diff options
Diffstat (limited to 'contrib/llvm/lib/Analysis')
91 files changed, 73155 insertions, 0 deletions
diff --git a/contrib/llvm/lib/Analysis/AliasAnalysis.cpp b/contrib/llvm/lib/Analysis/AliasAnalysis.cpp new file mode 100644 index 000000000000..a6585df949f8 --- /dev/null +++ b/contrib/llvm/lib/Analysis/AliasAnalysis.cpp @@ -0,0 +1,840 @@ +//==- AliasAnalysis.cpp - Generic Alias Analysis Interface Implementation --==// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the generic AliasAnalysis interface which is used as the +// common interface used by all clients and implementations of alias analysis. +// +// This file also implements the default version of the AliasAnalysis interface +// that is to be used when no other implementation is specified.  This does some +// simple tests that detect obvious cases: two different global pointers cannot +// alias, a global cannot alias a malloc, two different mallocs cannot alias, +// etc. +// +// This alias analysis implementation really isn't very good for anything, but +// it is very fast, and makes a nice clean default implementation.  Because it +// handles lots of little corner cases, other, more complex, alias analysis +// implementations may choose to rely on this pass to resolve these simple and +// easy cases. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CFLAndersAliasAnalysis.h" +#include "llvm/Analysis/CFLSteensAliasAnalysis.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/ObjCARCAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TypeBasedAliasAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include <algorithm> +#include <cassert> +#include <functional> +#include <iterator> + +using namespace llvm; + +/// Allow disabling BasicAA from the AA results. This is particularly useful +/// when testing to isolate a single AA implementation. +static cl::opt<bool> DisableBasicAA("disable-basicaa", cl::Hidden, +                                    cl::init(false)); + +AAResults::AAResults(AAResults &&Arg) +    : TLI(Arg.TLI), AAs(std::move(Arg.AAs)), AADeps(std::move(Arg.AADeps)) { +  for (auto &AA : AAs) +    AA->setAAResults(this); +} + +AAResults::~AAResults() { +// FIXME; It would be nice to at least clear out the pointers back to this +// aggregation here, but we end up with non-nesting lifetimes in the legacy +// pass manager that prevent this from working. In the legacy pass manager +// we'll end up with dangling references here in some cases. +#if 0 +  for (auto &AA : AAs) +    AA->setAAResults(nullptr); +#endif +} + +bool AAResults::invalidate(Function &F, const PreservedAnalyses &PA, +                           FunctionAnalysisManager::Invalidator &Inv) { +  // Check if the AA manager itself has been invalidated. +  auto PAC = PA.getChecker<AAManager>(); +  if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>()) +    return true; // The manager needs to be blown away, clear everything. + +  // Check all of the dependencies registered. +  for (AnalysisKey *ID : AADeps) +    if (Inv.invalidate(ID, F, PA)) +      return true; + +  // Everything we depend on is still fine, so are we. Nothing to invalidate. +  return false; +} + +//===----------------------------------------------------------------------===// +// Default chaining methods +//===----------------------------------------------------------------------===// + +AliasResult AAResults::alias(const MemoryLocation &LocA, +                             const MemoryLocation &LocB) { +  for (const auto &AA : AAs) { +    auto Result = AA->alias(LocA, LocB); +    if (Result != MayAlias) +      return Result; +  } +  return MayAlias; +} + +bool AAResults::pointsToConstantMemory(const MemoryLocation &Loc, +                                       bool OrLocal) { +  for (const auto &AA : AAs) +    if (AA->pointsToConstantMemory(Loc, OrLocal)) +      return true; + +  return false; +} + +ModRefInfo AAResults::getArgModRefInfo(ImmutableCallSite CS, unsigned ArgIdx) { +  ModRefInfo Result = ModRefInfo::ModRef; + +  for (const auto &AA : AAs) { +    Result = intersectModRef(Result, AA->getArgModRefInfo(CS, ArgIdx)); + +    // Early-exit the moment we reach the bottom of the lattice. +    if (isNoModRef(Result)) +      return ModRefInfo::NoModRef; +  } + +  return Result; +} + +ModRefInfo AAResults::getModRefInfo(Instruction *I, ImmutableCallSite Call) { +  // We may have two calls. +  if (auto CS = ImmutableCallSite(I)) { +    // Check if the two calls modify the same memory. +    return getModRefInfo(CS, Call); +  } else if (I->isFenceLike()) { +    // If this is a fence, just return ModRef. +    return ModRefInfo::ModRef; +  } else { +    // Otherwise, check if the call modifies or references the +    // location this memory access defines.  The best we can say +    // is that if the call references what this instruction +    // defines, it must be clobbered by this location. +    const MemoryLocation DefLoc = MemoryLocation::get(I); +    ModRefInfo MR = getModRefInfo(Call, DefLoc); +    if (isModOrRefSet(MR)) +      return setModAndRef(MR); +  } +  return ModRefInfo::NoModRef; +} + +ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS, +                                    const MemoryLocation &Loc) { +  ModRefInfo Result = ModRefInfo::ModRef; + +  for (const auto &AA : AAs) { +    Result = intersectModRef(Result, AA->getModRefInfo(CS, Loc)); + +    // Early-exit the moment we reach the bottom of the lattice. +    if (isNoModRef(Result)) +      return ModRefInfo::NoModRef; +  } + +  // Try to refine the mod-ref info further using other API entry points to the +  // aggregate set of AA results. +  auto MRB = getModRefBehavior(CS); +  if (MRB == FMRB_DoesNotAccessMemory || +      MRB == FMRB_OnlyAccessesInaccessibleMem) +    return ModRefInfo::NoModRef; + +  if (onlyReadsMemory(MRB)) +    Result = clearMod(Result); +  else if (doesNotReadMemory(MRB)) +    Result = clearRef(Result); + +  if (onlyAccessesArgPointees(MRB) || onlyAccessesInaccessibleOrArgMem(MRB)) { +    bool DoesAlias = false; +    bool IsMustAlias = true; +    ModRefInfo AllArgsMask = ModRefInfo::NoModRef; +    if (doesAccessArgPointees(MRB)) { +      for (auto AI = CS.arg_begin(), AE = CS.arg_end(); AI != AE; ++AI) { +        const Value *Arg = *AI; +        if (!Arg->getType()->isPointerTy()) +          continue; +        unsigned ArgIdx = std::distance(CS.arg_begin(), AI); +        MemoryLocation ArgLoc = MemoryLocation::getForArgument(CS, ArgIdx, TLI); +        AliasResult ArgAlias = alias(ArgLoc, Loc); +        if (ArgAlias != NoAlias) { +          ModRefInfo ArgMask = getArgModRefInfo(CS, ArgIdx); +          DoesAlias = true; +          AllArgsMask = unionModRef(AllArgsMask, ArgMask); +        } +        // Conservatively clear IsMustAlias unless only MustAlias is found. +        IsMustAlias &= (ArgAlias == MustAlias); +      } +    } +    // Return NoModRef if no alias found with any argument. +    if (!DoesAlias) +      return ModRefInfo::NoModRef; +    // Logical & between other AA analyses and argument analysis. +    Result = intersectModRef(Result, AllArgsMask); +    // If only MustAlias found above, set Must bit. +    Result = IsMustAlias ? setMust(Result) : clearMust(Result); +  } + +  // If Loc is a constant memory location, the call definitely could not +  // modify the memory location. +  if (isModSet(Result) && pointsToConstantMemory(Loc, /*OrLocal*/ false)) +    Result = clearMod(Result); + +  return Result; +} + +ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS1, +                                    ImmutableCallSite CS2) { +  ModRefInfo Result = ModRefInfo::ModRef; + +  for (const auto &AA : AAs) { +    Result = intersectModRef(Result, AA->getModRefInfo(CS1, CS2)); + +    // Early-exit the moment we reach the bottom of the lattice. +    if (isNoModRef(Result)) +      return ModRefInfo::NoModRef; +  } + +  // Try to refine the mod-ref info further using other API entry points to the +  // aggregate set of AA results. + +  // If CS1 or CS2 are readnone, they don't interact. +  auto CS1B = getModRefBehavior(CS1); +  if (CS1B == FMRB_DoesNotAccessMemory) +    return ModRefInfo::NoModRef; + +  auto CS2B = getModRefBehavior(CS2); +  if (CS2B == FMRB_DoesNotAccessMemory) +    return ModRefInfo::NoModRef; + +  // If they both only read from memory, there is no dependence. +  if (onlyReadsMemory(CS1B) && onlyReadsMemory(CS2B)) +    return ModRefInfo::NoModRef; + +  // If CS1 only reads memory, the only dependence on CS2 can be +  // from CS1 reading memory written by CS2. +  if (onlyReadsMemory(CS1B)) +    Result = clearMod(Result); +  else if (doesNotReadMemory(CS1B)) +    Result = clearRef(Result); + +  // If CS2 only access memory through arguments, accumulate the mod/ref +  // information from CS1's references to the memory referenced by +  // CS2's arguments. +  if (onlyAccessesArgPointees(CS2B)) { +    if (!doesAccessArgPointees(CS2B)) +      return ModRefInfo::NoModRef; +    ModRefInfo R = ModRefInfo::NoModRef; +    bool IsMustAlias = true; +    for (auto I = CS2.arg_begin(), E = CS2.arg_end(); I != E; ++I) { +      const Value *Arg = *I; +      if (!Arg->getType()->isPointerTy()) +        continue; +      unsigned CS2ArgIdx = std::distance(CS2.arg_begin(), I); +      auto CS2ArgLoc = MemoryLocation::getForArgument(CS2, CS2ArgIdx, TLI); + +      // ArgModRefCS2 indicates what CS2 might do to CS2ArgLoc, and the +      // dependence of CS1 on that location is the inverse: +      // - If CS2 modifies location, dependence exists if CS1 reads or writes. +      // - If CS2 only reads location, dependence exists if CS1 writes. +      ModRefInfo ArgModRefCS2 = getArgModRefInfo(CS2, CS2ArgIdx); +      ModRefInfo ArgMask = ModRefInfo::NoModRef; +      if (isModSet(ArgModRefCS2)) +        ArgMask = ModRefInfo::ModRef; +      else if (isRefSet(ArgModRefCS2)) +        ArgMask = ModRefInfo::Mod; + +      // ModRefCS1 indicates what CS1 might do to CS2ArgLoc, and we use +      // above ArgMask to update dependence info. +      ModRefInfo ModRefCS1 = getModRefInfo(CS1, CS2ArgLoc); +      ArgMask = intersectModRef(ArgMask, ModRefCS1); + +      // Conservatively clear IsMustAlias unless only MustAlias is found. +      IsMustAlias &= isMustSet(ModRefCS1); + +      R = intersectModRef(unionModRef(R, ArgMask), Result); +      if (R == Result) { +        // On early exit, not all args were checked, cannot set Must. +        if (I + 1 != E) +          IsMustAlias = false; +        break; +      } +    } + +    if (isNoModRef(R)) +      return ModRefInfo::NoModRef; + +    // If MustAlias found above, set Must bit. +    return IsMustAlias ? setMust(R) : clearMust(R); +  } + +  // If CS1 only accesses memory through arguments, check if CS2 references +  // any of the memory referenced by CS1's arguments. If not, return NoModRef. +  if (onlyAccessesArgPointees(CS1B)) { +    if (!doesAccessArgPointees(CS1B)) +      return ModRefInfo::NoModRef; +    ModRefInfo R = ModRefInfo::NoModRef; +    bool IsMustAlias = true; +    for (auto I = CS1.arg_begin(), E = CS1.arg_end(); I != E; ++I) { +      const Value *Arg = *I; +      if (!Arg->getType()->isPointerTy()) +        continue; +      unsigned CS1ArgIdx = std::distance(CS1.arg_begin(), I); +      auto CS1ArgLoc = MemoryLocation::getForArgument(CS1, CS1ArgIdx, TLI); + +      // ArgModRefCS1 indicates what CS1 might do to CS1ArgLoc; if CS1 might +      // Mod CS1ArgLoc, then we care about either a Mod or a Ref by CS2. If +      // CS1 might Ref, then we care only about a Mod by CS2. +      ModRefInfo ArgModRefCS1 = getArgModRefInfo(CS1, CS1ArgIdx); +      ModRefInfo ModRefCS2 = getModRefInfo(CS2, CS1ArgLoc); +      if ((isModSet(ArgModRefCS1) && isModOrRefSet(ModRefCS2)) || +          (isRefSet(ArgModRefCS1) && isModSet(ModRefCS2))) +        R = intersectModRef(unionModRef(R, ArgModRefCS1), Result); + +      // Conservatively clear IsMustAlias unless only MustAlias is found. +      IsMustAlias &= isMustSet(ModRefCS2); + +      if (R == Result) { +        // On early exit, not all args were checked, cannot set Must. +        if (I + 1 != E) +          IsMustAlias = false; +        break; +      } +    } + +    if (isNoModRef(R)) +      return ModRefInfo::NoModRef; + +    // If MustAlias found above, set Must bit. +    return IsMustAlias ? setMust(R) : clearMust(R); +  } + +  return Result; +} + +FunctionModRefBehavior AAResults::getModRefBehavior(ImmutableCallSite CS) { +  FunctionModRefBehavior Result = FMRB_UnknownModRefBehavior; + +  for (const auto &AA : AAs) { +    Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(CS)); + +    // Early-exit the moment we reach the bottom of the lattice. +    if (Result == FMRB_DoesNotAccessMemory) +      return Result; +  } + +  return Result; +} + +FunctionModRefBehavior AAResults::getModRefBehavior(const Function *F) { +  FunctionModRefBehavior Result = FMRB_UnknownModRefBehavior; + +  for (const auto &AA : AAs) { +    Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(F)); + +    // Early-exit the moment we reach the bottom of the lattice. +    if (Result == FMRB_DoesNotAccessMemory) +      return Result; +  } + +  return Result; +} + +raw_ostream &llvm::operator<<(raw_ostream &OS, AliasResult AR) { +  switch (AR) { +  case NoAlias: +    OS << "NoAlias"; +    break; +  case MustAlias: +    OS << "MustAlias"; +    break; +  case MayAlias: +    OS << "MayAlias"; +    break; +  case PartialAlias: +    OS << "PartialAlias"; +    break; +  } +  return OS; +} + +//===----------------------------------------------------------------------===// +// Helper method implementation +//===----------------------------------------------------------------------===// + +ModRefInfo AAResults::getModRefInfo(const LoadInst *L, +                                    const MemoryLocation &Loc) { +  // Be conservative in the face of atomic. +  if (isStrongerThan(L->getOrdering(), AtomicOrdering::Unordered)) +    return ModRefInfo::ModRef; + +  // If the load address doesn't alias the given address, it doesn't read +  // or write the specified memory. +  if (Loc.Ptr) { +    AliasResult AR = alias(MemoryLocation::get(L), Loc); +    if (AR == NoAlias) +      return ModRefInfo::NoModRef; +    if (AR == MustAlias) +      return ModRefInfo::MustRef; +  } +  // Otherwise, a load just reads. +  return ModRefInfo::Ref; +} + +ModRefInfo AAResults::getModRefInfo(const StoreInst *S, +                                    const MemoryLocation &Loc) { +  // Be conservative in the face of atomic. +  if (isStrongerThan(S->getOrdering(), AtomicOrdering::Unordered)) +    return ModRefInfo::ModRef; + +  if (Loc.Ptr) { +    AliasResult AR = alias(MemoryLocation::get(S), Loc); +    // If the store address cannot alias the pointer in question, then the +    // specified memory cannot be modified by the store. +    if (AR == NoAlias) +      return ModRefInfo::NoModRef; + +    // If the pointer is a pointer to constant memory, then it could not have +    // been modified by this store. +    if (pointsToConstantMemory(Loc)) +      return ModRefInfo::NoModRef; + +    // If the store address aliases the pointer as must alias, set Must. +    if (AR == MustAlias) +      return ModRefInfo::MustMod; +  } + +  // Otherwise, a store just writes. +  return ModRefInfo::Mod; +} + +ModRefInfo AAResults::getModRefInfo(const FenceInst *S, const MemoryLocation &Loc) { +  // If we know that the location is a constant memory location, the fence +  // cannot modify this location. +  if (Loc.Ptr && pointsToConstantMemory(Loc)) +    return ModRefInfo::Ref; +  return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const VAArgInst *V, +                                    const MemoryLocation &Loc) { +  if (Loc.Ptr) { +    AliasResult AR = alias(MemoryLocation::get(V), Loc); +    // If the va_arg address cannot alias the pointer in question, then the +    // specified memory cannot be accessed by the va_arg. +    if (AR == NoAlias) +      return ModRefInfo::NoModRef; + +    // If the pointer is a pointer to constant memory, then it could not have +    // been modified by this va_arg. +    if (pointsToConstantMemory(Loc)) +      return ModRefInfo::NoModRef; + +    // If the va_arg aliases the pointer as must alias, set Must. +    if (AR == MustAlias) +      return ModRefInfo::MustModRef; +  } + +  // Otherwise, a va_arg reads and writes. +  return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const CatchPadInst *CatchPad, +                                    const MemoryLocation &Loc) { +  if (Loc.Ptr) { +    // If the pointer is a pointer to constant memory, +    // then it could not have been modified by this catchpad. +    if (pointsToConstantMemory(Loc)) +      return ModRefInfo::NoModRef; +  } + +  // Otherwise, a catchpad reads and writes. +  return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet, +                                    const MemoryLocation &Loc) { +  if (Loc.Ptr) { +    // If the pointer is a pointer to constant memory, +    // then it could not have been modified by this catchpad. +    if (pointsToConstantMemory(Loc)) +      return ModRefInfo::NoModRef; +  } + +  // Otherwise, a catchret reads and writes. +  return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX, +                                    const MemoryLocation &Loc) { +  // Acquire/Release cmpxchg has properties that matter for arbitrary addresses. +  if (isStrongerThanMonotonic(CX->getSuccessOrdering())) +    return ModRefInfo::ModRef; + +  if (Loc.Ptr) { +    AliasResult AR = alias(MemoryLocation::get(CX), Loc); +    // If the cmpxchg address does not alias the location, it does not access +    // it. +    if (AR == NoAlias) +      return ModRefInfo::NoModRef; + +    // If the cmpxchg address aliases the pointer as must alias, set Must. +    if (AR == MustAlias) +      return ModRefInfo::MustModRef; +  } + +  return ModRefInfo::ModRef; +} + +ModRefInfo AAResults::getModRefInfo(const AtomicRMWInst *RMW, +                                    const MemoryLocation &Loc) { +  // Acquire/Release atomicrmw has properties that matter for arbitrary addresses. +  if (isStrongerThanMonotonic(RMW->getOrdering())) +    return ModRefInfo::ModRef; + +  if (Loc.Ptr) { +    AliasResult AR = alias(MemoryLocation::get(RMW), Loc); +    // If the atomicrmw address does not alias the location, it does not access +    // it. +    if (AR == NoAlias) +      return ModRefInfo::NoModRef; + +    // If the atomicrmw address aliases the pointer as must alias, set Must. +    if (AR == MustAlias) +      return ModRefInfo::MustModRef; +  } + +  return ModRefInfo::ModRef; +} + +/// Return information about whether a particular call site modifies +/// or reads the specified memory location \p MemLoc before instruction \p I +/// in a BasicBlock. An ordered basic block \p OBB can be used to speed up +/// instruction-ordering queries inside the BasicBlock containing \p I. +/// FIXME: this is really just shoring-up a deficiency in alias analysis. +/// BasicAA isn't willing to spend linear time determining whether an alloca +/// was captured before or after this particular call, while we are. However, +/// with a smarter AA in place, this test is just wasting compile time. +ModRefInfo AAResults::callCapturesBefore(const Instruction *I, +                                         const MemoryLocation &MemLoc, +                                         DominatorTree *DT, +                                         OrderedBasicBlock *OBB) { +  if (!DT) +    return ModRefInfo::ModRef; + +  const Value *Object = +      GetUnderlyingObject(MemLoc.Ptr, I->getModule()->getDataLayout()); +  if (!isIdentifiedObject(Object) || isa<GlobalValue>(Object) || +      isa<Constant>(Object)) +    return ModRefInfo::ModRef; + +  ImmutableCallSite CS(I); +  if (!CS.getInstruction() || CS.getInstruction() == Object) +    return ModRefInfo::ModRef; + +  if (PointerMayBeCapturedBefore(Object, /* ReturnCaptures */ true, +                                 /* StoreCaptures */ true, I, DT, +                                 /* include Object */ true, +                                 /* OrderedBasicBlock */ OBB)) +    return ModRefInfo::ModRef; + +  unsigned ArgNo = 0; +  ModRefInfo R = ModRefInfo::NoModRef; +  bool IsMustAlias = true; +  // Set flag only if no May found and all operands processed. +  for (auto CI = CS.data_operands_begin(), CE = CS.data_operands_end(); +       CI != CE; ++CI, ++ArgNo) { +    // Only look at the no-capture or byval pointer arguments.  If this +    // pointer were passed to arguments that were neither of these, then it +    // couldn't be no-capture. +    if (!(*CI)->getType()->isPointerTy() || +        (!CS.doesNotCapture(ArgNo) && +         ArgNo < CS.getNumArgOperands() && !CS.isByValArgument(ArgNo))) +      continue; + +    AliasResult AR = alias(MemoryLocation(*CI), MemoryLocation(Object)); +    // If this is a no-capture pointer argument, see if we can tell that it +    // is impossible to alias the pointer we're checking.  If not, we have to +    // assume that the call could touch the pointer, even though it doesn't +    // escape. +    if (AR != MustAlias) +      IsMustAlias = false; +    if (AR == NoAlias) +      continue; +    if (CS.doesNotAccessMemory(ArgNo)) +      continue; +    if (CS.onlyReadsMemory(ArgNo)) { +      R = ModRefInfo::Ref; +      continue; +    } +    // Not returning MustModRef since we have not seen all the arguments. +    return ModRefInfo::ModRef; +  } +  return IsMustAlias ? setMust(R) : clearMust(R); +} + +/// canBasicBlockModify - Return true if it is possible for execution of the +/// specified basic block to modify the location Loc. +/// +bool AAResults::canBasicBlockModify(const BasicBlock &BB, +                                    const MemoryLocation &Loc) { +  return canInstructionRangeModRef(BB.front(), BB.back(), Loc, ModRefInfo::Mod); +} + +/// canInstructionRangeModRef - Return true if it is possible for the +/// execution of the specified instructions to mod\ref (according to the +/// mode) the location Loc. The instructions to consider are all +/// of the instructions in the range of [I1,I2] INCLUSIVE. +/// I1 and I2 must be in the same basic block. +bool AAResults::canInstructionRangeModRef(const Instruction &I1, +                                          const Instruction &I2, +                                          const MemoryLocation &Loc, +                                          const ModRefInfo Mode) { +  assert(I1.getParent() == I2.getParent() && +         "Instructions not in same basic block!"); +  BasicBlock::const_iterator I = I1.getIterator(); +  BasicBlock::const_iterator E = I2.getIterator(); +  ++E;  // Convert from inclusive to exclusive range. + +  for (; I != E; ++I) // Check every instruction in range +    if (isModOrRefSet(intersectModRef(getModRefInfo(&*I, Loc), Mode))) +      return true; +  return false; +} + +// Provide a definition for the root virtual destructor. +AAResults::Concept::~Concept() = default; + +// Provide a definition for the static object used to identify passes. +AnalysisKey AAManager::Key; + +namespace { + +/// A wrapper pass for external alias analyses. This just squirrels away the +/// callback used to run any analyses and register their results. +struct ExternalAAWrapperPass : ImmutablePass { +  using CallbackT = std::function<void(Pass &, Function &, AAResults &)>; + +  CallbackT CB; + +  static char ID; + +  ExternalAAWrapperPass() : ImmutablePass(ID) { +    initializeExternalAAWrapperPassPass(*PassRegistry::getPassRegistry()); +  } + +  explicit ExternalAAWrapperPass(CallbackT CB) +      : ImmutablePass(ID), CB(std::move(CB)) { +    initializeExternalAAWrapperPassPass(*PassRegistry::getPassRegistry()); +  } + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.setPreservesAll(); +  } +}; + +} // end anonymous namespace + +char ExternalAAWrapperPass::ID = 0; + +INITIALIZE_PASS(ExternalAAWrapperPass, "external-aa", "External Alias Analysis", +                false, true) + +ImmutablePass * +llvm::createExternalAAWrapperPass(ExternalAAWrapperPass::CallbackT Callback) { +  return new ExternalAAWrapperPass(std::move(Callback)); +} + +AAResultsWrapperPass::AAResultsWrapperPass() : FunctionPass(ID) { +  initializeAAResultsWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +char AAResultsWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(AAResultsWrapperPass, "aa", +                      "Function Alias Analysis Results", false, true) +INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CFLAndersAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CFLSteensAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ExternalAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ObjCARCAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScopedNoAliasAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TypeBasedAAWrapperPass) +INITIALIZE_PASS_END(AAResultsWrapperPass, "aa", +                    "Function Alias Analysis Results", false, true) + +FunctionPass *llvm::createAAResultsWrapperPass() { +  return new AAResultsWrapperPass(); +} + +/// Run the wrapper pass to rebuild an aggregation over known AA passes. +/// +/// This is the legacy pass manager's interface to the new-style AA results +/// aggregation object. Because this is somewhat shoe-horned into the legacy +/// pass manager, we hard code all the specific alias analyses available into +/// it. While the particular set enabled is configured via commandline flags, +/// adding a new alias analysis to LLVM will require adding support for it to +/// this list. +bool AAResultsWrapperPass::runOnFunction(Function &F) { +  // NB! This *must* be reset before adding new AA results to the new +  // AAResults object because in the legacy pass manager, each instance +  // of these will refer to the *same* immutable analyses, registering and +  // unregistering themselves with them. We need to carefully tear down the +  // previous object first, in this case replacing it with an empty one, before +  // registering new results. +  AAR.reset( +      new AAResults(getAnalysis<TargetLibraryInfoWrapperPass>().getTLI())); + +  // BasicAA is always available for function analyses. Also, we add it first +  // so that it can trump TBAA results when it proves MustAlias. +  // FIXME: TBAA should have an explicit mode to support this and then we +  // should reconsider the ordering here. +  if (!DisableBasicAA) +    AAR->addAAResult(getAnalysis<BasicAAWrapperPass>().getResult()); + +  // Populate the results with the currently available AAs. +  if (auto *WrapperPass = getAnalysisIfAvailable<ScopedNoAliasAAWrapperPass>()) +    AAR->addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = getAnalysisIfAvailable<TypeBasedAAWrapperPass>()) +    AAR->addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = +          getAnalysisIfAvailable<objcarc::ObjCARCAAWrapperPass>()) +    AAR->addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = getAnalysisIfAvailable<GlobalsAAWrapperPass>()) +    AAR->addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = getAnalysisIfAvailable<SCEVAAWrapperPass>()) +    AAR->addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = getAnalysisIfAvailable<CFLAndersAAWrapperPass>()) +    AAR->addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = getAnalysisIfAvailable<CFLSteensAAWrapperPass>()) +    AAR->addAAResult(WrapperPass->getResult()); + +  // If available, run an external AA providing callback over the results as +  // well. +  if (auto *WrapperPass = getAnalysisIfAvailable<ExternalAAWrapperPass>()) +    if (WrapperPass->CB) +      WrapperPass->CB(*this, F, *AAR); + +  // Analyses don't mutate the IR, so return false. +  return false; +} + +void AAResultsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<BasicAAWrapperPass>(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); + +  // We also need to mark all the alias analysis passes we will potentially +  // probe in runOnFunction as used here to ensure the legacy pass manager +  // preserves them. This hard coding of lists of alias analyses is specific to +  // the legacy pass manager. +  AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>(); +  AU.addUsedIfAvailable<TypeBasedAAWrapperPass>(); +  AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>(); +  AU.addUsedIfAvailable<GlobalsAAWrapperPass>(); +  AU.addUsedIfAvailable<SCEVAAWrapperPass>(); +  AU.addUsedIfAvailable<CFLAndersAAWrapperPass>(); +  AU.addUsedIfAvailable<CFLSteensAAWrapperPass>(); +} + +AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F, +                                        BasicAAResult &BAR) { +  AAResults AAR(P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI()); + +  // Add in our explicitly constructed BasicAA results. +  if (!DisableBasicAA) +    AAR.addAAResult(BAR); + +  // Populate the results with the other currently available AAs. +  if (auto *WrapperPass = +          P.getAnalysisIfAvailable<ScopedNoAliasAAWrapperPass>()) +    AAR.addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = P.getAnalysisIfAvailable<TypeBasedAAWrapperPass>()) +    AAR.addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = +          P.getAnalysisIfAvailable<objcarc::ObjCARCAAWrapperPass>()) +    AAR.addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = P.getAnalysisIfAvailable<GlobalsAAWrapperPass>()) +    AAR.addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLAndersAAWrapperPass>()) +    AAR.addAAResult(WrapperPass->getResult()); +  if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLSteensAAWrapperPass>()) +    AAR.addAAResult(WrapperPass->getResult()); + +  return AAR; +} + +bool llvm::isNoAliasCall(const Value *V) { +  if (auto CS = ImmutableCallSite(V)) +    return CS.hasRetAttr(Attribute::NoAlias); +  return false; +} + +bool llvm::isNoAliasArgument(const Value *V) { +  if (const Argument *A = dyn_cast<Argument>(V)) +    return A->hasNoAliasAttr(); +  return false; +} + +bool llvm::isIdentifiedObject(const Value *V) { +  if (isa<AllocaInst>(V)) +    return true; +  if (isa<GlobalValue>(V) && !isa<GlobalAlias>(V)) +    return true; +  if (isNoAliasCall(V)) +    return true; +  if (const Argument *A = dyn_cast<Argument>(V)) +    return A->hasNoAliasAttr() || A->hasByValAttr(); +  return false; +} + +bool llvm::isIdentifiedFunctionLocal(const Value *V) { +  return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasArgument(V); +} + +void llvm::getAAResultsAnalysisUsage(AnalysisUsage &AU) { +  // This function needs to be in sync with llvm::createLegacyPMAAResults -- if +  // more alias analyses are added to llvm::createLegacyPMAAResults, they need +  // to be added here also. +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +  AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>(); +  AU.addUsedIfAvailable<TypeBasedAAWrapperPass>(); +  AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>(); +  AU.addUsedIfAvailable<GlobalsAAWrapperPass>(); +  AU.addUsedIfAvailable<CFLAndersAAWrapperPass>(); +  AU.addUsedIfAvailable<CFLSteensAAWrapperPass>(); +} diff --git a/contrib/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp b/contrib/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp new file mode 100644 index 000000000000..764ae9160350 --- /dev/null +++ b/contrib/llvm/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -0,0 +1,433 @@ +//===- AliasAnalysisEvaluator.cpp - Alias Analysis Accuracy Evaluator -----===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysisEvaluator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +static cl::opt<bool> PrintAll("print-all-alias-modref-info", cl::ReallyHidden); + +static cl::opt<bool> PrintNoAlias("print-no-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintMayAlias("print-may-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintPartialAlias("print-partial-aliases", cl::ReallyHidden); +static cl::opt<bool> PrintMustAlias("print-must-aliases", cl::ReallyHidden); + +static cl::opt<bool> PrintNoModRef("print-no-modref", cl::ReallyHidden); +static cl::opt<bool> PrintRef("print-ref", cl::ReallyHidden); +static cl::opt<bool> PrintMod("print-mod", cl::ReallyHidden); +static cl::opt<bool> PrintModRef("print-modref", cl::ReallyHidden); +static cl::opt<bool> PrintMust("print-must", cl::ReallyHidden); +static cl::opt<bool> PrintMustRef("print-mustref", cl::ReallyHidden); +static cl::opt<bool> PrintMustMod("print-mustmod", cl::ReallyHidden); +static cl::opt<bool> PrintMustModRef("print-mustmodref", cl::ReallyHidden); + +static cl::opt<bool> EvalAAMD("evaluate-aa-metadata", cl::ReallyHidden); + +static void PrintResults(AliasResult AR, bool P, const Value *V1, +                         const Value *V2, const Module *M) { +  if (PrintAll || P) { +    std::string o1, o2; +    { +      raw_string_ostream os1(o1), os2(o2); +      V1->printAsOperand(os1, true, M); +      V2->printAsOperand(os2, true, M); +    } + +    if (o2 < o1) +      std::swap(o1, o2); +    errs() << "  " << AR << ":\t" << o1 << ", " << o2 << "\n"; +  } +} + +static inline void PrintModRefResults(const char *Msg, bool P, Instruction *I, +                                      Value *Ptr, Module *M) { +  if (PrintAll || P) { +    errs() << "  " << Msg << ":  Ptr: "; +    Ptr->printAsOperand(errs(), true, M); +    errs() << "\t<->" << *I << '\n'; +  } +} + +static inline void PrintModRefResults(const char *Msg, bool P, CallSite CSA, +                                      CallSite CSB, Module *M) { +  if (PrintAll || P) { +    errs() << "  " << Msg << ": " << *CSA.getInstruction() << " <-> " +           << *CSB.getInstruction() << '\n'; +  } +} + +static inline void PrintLoadStoreResults(AliasResult AR, bool P, +                                         const Value *V1, const Value *V2, +                                         const Module *M) { +  if (PrintAll || P) { +    errs() << "  " << AR << ": " << *V1 << " <-> " << *V2 << '\n'; +  } +} + +static inline bool isInterestingPointer(Value *V) { +  return V->getType()->isPointerTy() +      && !isa<ConstantPointerNull>(V); +} + +PreservedAnalyses AAEvaluator::run(Function &F, FunctionAnalysisManager &AM) { +  runInternal(F, AM.getResult<AAManager>(F)); +  return PreservedAnalyses::all(); +} + +void AAEvaluator::runInternal(Function &F, AAResults &AA) { +  const DataLayout &DL = F.getParent()->getDataLayout(); + +  ++FunctionCount; + +  SetVector<Value *> Pointers; +  SmallSetVector<CallSite, 16> CallSites; +  SetVector<Value *> Loads; +  SetVector<Value *> Stores; + +  for (auto &I : F.args()) +    if (I.getType()->isPointerTy())    // Add all pointer arguments. +      Pointers.insert(&I); + +  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { +    if (I->getType()->isPointerTy()) // Add all pointer instructions. +      Pointers.insert(&*I); +    if (EvalAAMD && isa<LoadInst>(&*I)) +      Loads.insert(&*I); +    if (EvalAAMD && isa<StoreInst>(&*I)) +      Stores.insert(&*I); +    Instruction &Inst = *I; +    if (auto CS = CallSite(&Inst)) { +      Value *Callee = CS.getCalledValue(); +      // Skip actual functions for direct function calls. +      if (!isa<Function>(Callee) && isInterestingPointer(Callee)) +        Pointers.insert(Callee); +      // Consider formals. +      for (Use &DataOp : CS.data_ops()) +        if (isInterestingPointer(DataOp)) +          Pointers.insert(DataOp); +      CallSites.insert(CS); +    } else { +      // Consider all operands. +      for (Instruction::op_iterator OI = Inst.op_begin(), OE = Inst.op_end(); +           OI != OE; ++OI) +        if (isInterestingPointer(*OI)) +          Pointers.insert(*OI); +    } +  } + +  if (PrintAll || PrintNoAlias || PrintMayAlias || PrintPartialAlias || +      PrintMustAlias || PrintNoModRef || PrintMod || PrintRef || PrintModRef) +    errs() << "Function: " << F.getName() << ": " << Pointers.size() +           << " pointers, " << CallSites.size() << " call sites\n"; + +  // iterate over the worklist, and run the full (n^2)/2 disambiguations +  for (SetVector<Value *>::iterator I1 = Pointers.begin(), E = Pointers.end(); +       I1 != E; ++I1) { +    uint64_t I1Size = MemoryLocation::UnknownSize; +    Type *I1ElTy = cast<PointerType>((*I1)->getType())->getElementType(); +    if (I1ElTy->isSized()) I1Size = DL.getTypeStoreSize(I1ElTy); + +    for (SetVector<Value *>::iterator I2 = Pointers.begin(); I2 != I1; ++I2) { +      uint64_t I2Size = MemoryLocation::UnknownSize; +      Type *I2ElTy =cast<PointerType>((*I2)->getType())->getElementType(); +      if (I2ElTy->isSized()) I2Size = DL.getTypeStoreSize(I2ElTy); + +      AliasResult AR = AA.alias(*I1, I1Size, *I2, I2Size); +      switch (AR) { +      case NoAlias: +        PrintResults(AR, PrintNoAlias, *I1, *I2, F.getParent()); +        ++NoAliasCount; +        break; +      case MayAlias: +        PrintResults(AR, PrintMayAlias, *I1, *I2, F.getParent()); +        ++MayAliasCount; +        break; +      case PartialAlias: +        PrintResults(AR, PrintPartialAlias, *I1, *I2, F.getParent()); +        ++PartialAliasCount; +        break; +      case MustAlias: +        PrintResults(AR, PrintMustAlias, *I1, *I2, F.getParent()); +        ++MustAliasCount; +        break; +      } +    } +  } + +  if (EvalAAMD) { +    // iterate over all pairs of load, store +    for (Value *Load : Loads) { +      for (Value *Store : Stores) { +        AliasResult AR = AA.alias(MemoryLocation::get(cast<LoadInst>(Load)), +                                  MemoryLocation::get(cast<StoreInst>(Store))); +        switch (AR) { +        case NoAlias: +          PrintLoadStoreResults(AR, PrintNoAlias, Load, Store, F.getParent()); +          ++NoAliasCount; +          break; +        case MayAlias: +          PrintLoadStoreResults(AR, PrintMayAlias, Load, Store, F.getParent()); +          ++MayAliasCount; +          break; +        case PartialAlias: +          PrintLoadStoreResults(AR, PrintPartialAlias, Load, Store, F.getParent()); +          ++PartialAliasCount; +          break; +        case MustAlias: +          PrintLoadStoreResults(AR, PrintMustAlias, Load, Store, F.getParent()); +          ++MustAliasCount; +          break; +        } +      } +    } + +    // iterate over all pairs of store, store +    for (SetVector<Value *>::iterator I1 = Stores.begin(), E = Stores.end(); +         I1 != E; ++I1) { +      for (SetVector<Value *>::iterator I2 = Stores.begin(); I2 != I1; ++I2) { +        AliasResult AR = AA.alias(MemoryLocation::get(cast<StoreInst>(*I1)), +                                  MemoryLocation::get(cast<StoreInst>(*I2))); +        switch (AR) { +        case NoAlias: +          PrintLoadStoreResults(AR, PrintNoAlias, *I1, *I2, F.getParent()); +          ++NoAliasCount; +          break; +        case MayAlias: +          PrintLoadStoreResults(AR, PrintMayAlias, *I1, *I2, F.getParent()); +          ++MayAliasCount; +          break; +        case PartialAlias: +          PrintLoadStoreResults(AR, PrintPartialAlias, *I1, *I2, F.getParent()); +          ++PartialAliasCount; +          break; +        case MustAlias: +          PrintLoadStoreResults(AR, PrintMustAlias, *I1, *I2, F.getParent()); +          ++MustAliasCount; +          break; +        } +      } +    } +  } + +  // Mod/ref alias analysis: compare all pairs of calls and values +  for (CallSite C : CallSites) { +    Instruction *I = C.getInstruction(); + +    for (auto Pointer : Pointers) { +      uint64_t Size = MemoryLocation::UnknownSize; +      Type *ElTy = cast<PointerType>(Pointer->getType())->getElementType(); +      if (ElTy->isSized()) Size = DL.getTypeStoreSize(ElTy); + +      switch (AA.getModRefInfo(C, Pointer, Size)) { +      case ModRefInfo::NoModRef: +        PrintModRefResults("NoModRef", PrintNoModRef, I, Pointer, +                           F.getParent()); +        ++NoModRefCount; +        break; +      case ModRefInfo::Mod: +        PrintModRefResults("Just Mod", PrintMod, I, Pointer, F.getParent()); +        ++ModCount; +        break; +      case ModRefInfo::Ref: +        PrintModRefResults("Just Ref", PrintRef, I, Pointer, F.getParent()); +        ++RefCount; +        break; +      case ModRefInfo::ModRef: +        PrintModRefResults("Both ModRef", PrintModRef, I, Pointer, +                           F.getParent()); +        ++ModRefCount; +        break; +      case ModRefInfo::Must: +        PrintModRefResults("Must", PrintMust, I, Pointer, F.getParent()); +        ++MustCount; +        break; +      case ModRefInfo::MustMod: +        PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, I, Pointer, +                           F.getParent()); +        ++MustModCount; +        break; +      case ModRefInfo::MustRef: +        PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, I, Pointer, +                           F.getParent()); +        ++MustRefCount; +        break; +      case ModRefInfo::MustModRef: +        PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, I, +                           Pointer, F.getParent()); +        ++MustModRefCount; +        break; +      } +    } +  } + +  // Mod/ref alias analysis: compare all pairs of calls +  for (auto C = CallSites.begin(), Ce = CallSites.end(); C != Ce; ++C) { +    for (auto D = CallSites.begin(); D != Ce; ++D) { +      if (D == C) +        continue; +      switch (AA.getModRefInfo(*C, *D)) { +      case ModRefInfo::NoModRef: +        PrintModRefResults("NoModRef", PrintNoModRef, *C, *D, F.getParent()); +        ++NoModRefCount; +        break; +      case ModRefInfo::Mod: +        PrintModRefResults("Just Mod", PrintMod, *C, *D, F.getParent()); +        ++ModCount; +        break; +      case ModRefInfo::Ref: +        PrintModRefResults("Just Ref", PrintRef, *C, *D, F.getParent()); +        ++RefCount; +        break; +      case ModRefInfo::ModRef: +        PrintModRefResults("Both ModRef", PrintModRef, *C, *D, F.getParent()); +        ++ModRefCount; +        break; +      case ModRefInfo::Must: +        PrintModRefResults("Must", PrintMust, *C, *D, F.getParent()); +        ++MustCount; +        break; +      case ModRefInfo::MustMod: +        PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, *C, *D, +                           F.getParent()); +        ++MustModCount; +        break; +      case ModRefInfo::MustRef: +        PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, *C, *D, +                           F.getParent()); +        ++MustRefCount; +        break; +      case ModRefInfo::MustModRef: +        PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, *C, *D, +                           F.getParent()); +        ++MustModRefCount; +        break; +      } +    } +  } +} + +static void PrintPercent(int64_t Num, int64_t Sum) { +  errs() << "(" << Num * 100LL / Sum << "." << ((Num * 1000LL / Sum) % 10) +         << "%)\n"; +} + +AAEvaluator::~AAEvaluator() { +  if (FunctionCount == 0) +    return; + +  int64_t AliasSum = +      NoAliasCount + MayAliasCount + PartialAliasCount + MustAliasCount; +  errs() << "===== Alias Analysis Evaluator Report =====\n"; +  if (AliasSum == 0) { +    errs() << "  Alias Analysis Evaluator Summary: No pointers!\n"; +  } else { +    errs() << "  " << AliasSum << " Total Alias Queries Performed\n"; +    errs() << "  " << NoAliasCount << " no alias responses "; +    PrintPercent(NoAliasCount, AliasSum); +    errs() << "  " << MayAliasCount << " may alias responses "; +    PrintPercent(MayAliasCount, AliasSum); +    errs() << "  " << PartialAliasCount << " partial alias responses "; +    PrintPercent(PartialAliasCount, AliasSum); +    errs() << "  " << MustAliasCount << " must alias responses "; +    PrintPercent(MustAliasCount, AliasSum); +    errs() << "  Alias Analysis Evaluator Pointer Alias Summary: " +           << NoAliasCount * 100 / AliasSum << "%/" +           << MayAliasCount * 100 / AliasSum << "%/" +           << PartialAliasCount * 100 / AliasSum << "%/" +           << MustAliasCount * 100 / AliasSum << "%\n"; +  } + +  // Display the summary for mod/ref analysis +  int64_t ModRefSum = NoModRefCount + RefCount + ModCount + ModRefCount + +                      MustCount + MustRefCount + MustModCount + MustModRefCount; +  if (ModRefSum == 0) { +    errs() << "  Alias Analysis Mod/Ref Evaluator Summary: no " +              "mod/ref!\n"; +  } else { +    errs() << "  " << ModRefSum << " Total ModRef Queries Performed\n"; +    errs() << "  " << NoModRefCount << " no mod/ref responses "; +    PrintPercent(NoModRefCount, ModRefSum); +    errs() << "  " << ModCount << " mod responses "; +    PrintPercent(ModCount, ModRefSum); +    errs() << "  " << RefCount << " ref responses "; +    PrintPercent(RefCount, ModRefSum); +    errs() << "  " << ModRefCount << " mod & ref responses "; +    PrintPercent(ModRefCount, ModRefSum); +    errs() << "  " << MustCount << " must responses "; +    PrintPercent(MustCount, ModRefSum); +    errs() << "  " << MustModCount << " must mod responses "; +    PrintPercent(MustModCount, ModRefSum); +    errs() << "  " << MustRefCount << " must ref responses "; +    PrintPercent(MustRefCount, ModRefSum); +    errs() << "  " << MustModRefCount << " must mod & ref responses "; +    PrintPercent(MustModRefCount, ModRefSum); +    errs() << "  Alias Analysis Evaluator Mod/Ref Summary: " +           << NoModRefCount * 100 / ModRefSum << "%/" +           << ModCount * 100 / ModRefSum << "%/" << RefCount * 100 / ModRefSum +           << "%/" << ModRefCount * 100 / ModRefSum << "%/" +           << MustCount * 100 / ModRefSum << "%/" +           << MustRefCount * 100 / ModRefSum << "%/" +           << MustModCount * 100 / ModRefSum << "%/" +           << MustModRefCount * 100 / ModRefSum << "%\n"; +  } +} + +namespace llvm { +class AAEvalLegacyPass : public FunctionPass { +  std::unique_ptr<AAEvaluator> P; + +public: +  static char ID; // Pass identification, replacement for typeid +  AAEvalLegacyPass() : FunctionPass(ID) { +    initializeAAEvalLegacyPassPass(*PassRegistry::getPassRegistry()); +  } + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addRequired<AAResultsWrapperPass>(); +    AU.setPreservesAll(); +  } + +  bool doInitialization(Module &M) override { +    P.reset(new AAEvaluator()); +    return false; +  } + +  bool runOnFunction(Function &F) override { +    P->runInternal(F, getAnalysis<AAResultsWrapperPass>().getAAResults()); +    return false; +  } +  bool doFinalization(Module &M) override { +    P.reset(); +    return false; +  } +}; +} + +char AAEvalLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AAEvalLegacyPass, "aa-eval", +                      "Exhaustive Alias Analysis Precision Evaluator", false, +                      true) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(AAEvalLegacyPass, "aa-eval", +                    "Exhaustive Alias Analysis Precision Evaluator", false, +                    true) + +FunctionPass *llvm::createAAEvalPass() { return new AAEvalLegacyPass(); } diff --git a/contrib/llvm/lib/Analysis/AliasAnalysisSummary.cpp b/contrib/llvm/lib/Analysis/AliasAnalysisSummary.cpp new file mode 100644 index 000000000000..2b4879453beb --- /dev/null +++ b/contrib/llvm/lib/Analysis/AliasAnalysisSummary.cpp @@ -0,0 +1,103 @@ +#include "AliasAnalysisSummary.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Compiler.h" + +namespace llvm { +namespace cflaa { + +namespace { +const unsigned AttrEscapedIndex = 0; +const unsigned AttrUnknownIndex = 1; +const unsigned AttrGlobalIndex = 2; +const unsigned AttrCallerIndex = 3; +const unsigned AttrFirstArgIndex = 4; +const unsigned AttrLastArgIndex = NumAliasAttrs; +const unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex; + +// It would be *slightly* prettier if we changed these to AliasAttrs, but it +// seems that both GCC and MSVC emit dynamic initializers for const bitsets. +using AliasAttr = unsigned; +const AliasAttr AttrNone = 0; +const AliasAttr AttrEscaped = 1 << AttrEscapedIndex; +const AliasAttr AttrUnknown = 1 << AttrUnknownIndex; +const AliasAttr AttrGlobal = 1 << AttrGlobalIndex; +const AliasAttr AttrCaller = 1 << AttrCallerIndex; +const AliasAttr ExternalAttrMask = AttrEscaped | AttrUnknown | AttrGlobal; +} + +AliasAttrs getAttrNone() { return AttrNone; } + +AliasAttrs getAttrUnknown() { return AttrUnknown; } +bool hasUnknownAttr(AliasAttrs Attr) { return Attr.test(AttrUnknownIndex); } + +AliasAttrs getAttrCaller() { return AttrCaller; } +bool hasCallerAttr(AliasAttrs Attr) { return Attr.test(AttrCaller); } +bool hasUnknownOrCallerAttr(AliasAttrs Attr) { +  return Attr.test(AttrUnknownIndex) || Attr.test(AttrCallerIndex); +} + +AliasAttrs getAttrEscaped() { return AttrEscaped; } +bool hasEscapedAttr(AliasAttrs Attr) { return Attr.test(AttrEscapedIndex); } + +static AliasAttr argNumberToAttr(unsigned ArgNum) { +  if (ArgNum >= AttrMaxNumArgs) +    return AttrUnknown; +  // N.B. MSVC complains if we use `1U` here, since AliasAttr' ctor takes +  // an unsigned long long. +  return AliasAttr(1ULL << (ArgNum + AttrFirstArgIndex)); +} + +AliasAttrs getGlobalOrArgAttrFromValue(const Value &Val) { +  if (isa<GlobalValue>(Val)) +    return AttrGlobal; + +  if (auto *Arg = dyn_cast<Argument>(&Val)) +    // Only pointer arguments should have the argument attribute, +    // because things can't escape through scalars without us seeing a +    // cast, and thus, interaction with them doesn't matter. +    if (!Arg->hasNoAliasAttr() && Arg->getType()->isPointerTy()) +      return argNumberToAttr(Arg->getArgNo()); +  return AttrNone; +} + +bool isGlobalOrArgAttr(AliasAttrs Attr) { +  return Attr.reset(AttrEscapedIndex) +      .reset(AttrUnknownIndex) +      .reset(AttrCallerIndex) +      .any(); +} + +AliasAttrs getExternallyVisibleAttrs(AliasAttrs Attr) { +  return Attr & AliasAttrs(ExternalAttrMask); +} + +Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue IValue, +                                                      CallSite CS) { +  auto Index = IValue.Index; +  auto Value = (Index == 0) ? CS.getInstruction() : CS.getArgument(Index - 1); +  if (Value->getType()->isPointerTy()) +    return InstantiatedValue{Value, IValue.DerefLevel}; +  return None; +} + +Optional<InstantiatedRelation> +instantiateExternalRelation(ExternalRelation ERelation, CallSite CS) { +  auto From = instantiateInterfaceValue(ERelation.From, CS); +  if (!From) +    return None; +  auto To = instantiateInterfaceValue(ERelation.To, CS); +  if (!To) +    return None; +  return InstantiatedRelation{*From, *To, ERelation.Offset}; +} + +Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute EAttr, +                                                        CallSite CS) { +  auto Value = instantiateInterfaceValue(EAttr.IValue, CS); +  if (!Value) +    return None; +  return InstantiatedAttr{*Value, EAttr.Attr}; +} +} +} diff --git a/contrib/llvm/lib/Analysis/AliasAnalysisSummary.h b/contrib/llvm/lib/Analysis/AliasAnalysisSummary.h new file mode 100644 index 000000000000..fb93a12420f8 --- /dev/null +++ b/contrib/llvm/lib/Analysis/AliasAnalysisSummary.h @@ -0,0 +1,265 @@ +//=====- CFLSummary.h - Abstract stratified sets implementation. --------=====// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines various utility types and functions useful to +/// summary-based alias analysis. +/// +/// Summary-based analysis, also known as bottom-up analysis, is a style of +/// interprocedrual static analysis that tries to analyze the callees before the +/// callers get analyzed. The key idea of summary-based analysis is to first +/// process each function independently, outline its behavior in a condensed +/// summary, and then instantiate the summary at the callsite when the said +/// function is called elsewhere. This is often in contrast to another style +/// called top-down analysis, in which callers are always analyzed first before +/// the callees. +/// +/// In a summary-based analysis, functions must be examined independently and +/// out-of-context. We have no information on the state of the memory, the +/// arguments, the global values, and anything else external to the function. To +/// carry out the analysis conservative assumptions have to be made about those +/// external states. In exchange for the potential loss of precision, the +/// summary we obtain this way is highly reusable, which makes the analysis +/// easier to scale to large programs even if carried out context-sensitively. +/// +/// Currently, all CFL-based alias analyses adopt the summary-based approach +/// and therefore heavily rely on this header. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H +#define LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/CallSite.h" +#include <bitset> + +namespace llvm { +namespace cflaa { + +//===----------------------------------------------------------------------===// +// AliasAttr related stuffs +//===----------------------------------------------------------------------===// + +/// The number of attributes that AliasAttr should contain. Attributes are +/// described below, and 32 was an arbitrary choice because it fits nicely in 32 +/// bits (because we use a bitset for AliasAttr). +static const unsigned NumAliasAttrs = 32; + +/// These are attributes that an alias analysis can use to mark certain special +/// properties of a given pointer. Refer to the related functions below to see +/// what kinds of attributes are currently defined. +typedef std::bitset<NumAliasAttrs> AliasAttrs; + +/// Attr represent whether the said pointer comes from an unknown source +/// (such as opaque memory or an integer cast). +AliasAttrs getAttrNone(); + +/// AttrUnknown represent whether the said pointer comes from a source not known +/// to alias analyses (such as opaque memory or an integer cast). +AliasAttrs getAttrUnknown(); +bool hasUnknownAttr(AliasAttrs); + +/// AttrCaller represent whether the said pointer comes from a source not known +/// to the current function but known to the caller. Values pointed to by the +/// arguments of the current function have this attribute set +AliasAttrs getAttrCaller(); +bool hasCallerAttr(AliasAttrs); +bool hasUnknownOrCallerAttr(AliasAttrs); + +/// AttrEscaped represent whether the said pointer comes from a known source but +/// escapes to the unknown world (e.g. casted to an integer, or passed as an +/// argument to opaque function). Unlike non-escaped pointers, escaped ones may +/// alias pointers coming from unknown sources. +AliasAttrs getAttrEscaped(); +bool hasEscapedAttr(AliasAttrs); + +/// AttrGlobal represent whether the said pointer is a global value. +/// AttrArg represent whether the said pointer is an argument, and if so, what +/// index the argument has. +AliasAttrs getGlobalOrArgAttrFromValue(const Value &); +bool isGlobalOrArgAttr(AliasAttrs); + +/// Given an AliasAttrs, return a new AliasAttrs that only contains attributes +/// meaningful to the caller. This function is primarily used for +/// interprocedural analysis +/// Currently, externally visible AliasAttrs include AttrUnknown, AttrGlobal, +/// and AttrEscaped +AliasAttrs getExternallyVisibleAttrs(AliasAttrs); + +//===----------------------------------------------------------------------===// +// Function summary related stuffs +//===----------------------------------------------------------------------===// + +/// The maximum number of arguments we can put into a summary. +static const unsigned MaxSupportedArgsInSummary = 50; + +/// We use InterfaceValue to describe parameters/return value, as well as +/// potential memory locations that are pointed to by parameters/return value, +/// of a function. +/// Index is an integer which represents a single parameter or a return value. +/// When the index is 0, it refers to the return value. Non-zero index i refers +/// to the i-th parameter. +/// DerefLevel indicates the number of dereferences one must perform on the +/// parameter/return value to get this InterfaceValue. +struct InterfaceValue { +  unsigned Index; +  unsigned DerefLevel; +}; + +inline bool operator==(InterfaceValue LHS, InterfaceValue RHS) { +  return LHS.Index == RHS.Index && LHS.DerefLevel == RHS.DerefLevel; +} +inline bool operator!=(InterfaceValue LHS, InterfaceValue RHS) { +  return !(LHS == RHS); +} +inline bool operator<(InterfaceValue LHS, InterfaceValue RHS) { +  return LHS.Index < RHS.Index || +         (LHS.Index == RHS.Index && LHS.DerefLevel < RHS.DerefLevel); +} +inline bool operator>(InterfaceValue LHS, InterfaceValue RHS) { +  return RHS < LHS; +} +inline bool operator<=(InterfaceValue LHS, InterfaceValue RHS) { +  return !(RHS < LHS); +} +inline bool operator>=(InterfaceValue LHS, InterfaceValue RHS) { +  return !(LHS < RHS); +} + +// We use UnknownOffset to represent pointer offsets that cannot be determined +// at compile time. Note that MemoryLocation::UnknownSize cannot be used here +// because we require a signed value. +static const int64_t UnknownOffset = INT64_MAX; + +inline int64_t addOffset(int64_t LHS, int64_t RHS) { +  if (LHS == UnknownOffset || RHS == UnknownOffset) +    return UnknownOffset; +  // FIXME: Do we need to guard against integer overflow here? +  return LHS + RHS; +} + +/// We use ExternalRelation to describe an externally visible aliasing relations +/// between parameters/return value of a function. +struct ExternalRelation { +  InterfaceValue From, To; +  int64_t Offset; +}; + +inline bool operator==(ExternalRelation LHS, ExternalRelation RHS) { +  return LHS.From == RHS.From && LHS.To == RHS.To && LHS.Offset == RHS.Offset; +} +inline bool operator!=(ExternalRelation LHS, ExternalRelation RHS) { +  return !(LHS == RHS); +} +inline bool operator<(ExternalRelation LHS, ExternalRelation RHS) { +  if (LHS.From < RHS.From) +    return true; +  if (LHS.From > RHS.From) +    return false; +  if (LHS.To < RHS.To) +    return true; +  if (LHS.To > RHS.To) +    return false; +  return LHS.Offset < RHS.Offset; +} +inline bool operator>(ExternalRelation LHS, ExternalRelation RHS) { +  return RHS < LHS; +} +inline bool operator<=(ExternalRelation LHS, ExternalRelation RHS) { +  return !(RHS < LHS); +} +inline bool operator>=(ExternalRelation LHS, ExternalRelation RHS) { +  return !(LHS < RHS); +} + +/// We use ExternalAttribute to describe an externally visible AliasAttrs +/// for parameters/return value. +struct ExternalAttribute { +  InterfaceValue IValue; +  AliasAttrs Attr; +}; + +/// AliasSummary is just a collection of ExternalRelation and ExternalAttribute +struct AliasSummary { +  // RetParamRelations is a collection of ExternalRelations. +  SmallVector<ExternalRelation, 8> RetParamRelations; + +  // RetParamAttributes is a collection of ExternalAttributes. +  SmallVector<ExternalAttribute, 8> RetParamAttributes; +}; + +/// This is the result of instantiating InterfaceValue at a particular callsite +struct InstantiatedValue { +  Value *Val; +  unsigned DerefLevel; +}; +Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue, CallSite); + +inline bool operator==(InstantiatedValue LHS, InstantiatedValue RHS) { +  return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; +} +inline bool operator!=(InstantiatedValue LHS, InstantiatedValue RHS) { +  return !(LHS == RHS); +} +inline bool operator<(InstantiatedValue LHS, InstantiatedValue RHS) { +  return std::less<Value *>()(LHS.Val, RHS.Val) || +         (LHS.Val == RHS.Val && LHS.DerefLevel < RHS.DerefLevel); +} +inline bool operator>(InstantiatedValue LHS, InstantiatedValue RHS) { +  return RHS < LHS; +} +inline bool operator<=(InstantiatedValue LHS, InstantiatedValue RHS) { +  return !(RHS < LHS); +} +inline bool operator>=(InstantiatedValue LHS, InstantiatedValue RHS) { +  return !(LHS < RHS); +} + +/// This is the result of instantiating ExternalRelation at a particular +/// callsite +struct InstantiatedRelation { +  InstantiatedValue From, To; +  int64_t Offset; +}; +Optional<InstantiatedRelation> instantiateExternalRelation(ExternalRelation, +                                                           CallSite); + +/// This is the result of instantiating ExternalAttribute at a particular +/// callsite +struct InstantiatedAttr { +  InstantiatedValue IValue; +  AliasAttrs Attr; +}; +Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute, +                                                        CallSite); +} + +template <> struct DenseMapInfo<cflaa::InstantiatedValue> { +  static inline cflaa::InstantiatedValue getEmptyKey() { +    return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getEmptyKey(), +                                    DenseMapInfo<unsigned>::getEmptyKey()}; +  } +  static inline cflaa::InstantiatedValue getTombstoneKey() { +    return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getTombstoneKey(), +                                    DenseMapInfo<unsigned>::getTombstoneKey()}; +  } +  static unsigned getHashValue(const cflaa::InstantiatedValue &IV) { +    return DenseMapInfo<std::pair<Value *, unsigned>>::getHashValue( +        std::make_pair(IV.Val, IV.DerefLevel)); +  } +  static bool isEqual(const cflaa::InstantiatedValue &LHS, +                      const cflaa::InstantiatedValue &RHS) { +    return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; +  } +}; +} + +#endif diff --git a/contrib/llvm/lib/Analysis/AliasSetTracker.cpp b/contrib/llvm/lib/Analysis/AliasSetTracker.cpp new file mode 100644 index 000000000000..8f903fa4f1e8 --- /dev/null +++ b/contrib/llvm/lib/Analysis/AliasSetTracker.cpp @@ -0,0 +1,729 @@ +//===- AliasSetTracker.cpp - Alias Sets Tracker implementation-------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the AliasSetTracker and AliasSet classes. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <vector> + +using namespace llvm; + +static cl::opt<unsigned> +    SaturationThreshold("alias-set-saturation-threshold", cl::Hidden, +                        cl::init(250), +                        cl::desc("The maximum number of pointers may-alias " +                                 "sets may contain before degradation")); + +/// mergeSetIn - Merge the specified alias set into this alias set. +/// +void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) { +  assert(!AS.Forward && "Alias set is already forwarding!"); +  assert(!Forward && "This set is a forwarding set!!"); + +  bool WasMustAlias = (Alias == SetMustAlias); +  // Update the alias and access types of this set... +  Access |= AS.Access; +  Alias  |= AS.Alias; +  Volatile |= AS.Volatile; + +  if (Alias == SetMustAlias) { +    // Check that these two merged sets really are must aliases.  Since both +    // used to be must-alias sets, we can just check any pointer from each set +    // for aliasing. +    AliasAnalysis &AA = AST.getAliasAnalysis(); +    PointerRec *L = getSomePointer(); +    PointerRec *R = AS.getSomePointer(); + +    // If the pointers are not a must-alias pair, this set becomes a may alias. +    if (AA.alias(MemoryLocation(L->getValue(), L->getSize(), L->getAAInfo()), +                 MemoryLocation(R->getValue(), R->getSize(), R->getAAInfo())) != +        MustAlias) +      Alias = SetMayAlias; +  } + +  if (Alias == SetMayAlias) { +    if (WasMustAlias) +      AST.TotalMayAliasSetSize += size(); +    if (AS.Alias == SetMustAlias) +      AST.TotalMayAliasSetSize += AS.size(); +  } + +  bool ASHadUnknownInsts = !AS.UnknownInsts.empty(); +  if (UnknownInsts.empty()) {            // Merge call sites... +    if (ASHadUnknownInsts) { +      std::swap(UnknownInsts, AS.UnknownInsts); +      addRef(); +    } +  } else if (ASHadUnknownInsts) { +    UnknownInsts.insert(UnknownInsts.end(), AS.UnknownInsts.begin(), AS.UnknownInsts.end()); +    AS.UnknownInsts.clear(); +  } + +  AS.Forward = this; // Forward across AS now... +  addRef();          // AS is now pointing to us... + +  // Merge the list of constituent pointers... +  if (AS.PtrList) { +    SetSize += AS.size(); +    AS.SetSize = 0; +    *PtrListEnd = AS.PtrList; +    AS.PtrList->setPrevInList(PtrListEnd); +    PtrListEnd = AS.PtrListEnd; + +    AS.PtrList = nullptr; +    AS.PtrListEnd = &AS.PtrList; +    assert(*AS.PtrListEnd == nullptr && "End of list is not null?"); +  } +  if (ASHadUnknownInsts) +    AS.dropRef(AST); +} + +void AliasSetTracker::removeAliasSet(AliasSet *AS) { +  if (AliasSet *Fwd = AS->Forward) { +    Fwd->dropRef(*this); +    AS->Forward = nullptr; +  } + +  if (AS->Alias == AliasSet::SetMayAlias) +    TotalMayAliasSetSize -= AS->size(); + +  AliasSets.erase(AS); +} + +void AliasSet::removeFromTracker(AliasSetTracker &AST) { +  assert(RefCount == 0 && "Cannot remove non-dead alias set from tracker!"); +  AST.removeAliasSet(this); +} + +void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, +                          LocationSize Size, const AAMDNodes &AAInfo, +                          bool KnownMustAlias) { +  assert(!Entry.hasAliasSet() && "Entry already in set!"); + +  // Check to see if we have to downgrade to _may_ alias. +  if (isMustAlias() && !KnownMustAlias) +    if (PointerRec *P = getSomePointer()) { +      AliasAnalysis &AA = AST.getAliasAnalysis(); +      AliasResult Result = +          AA.alias(MemoryLocation(P->getValue(), P->getSize(), P->getAAInfo()), +                   MemoryLocation(Entry.getValue(), Size, AAInfo)); +      if (Result != MustAlias) { +        Alias = SetMayAlias; +        AST.TotalMayAliasSetSize += size(); +      } else { +        // First entry of must alias must have maximum size! +        P->updateSizeAndAAInfo(Size, AAInfo); +      } +      assert(Result != NoAlias && "Cannot be part of must set!"); +    } + +  Entry.setAliasSet(this); +  Entry.updateSizeAndAAInfo(Size, AAInfo); + +  // Add it to the end of the list... +  ++SetSize; +  assert(*PtrListEnd == nullptr && "End of list is not null?"); +  *PtrListEnd = &Entry; +  PtrListEnd = Entry.setPrevInList(PtrListEnd); +  assert(*PtrListEnd == nullptr && "End of list is not null?"); +  // Entry points to alias set. +  addRef(); + +  if (Alias == SetMayAlias) +    AST.TotalMayAliasSetSize++; +} + +void AliasSet::addUnknownInst(Instruction *I, AliasAnalysis &AA) { +  if (UnknownInsts.empty()) +    addRef(); +  UnknownInsts.emplace_back(I); + +  if (!I->mayWriteToMemory()) { +    Alias = SetMayAlias; +    Access |= RefAccess; +    return; +  } + +  // FIXME: This should use mod/ref information to make this not suck so bad +  Alias = SetMayAlias; +  Access = ModRefAccess; +} + +/// aliasesPointer - Return true if the specified pointer "may" (or must) +/// alias one of the members in the set. +/// +bool AliasSet::aliasesPointer(const Value *Ptr, LocationSize Size, +                              const AAMDNodes &AAInfo, +                              AliasAnalysis &AA) const { +  if (AliasAny) +    return true; + +  if (Alias == SetMustAlias) { +    assert(UnknownInsts.empty() && "Illegal must alias set!"); + +    // If this is a set of MustAliases, only check to see if the pointer aliases +    // SOME value in the set. +    PointerRec *SomePtr = getSomePointer(); +    assert(SomePtr && "Empty must-alias set??"); +    return AA.alias(MemoryLocation(SomePtr->getValue(), SomePtr->getSize(), +                                   SomePtr->getAAInfo()), +                    MemoryLocation(Ptr, Size, AAInfo)); +  } + +  // If this is a may-alias set, we have to check all of the pointers in the set +  // to be sure it doesn't alias the set... +  for (iterator I = begin(), E = end(); I != E; ++I) +    if (AA.alias(MemoryLocation(Ptr, Size, AAInfo), +                 MemoryLocation(I.getPointer(), I.getSize(), I.getAAInfo()))) +      return true; + +  // Check the unknown instructions... +  if (!UnknownInsts.empty()) { +    for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) +      if (auto *Inst = getUnknownInst(i)) +        if (isModOrRefSet( +                AA.getModRefInfo(Inst, MemoryLocation(Ptr, Size, AAInfo)))) +          return true; +  } + +  return false; +} + +bool AliasSet::aliasesUnknownInst(const Instruction *Inst, +                                  AliasAnalysis &AA) const { + +  if (AliasAny) +    return true; + +  if (!Inst->mayReadOrWriteMemory()) +    return false; + +  for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { +    if (auto *UnknownInst = getUnknownInst(i)) { +      ImmutableCallSite C1(UnknownInst), C2(Inst); +      if (!C1 || !C2 || isModOrRefSet(AA.getModRefInfo(C1, C2)) || +          isModOrRefSet(AA.getModRefInfo(C2, C1))) +        return true; +    } +  } + +  for (iterator I = begin(), E = end(); I != E; ++I) +    if (isModOrRefSet(AA.getModRefInfo( +            Inst, MemoryLocation(I.getPointer(), I.getSize(), I.getAAInfo())))) +      return true; + +  return false; +} + +void AliasSetTracker::clear() { +  // Delete all the PointerRec entries. +  for (PointerMapType::iterator I = PointerMap.begin(), E = PointerMap.end(); +       I != E; ++I) +    I->second->eraseFromList(); + +  PointerMap.clear(); + +  // The alias sets should all be clear now. +  AliasSets.clear(); +} + + +/// mergeAliasSetsForPointer - Given a pointer, merge all alias sets that may +/// alias the pointer. Return the unified set, or nullptr if no set that aliases +/// the pointer was found. +AliasSet *AliasSetTracker::mergeAliasSetsForPointer(const Value *Ptr, +                                                    LocationSize Size, +                                                    const AAMDNodes &AAInfo) { +  AliasSet *FoundSet = nullptr; +  for (iterator I = begin(), E = end(); I != E;) { +    iterator Cur = I++; +    if (Cur->Forward || !Cur->aliasesPointer(Ptr, Size, AAInfo, AA)) continue; + +    if (!FoundSet) {      // If this is the first alias set ptr can go into. +      FoundSet = &*Cur;   // Remember it. +    } else {              // Otherwise, we must merge the sets. +      FoundSet->mergeSetIn(*Cur, *this);     // Merge in contents. +    } +  } + +  return FoundSet; +} + +bool AliasSetTracker::containsUnknown(const Instruction *Inst) const { +  for (const AliasSet &AS : *this) +    if (!AS.Forward && AS.aliasesUnknownInst(Inst, AA)) +      return true; +  return false; +} + +AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { +  AliasSet *FoundSet = nullptr; +  for (iterator I = begin(), E = end(); I != E;) { +    iterator Cur = I++; +    if (Cur->Forward || !Cur->aliasesUnknownInst(Inst, AA)) +      continue; +    if (!FoundSet)            // If this is the first alias set ptr can go into. +      FoundSet = &*Cur;       // Remember it. +    else if (!Cur->Forward)   // Otherwise, we must merge the sets. +      FoundSet->mergeSetIn(*Cur, *this);     // Merge in contents. +  } +  return FoundSet; +} + +/// getAliasSetForPointer - Return the alias set that the specified pointer +/// lives in. +AliasSet &AliasSetTracker::getAliasSetForPointer(Value *Pointer, +                                                 LocationSize Size, +                                                 const AAMDNodes &AAInfo) { +  AliasSet::PointerRec &Entry = getEntryFor(Pointer); + +  if (AliasAnyAS) { +    // At this point, the AST is saturated, so we only have one active alias +    // set. That means we already know which alias set we want to return, and +    // just need to add the pointer to that set to keep the data structure +    // consistent. +    // This, of course, means that we will never need a merge here. +    if (Entry.hasAliasSet()) { +      Entry.updateSizeAndAAInfo(Size, AAInfo); +      assert(Entry.getAliasSet(*this) == AliasAnyAS && +             "Entry in saturated AST must belong to only alias set"); +    } else { +      AliasAnyAS->addPointer(*this, Entry, Size, AAInfo); +    } +    return *AliasAnyAS; +  } + +  // Check to see if the pointer is already known. +  if (Entry.hasAliasSet()) { +    // If the size changed, we may need to merge several alias sets. +    // Note that we can *not* return the result of mergeAliasSetsForPointer +    // due to a quirk of alias analysis behavior. Since alias(undef, undef) +    // is NoAlias, mergeAliasSetsForPointer(undef, ...) will not find the +    // the right set for undef, even if it exists. +    if (Entry.updateSizeAndAAInfo(Size, AAInfo)) +      mergeAliasSetsForPointer(Pointer, Size, AAInfo); +    // Return the set! +    return *Entry.getAliasSet(*this)->getForwardedTarget(*this); +  } + +  if (AliasSet *AS = mergeAliasSetsForPointer(Pointer, Size, AAInfo)) { +    // Add it to the alias set it aliases. +    AS->addPointer(*this, Entry, Size, AAInfo); +    return *AS; +  } + +  // Otherwise create a new alias set to hold the loaded pointer. +  AliasSets.push_back(new AliasSet()); +  AliasSets.back().addPointer(*this, Entry, Size, AAInfo); +  return AliasSets.back(); +} + +void AliasSetTracker::add(Value *Ptr, LocationSize Size, +                          const AAMDNodes &AAInfo) { +  addPointer(Ptr, Size, AAInfo, AliasSet::NoAccess); +} + +void AliasSetTracker::add(LoadInst *LI) { +  if (isStrongerThanMonotonic(LI->getOrdering())) return addUnknown(LI); + +  AAMDNodes AAInfo; +  LI->getAAMetadata(AAInfo); + +  AliasSet::AccessLattice Access = AliasSet::RefAccess; +  const DataLayout &DL = LI->getModule()->getDataLayout(); +  AliasSet &AS = addPointer(LI->getOperand(0), +                            DL.getTypeStoreSize(LI->getType()), AAInfo, Access); +  if (LI->isVolatile()) AS.setVolatile(); +} + +void AliasSetTracker::add(StoreInst *SI) { +  if (isStrongerThanMonotonic(SI->getOrdering())) return addUnknown(SI); + +  AAMDNodes AAInfo; +  SI->getAAMetadata(AAInfo); + +  AliasSet::AccessLattice Access = AliasSet::ModAccess; +  const DataLayout &DL = SI->getModule()->getDataLayout(); +  Value *Val = SI->getOperand(0); +  AliasSet &AS = addPointer( +      SI->getOperand(1), DL.getTypeStoreSize(Val->getType()), AAInfo, Access); +  if (SI->isVolatile()) AS.setVolatile(); +} + +void AliasSetTracker::add(VAArgInst *VAAI) { +  AAMDNodes AAInfo; +  VAAI->getAAMetadata(AAInfo); + +  addPointer(VAAI->getOperand(0), MemoryLocation::UnknownSize, AAInfo, +             AliasSet::ModRefAccess); +} + +void AliasSetTracker::add(AnyMemSetInst *MSI) { +  AAMDNodes AAInfo; +  MSI->getAAMetadata(AAInfo); + +  uint64_t Len; + +  if (ConstantInt *C = dyn_cast<ConstantInt>(MSI->getLength())) +    Len = C->getZExtValue(); +  else +    Len = MemoryLocation::UnknownSize; + +  AliasSet &AS = +      addPointer(MSI->getRawDest(), Len, AAInfo, AliasSet::ModAccess); +  auto *MS = dyn_cast<MemSetInst>(MSI); +  if (MS && MS->isVolatile()) +    AS.setVolatile(); +} + +void AliasSetTracker::add(AnyMemTransferInst *MTI) { +  AAMDNodes AAInfo; +  MTI->getAAMetadata(AAInfo); + +  uint64_t Len; +  if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) +    Len = C->getZExtValue(); +  else +    Len = MemoryLocation::UnknownSize; + +  AliasSet &ASSrc = +      addPointer(MTI->getRawSource(), Len, AAInfo, AliasSet::RefAccess); + +  AliasSet &ASDst = +      addPointer(MTI->getRawDest(), Len, AAInfo, AliasSet::ModAccess); + +  auto* MT = dyn_cast<MemTransferInst>(MTI); +  if (MT && MT->isVolatile()) { +    ASSrc.setVolatile(); +    ASDst.setVolatile(); +  } +} + +void AliasSetTracker::addUnknown(Instruction *Inst) { +  if (isa<DbgInfoIntrinsic>(Inst)) +    return; // Ignore DbgInfo Intrinsics. + +  if (auto *II = dyn_cast<IntrinsicInst>(Inst)) { +    // These intrinsics will show up as affecting memory, but they are just +    // markers. +    switch (II->getIntrinsicID()) { +    default: +      break; +      // FIXME: Add lifetime/invariant intrinsics (See: PR30807). +    case Intrinsic::assume: +    case Intrinsic::sideeffect: +      return; +    } +  } +  if (!Inst->mayReadOrWriteMemory()) +    return; // doesn't alias anything + +  AliasSet *AS = findAliasSetForUnknownInst(Inst); +  if (AS) { +    AS->addUnknownInst(Inst, AA); +    return; +  } +  AliasSets.push_back(new AliasSet()); +  AS = &AliasSets.back(); +  AS->addUnknownInst(Inst, AA); +} + +void AliasSetTracker::add(Instruction *I) { +  // Dispatch to one of the other add methods. +  if (LoadInst *LI = dyn_cast<LoadInst>(I)) +    return add(LI); +  if (StoreInst *SI = dyn_cast<StoreInst>(I)) +    return add(SI); +  if (VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) +    return add(VAAI); +  if (AnyMemSetInst *MSI = dyn_cast<AnyMemSetInst>(I)) +    return add(MSI); +  if (AnyMemTransferInst *MTI = dyn_cast<AnyMemTransferInst>(I)) +    return add(MTI); +  return addUnknown(I); +} + +void AliasSetTracker::add(BasicBlock &BB) { +  for (auto &I : BB) +    add(&I); +} + +void AliasSetTracker::add(const AliasSetTracker &AST) { +  assert(&AA == &AST.AA && +         "Merging AliasSetTracker objects with different Alias Analyses!"); + +  // Loop over all of the alias sets in AST, adding the pointers contained +  // therein into the current alias sets.  This can cause alias sets to be +  // merged together in the current AST. +  for (const AliasSet &AS : AST) { +    if (AS.Forward) +      continue; // Ignore forwarding alias sets + +    // If there are any call sites in the alias set, add them to this AST. +    for (unsigned i = 0, e = AS.UnknownInsts.size(); i != e; ++i) +      if (auto *Inst = AS.getUnknownInst(i)) +        add(Inst); + +    // Loop over all of the pointers in this alias set. +    for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) { +      AliasSet &NewAS = +          addPointer(ASI.getPointer(), ASI.getSize(), ASI.getAAInfo(), +                     (AliasSet::AccessLattice)AS.Access); +      if (AS.isVolatile()) NewAS.setVolatile(); +    } +  } +} + +// deleteValue method - This method is used to remove a pointer value from the +// AliasSetTracker entirely.  It should be used when an instruction is deleted +// from the program to update the AST.  If you don't use this, you would have +// dangling pointers to deleted instructions. +// +void AliasSetTracker::deleteValue(Value *PtrVal) { +  // First, look up the PointerRec for this pointer. +  PointerMapType::iterator I = PointerMap.find_as(PtrVal); +  if (I == PointerMap.end()) return;  // Noop + +  // If we found one, remove the pointer from the alias set it is in. +  AliasSet::PointerRec *PtrValEnt = I->second; +  AliasSet *AS = PtrValEnt->getAliasSet(*this); + +  // Unlink and delete from the list of values. +  PtrValEnt->eraseFromList(); + +  if (AS->Alias == AliasSet::SetMayAlias) { +    AS->SetSize--; +    TotalMayAliasSetSize--; +  } + +  // Stop using the alias set. +  AS->dropRef(*this); + +  PointerMap.erase(I); +} + +// copyValue - This method should be used whenever a preexisting value in the +// program is copied or cloned, introducing a new value.  Note that it is ok for +// clients that use this method to introduce the same value multiple times: if +// the tracker already knows about a value, it will ignore the request. +// +void AliasSetTracker::copyValue(Value *From, Value *To) { +  // First, look up the PointerRec for this pointer. +  PointerMapType::iterator I = PointerMap.find_as(From); +  if (I == PointerMap.end()) +    return;  // Noop +  assert(I->second->hasAliasSet() && "Dead entry?"); + +  AliasSet::PointerRec &Entry = getEntryFor(To); +  if (Entry.hasAliasSet()) return;    // Already in the tracker! + +  // getEntryFor above may invalidate iterator \c I, so reinitialize it. +  I = PointerMap.find_as(From); +  // Add it to the alias set it aliases... +  AliasSet *AS = I->second->getAliasSet(*this); +  AS->addPointer(*this, Entry, I->second->getSize(), +                 I->second->getAAInfo(), +                 true); +} + +AliasSet &AliasSetTracker::mergeAllAliasSets() { +  assert(!AliasAnyAS && (TotalMayAliasSetSize > SaturationThreshold) && +         "Full merge should happen once, when the saturation threshold is " +         "reached"); + +  // Collect all alias sets, so that we can drop references with impunity +  // without worrying about iterator invalidation. +  std::vector<AliasSet *> ASVector; +  ASVector.reserve(SaturationThreshold); +  for (iterator I = begin(), E = end(); I != E; I++) +    ASVector.push_back(&*I); + +  // Copy all instructions and pointers into a new set, and forward all other +  // sets to it. +  AliasSets.push_back(new AliasSet()); +  AliasAnyAS = &AliasSets.back(); +  AliasAnyAS->Alias = AliasSet::SetMayAlias; +  AliasAnyAS->Access = AliasSet::ModRefAccess; +  AliasAnyAS->AliasAny = true; + +  for (auto Cur : ASVector) { +    // If Cur was already forwarding, just forward to the new AS instead. +    AliasSet *FwdTo = Cur->Forward; +    if (FwdTo) { +      Cur->Forward = AliasAnyAS; +      AliasAnyAS->addRef(); +      FwdTo->dropRef(*this); +      continue; +    } + +    // Otherwise, perform the actual merge. +    AliasAnyAS->mergeSetIn(*Cur, *this); +  } + +  return *AliasAnyAS; +} + +AliasSet &AliasSetTracker::addPointer(Value *P, LocationSize Size, +                                      const AAMDNodes &AAInfo, +                                      AliasSet::AccessLattice E) { +  AliasSet &AS = getAliasSetForPointer(P, Size, AAInfo); +  AS.Access |= E; + +  if (!AliasAnyAS && (TotalMayAliasSetSize > SaturationThreshold)) { +    // The AST is now saturated. From here on, we conservatively consider all +    // pointers to alias each-other. +    return mergeAllAliasSets(); +  } + +  return AS; +} + +//===----------------------------------------------------------------------===// +//               AliasSet/AliasSetTracker Printing Support +//===----------------------------------------------------------------------===// + +void AliasSet::print(raw_ostream &OS) const { +  OS << "  AliasSet[" << (const void*)this << ", " << RefCount << "] "; +  OS << (Alias == SetMustAlias ? "must" : "may") << " alias, "; +  switch (Access) { +  case NoAccess:     OS << "No access "; break; +  case RefAccess:    OS << "Ref       "; break; +  case ModAccess:    OS << "Mod       "; break; +  case ModRefAccess: OS << "Mod/Ref   "; break; +  default: llvm_unreachable("Bad value for Access!"); +  } +  if (isVolatile()) OS << "[volatile] "; +  if (Forward) +    OS << " forwarding to " << (void*)Forward; + +  if (!empty()) { +    OS << "Pointers: "; +    for (iterator I = begin(), E = end(); I != E; ++I) { +      if (I != begin()) OS << ", "; +      I.getPointer()->printAsOperand(OS << "("); +      OS << ", " << I.getSize() << ")"; +    } +  } +  if (!UnknownInsts.empty()) { +    OS << "\n    " << UnknownInsts.size() << " Unknown instructions: "; +    for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { +      if (i) OS << ", "; +      if (auto *I = getUnknownInst(i)) { +        if (I->hasName()) +          I->printAsOperand(OS); +        else +          I->print(OS); +      } +    } +  } +  OS << "\n"; +} + +void AliasSetTracker::print(raw_ostream &OS) const { +  OS << "Alias Set Tracker: " << AliasSets.size() << " alias sets for " +     << PointerMap.size() << " pointer values.\n"; +  for (const AliasSet &AS : *this) +    AS.print(OS); +  OS << "\n"; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void AliasSet::dump() const { print(dbgs()); } +LLVM_DUMP_METHOD void AliasSetTracker::dump() const { print(dbgs()); } +#endif + +//===----------------------------------------------------------------------===// +//                     ASTCallbackVH Class Implementation +//===----------------------------------------------------------------------===// + +void AliasSetTracker::ASTCallbackVH::deleted() { +  assert(AST && "ASTCallbackVH called with a null AliasSetTracker!"); +  AST->deleteValue(getValPtr()); +  // this now dangles! +} + +void AliasSetTracker::ASTCallbackVH::allUsesReplacedWith(Value *V) { +  AST->copyValue(getValPtr(), V); +} + +AliasSetTracker::ASTCallbackVH::ASTCallbackVH(Value *V, AliasSetTracker *ast) +  : CallbackVH(V), AST(ast) {} + +AliasSetTracker::ASTCallbackVH & +AliasSetTracker::ASTCallbackVH::operator=(Value *V) { +  return *this = ASTCallbackVH(V, AST); +} + +//===----------------------------------------------------------------------===// +//                            AliasSetPrinter Pass +//===----------------------------------------------------------------------===// + +namespace { + +  class AliasSetPrinter : public FunctionPass { +    AliasSetTracker *Tracker; + +  public: +    static char ID; // Pass identification, replacement for typeid + +    AliasSetPrinter() : FunctionPass(ID) { +      initializeAliasSetPrinterPass(*PassRegistry::getPassRegistry()); +    } + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +      AU.addRequired<AAResultsWrapperPass>(); +    } + +    bool runOnFunction(Function &F) override { +      auto &AAWP = getAnalysis<AAResultsWrapperPass>(); +      Tracker = new AliasSetTracker(AAWP.getAAResults()); +      errs() << "Alias sets for function '" << F.getName() << "':\n"; +      for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) +        Tracker->add(&*I); +      Tracker->print(errs()); +      delete Tracker; +      return false; +    } +  }; + +} // end anonymous namespace + +char AliasSetPrinter::ID = 0; + +INITIALIZE_PASS_BEGIN(AliasSetPrinter, "print-alias-sets", +                "Alias Set Printer", false, true) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(AliasSetPrinter, "print-alias-sets", +                "Alias Set Printer", false, true) diff --git a/contrib/llvm/lib/Analysis/Analysis.cpp b/contrib/llvm/lib/Analysis/Analysis.cpp new file mode 100644 index 000000000000..30576cf1ae10 --- /dev/null +++ b/contrib/llvm/lib/Analysis/Analysis.cpp @@ -0,0 +1,136 @@ +//===-- Analysis.cpp ------------------------------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm-c/Analysis.h" +#include "llvm-c/Initialization.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" +#include "llvm/PassRegistry.h" +#include "llvm/Support/raw_ostream.h" +#include <cstring> + +using namespace llvm; + +/// initializeAnalysis - Initialize all passes linked into the Analysis library. +void llvm::initializeAnalysis(PassRegistry &Registry) { +  initializeAAEvalLegacyPassPass(Registry); +  initializeAliasSetPrinterPass(Registry); +  initializeBasicAAWrapperPassPass(Registry); +  initializeBlockFrequencyInfoWrapperPassPass(Registry); +  initializeBranchProbabilityInfoWrapperPassPass(Registry); +  initializeCallGraphWrapperPassPass(Registry); +  initializeCallGraphDOTPrinterPass(Registry); +  initializeCallGraphPrinterLegacyPassPass(Registry); +  initializeCallGraphViewerPass(Registry); +  initializeCostModelAnalysisPass(Registry); +  initializeCFGViewerLegacyPassPass(Registry); +  initializeCFGPrinterLegacyPassPass(Registry); +  initializeCFGOnlyViewerLegacyPassPass(Registry); +  initializeCFGOnlyPrinterLegacyPassPass(Registry); +  initializeCFLAndersAAWrapperPassPass(Registry); +  initializeCFLSteensAAWrapperPassPass(Registry); +  initializeDependenceAnalysisWrapperPassPass(Registry); +  initializeDelinearizationPass(Registry); +  initializeDemandedBitsWrapperPassPass(Registry); +  initializeDivergenceAnalysisPass(Registry); +  initializeDominanceFrontierWrapperPassPass(Registry); +  initializeDomViewerPass(Registry); +  initializeDomPrinterPass(Registry); +  initializeDomOnlyViewerPass(Registry); +  initializePostDomViewerPass(Registry); +  initializeDomOnlyPrinterPass(Registry); +  initializePostDomPrinterPass(Registry); +  initializePostDomOnlyViewerPass(Registry); +  initializePostDomOnlyPrinterPass(Registry); +  initializeAAResultsWrapperPassPass(Registry); +  initializeGlobalsAAWrapperPassPass(Registry); +  initializeIVUsersWrapperPassPass(Registry); +  initializeInstCountPass(Registry); +  initializeIntervalPartitionPass(Registry); +  initializeLazyBranchProbabilityInfoPassPass(Registry); +  initializeLazyBlockFrequencyInfoPassPass(Registry); +  initializeLazyValueInfoWrapperPassPass(Registry); +  initializeLazyValueInfoPrinterPass(Registry); +  initializeLintPass(Registry); +  initializeLoopInfoWrapperPassPass(Registry); +  initializeMemDepPrinterPass(Registry); +  initializeMemDerefPrinterPass(Registry); +  initializeMemoryDependenceWrapperPassPass(Registry); +  initializeModuleDebugInfoPrinterPass(Registry); +  initializeModuleSummaryIndexWrapperPassPass(Registry); +  initializeMustExecutePrinterPass(Registry); +  initializeObjCARCAAWrapperPassPass(Registry); +  initializeOptimizationRemarkEmitterWrapperPassPass(Registry); +  initializePhiValuesWrapperPassPass(Registry); +  initializePostDominatorTreeWrapperPassPass(Registry); +  initializeRegionInfoPassPass(Registry); +  initializeRegionViewerPass(Registry); +  initializeRegionPrinterPass(Registry); +  initializeRegionOnlyViewerPass(Registry); +  initializeRegionOnlyPrinterPass(Registry); +  initializeSCEVAAWrapperPassPass(Registry); +  initializeScalarEvolutionWrapperPassPass(Registry); +  initializeTargetTransformInfoWrapperPassPass(Registry); +  initializeTypeBasedAAWrapperPassPass(Registry); +  initializeScopedNoAliasAAWrapperPassPass(Registry); +  initializeLCSSAVerificationPassPass(Registry); +  initializeMemorySSAWrapperPassPass(Registry); +  initializeMemorySSAPrinterLegacyPassPass(Registry); +} + +void LLVMInitializeAnalysis(LLVMPassRegistryRef R) { +  initializeAnalysis(*unwrap(R)); +} + +void LLVMInitializeIPA(LLVMPassRegistryRef R) { +  initializeAnalysis(*unwrap(R)); +} + +LLVMBool LLVMVerifyModule(LLVMModuleRef M, LLVMVerifierFailureAction Action, +                          char **OutMessages) { +  raw_ostream *DebugOS = Action != LLVMReturnStatusAction ? &errs() : nullptr; +  std::string Messages; +  raw_string_ostream MsgsOS(Messages); + +  LLVMBool Result = verifyModule(*unwrap(M), OutMessages ? &MsgsOS : DebugOS); + +  // Duplicate the output to stderr. +  if (DebugOS && OutMessages) +    *DebugOS << MsgsOS.str(); + +  if (Action == LLVMAbortProcessAction && Result) +    report_fatal_error("Broken module found, compilation aborted!"); + +  if (OutMessages) +    *OutMessages = strdup(MsgsOS.str().c_str()); + +  return Result; +} + +LLVMBool LLVMVerifyFunction(LLVMValueRef Fn, LLVMVerifierFailureAction Action) { +  LLVMBool Result = verifyFunction( +      *unwrap<Function>(Fn), Action != LLVMReturnStatusAction ? &errs() +                                                              : nullptr); + +  if (Action == LLVMAbortProcessAction && Result) +    report_fatal_error("Broken function found, compilation aborted!"); + +  return Result; +} + +void LLVMViewFunctionCFG(LLVMValueRef Fn) { +  Function *F = unwrap<Function>(Fn); +  F->viewCFG(); +} + +void LLVMViewFunctionCFGOnly(LLVMValueRef Fn) { +  Function *F = unwrap<Function>(Fn); +  F->viewCFGOnly(); +} diff --git a/contrib/llvm/lib/Analysis/AssumptionCache.cpp b/contrib/llvm/lib/Analysis/AssumptionCache.cpp new file mode 100644 index 000000000000..8bfd24ccf77b --- /dev/null +++ b/contrib/llvm/lib/Analysis/AssumptionCache.cpp @@ -0,0 +1,275 @@ +//===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass that keeps track of @llvm.assume intrinsics in +// the functions of a module. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <utility> + +using namespace llvm; +using namespace llvm::PatternMatch; + +static cl::opt<bool> +    VerifyAssumptionCache("verify-assumption-cache", cl::Hidden, +                          cl::desc("Enable verification of assumption cache"), +                          cl::init(false)); + +SmallVector<WeakTrackingVH, 1> & +AssumptionCache::getOrInsertAffectedValues(Value *V) { +  // Try using find_as first to avoid creating extra value handles just for the +  // purpose of doing the lookup. +  auto AVI = AffectedValues.find_as(V); +  if (AVI != AffectedValues.end()) +    return AVI->second; + +  auto AVIP = AffectedValues.insert( +      {AffectedValueCallbackVH(V, this), SmallVector<WeakTrackingVH, 1>()}); +  return AVIP.first->second; +} + +void AssumptionCache::updateAffectedValues(CallInst *CI) { +  // Note: This code must be kept in-sync with the code in +  // computeKnownBitsFromAssume in ValueTracking. + +  SmallVector<Value *, 16> Affected; +  auto AddAffected = [&Affected](Value *V) { +    if (isa<Argument>(V)) { +      Affected.push_back(V); +    } else if (auto *I = dyn_cast<Instruction>(V)) { +      Affected.push_back(I); + +      // Peek through unary operators to find the source of the condition. +      Value *Op; +      if (match(I, m_BitCast(m_Value(Op))) || +          match(I, m_PtrToInt(m_Value(Op))) || +          match(I, m_Not(m_Value(Op)))) { +        if (isa<Instruction>(Op) || isa<Argument>(Op)) +          Affected.push_back(Op); +      } +    } +  }; + +  Value *Cond = CI->getArgOperand(0), *A, *B; +  AddAffected(Cond); + +  CmpInst::Predicate Pred; +  if (match(Cond, m_ICmp(Pred, m_Value(A), m_Value(B)))) { +    AddAffected(A); +    AddAffected(B); + +    if (Pred == ICmpInst::ICMP_EQ) { +      // For equality comparisons, we handle the case of bit inversion. +      auto AddAffectedFromEq = [&AddAffected](Value *V) { +        Value *A; +        if (match(V, m_Not(m_Value(A)))) { +          AddAffected(A); +          V = A; +        } + +        Value *B; +        ConstantInt *C; +        // (A & B) or (A | B) or (A ^ B). +        if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) { +          AddAffected(A); +          AddAffected(B); +        // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant. +        } else if (match(V, m_Shift(m_Value(A), m_ConstantInt(C)))) { +          AddAffected(A); +        } +      }; + +      AddAffectedFromEq(A); +      AddAffectedFromEq(B); +    } +  } + +  for (auto &AV : Affected) { +    auto &AVV = getOrInsertAffectedValues(AV); +    if (std::find(AVV.begin(), AVV.end(), CI) == AVV.end()) +      AVV.push_back(CI); +  } +} + +void AssumptionCache::AffectedValueCallbackVH::deleted() { +  auto AVI = AC->AffectedValues.find(getValPtr()); +  if (AVI != AC->AffectedValues.end()) +    AC->AffectedValues.erase(AVI); +  // 'this' now dangles! +} + +void AssumptionCache::copyAffectedValuesInCache(Value *OV, Value *NV) { +  auto &NAVV = getOrInsertAffectedValues(NV); +  auto AVI = AffectedValues.find(OV); +  if (AVI == AffectedValues.end()) +    return; + +  for (auto &A : AVI->second) +    if (std::find(NAVV.begin(), NAVV.end(), A) == NAVV.end()) +      NAVV.push_back(A); +} + +void AssumptionCache::AffectedValueCallbackVH::allUsesReplacedWith(Value *NV) { +  if (!isa<Instruction>(NV) && !isa<Argument>(NV)) +    return; + +  // Any assumptions that affected this value now affect the new value. + +  AC->copyAffectedValuesInCache(getValPtr(), NV); +  // 'this' now might dangle! If the AffectedValues map was resized to add an +  // entry for NV then this object might have been destroyed in favor of some +  // copy in the grown map. +} + +void AssumptionCache::scanFunction() { +  assert(!Scanned && "Tried to scan the function twice!"); +  assert(AssumeHandles.empty() && "Already have assumes when scanning!"); + +  // Go through all instructions in all blocks, add all calls to @llvm.assume +  // to this cache. +  for (BasicBlock &B : F) +    for (Instruction &II : B) +      if (match(&II, m_Intrinsic<Intrinsic::assume>())) +        AssumeHandles.push_back(&II); + +  // Mark the scan as complete. +  Scanned = true; + +  // Update affected values. +  for (auto &A : AssumeHandles) +    updateAffectedValues(cast<CallInst>(A)); +} + +void AssumptionCache::registerAssumption(CallInst *CI) { +  assert(match(CI, m_Intrinsic<Intrinsic::assume>()) && +         "Registered call does not call @llvm.assume"); + +  // If we haven't scanned the function yet, just drop this assumption. It will +  // be found when we scan later. +  if (!Scanned) +    return; + +  AssumeHandles.push_back(CI); + +#ifndef NDEBUG +  assert(CI->getParent() && +         "Cannot register @llvm.assume call not in a basic block"); +  assert(&F == CI->getParent()->getParent() && +         "Cannot register @llvm.assume call not in this function"); + +  // We expect the number of assumptions to be small, so in an asserts build +  // check that we don't accumulate duplicates and that all assumptions point +  // to the same function. +  SmallPtrSet<Value *, 16> AssumptionSet; +  for (auto &VH : AssumeHandles) { +    if (!VH) +      continue; + +    assert(&F == cast<Instruction>(VH)->getParent()->getParent() && +           "Cached assumption not inside this function!"); +    assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) && +           "Cached something other than a call to @llvm.assume!"); +    assert(AssumptionSet.insert(VH).second && +           "Cache contains multiple copies of a call!"); +  } +#endif + +  updateAffectedValues(CI); +} + +AnalysisKey AssumptionAnalysis::Key; + +PreservedAnalyses AssumptionPrinterPass::run(Function &F, +                                             FunctionAnalysisManager &AM) { +  AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); + +  OS << "Cached assumptions for function: " << F.getName() << "\n"; +  for (auto &VH : AC.assumptions()) +    if (VH) +      OS << "  " << *cast<CallInst>(VH)->getArgOperand(0) << "\n"; + +  return PreservedAnalyses::all(); +} + +void AssumptionCacheTracker::FunctionCallbackVH::deleted() { +  auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr())); +  if (I != ACT->AssumptionCaches.end()) +    ACT->AssumptionCaches.erase(I); +  // 'this' now dangles! +} + +AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) { +  // We probe the function map twice to try and avoid creating a value handle +  // around the function in common cases. This makes insertion a bit slower, +  // but if we have to insert we're going to scan the whole function so that +  // shouldn't matter. +  auto I = AssumptionCaches.find_as(&F); +  if (I != AssumptionCaches.end()) +    return *I->second; + +  // Ok, build a new cache by scanning the function, insert it and the value +  // handle into our map, and return the newly populated cache. +  auto IP = AssumptionCaches.insert(std::make_pair( +      FunctionCallbackVH(&F, this), llvm::make_unique<AssumptionCache>(F))); +  assert(IP.second && "Scanning function already in the map?"); +  return *IP.first->second; +} + +void AssumptionCacheTracker::verifyAnalysis() const { +  // FIXME: In the long term the verifier should not be controllable with a +  // flag. We should either fix all passes to correctly update the assumption +  // cache and enable the verifier unconditionally or somehow arrange for the +  // assumption list to be updated automatically by passes. +  if (!VerifyAssumptionCache) +    return; + +  SmallPtrSet<const CallInst *, 4> AssumptionSet; +  for (const auto &I : AssumptionCaches) { +    for (auto &VH : I.second->assumptions()) +      if (VH) +        AssumptionSet.insert(cast<CallInst>(VH)); + +    for (const BasicBlock &B : cast<Function>(*I.first)) +      for (const Instruction &II : B) +        if (match(&II, m_Intrinsic<Intrinsic::assume>()) && +            !AssumptionSet.count(cast<CallInst>(&II))) +          report_fatal_error("Assumption in scanned function not in cache"); +  } +} + +AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) { +  initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry()); +} + +AssumptionCacheTracker::~AssumptionCacheTracker() = default; + +char AssumptionCacheTracker::ID = 0; + +INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker", +                "Assumption Cache Tracker", false, true) diff --git a/contrib/llvm/lib/Analysis/BasicAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/BasicAliasAnalysis.cpp new file mode 100644 index 000000000000..f9ecbc043261 --- /dev/null +++ b/contrib/llvm/lib/Analysis/BasicAliasAnalysis.cpp @@ -0,0 +1,1974 @@ +//===- BasicAliasAnalysis.cpp - Stateless Alias Analysis Impl -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the primary stateless implementation of the +// Alias Analysis interface that implements identities (two different +// globals cannot alias, etc), but does no stateful analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/PhiValues.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/KnownBits.h" +#include <cassert> +#include <cstdint> +#include <cstdlib> +#include <utility> + +#define DEBUG_TYPE "basicaa" + +using namespace llvm; + +/// Enable analysis of recursive PHI nodes. +static cl::opt<bool> EnableRecPhiAnalysis("basicaa-recphi", cl::Hidden, +                                          cl::init(false)); +/// SearchLimitReached / SearchTimes shows how often the limit of +/// to decompose GEPs is reached. It will affect the precision +/// of basic alias analysis. +STATISTIC(SearchLimitReached, "Number of times the limit to " +                              "decompose GEPs is reached"); +STATISTIC(SearchTimes, "Number of times a GEP is decomposed"); + +/// Cutoff after which to stop analysing a set of phi nodes potentially involved +/// in a cycle. Because we are analysing 'through' phi nodes, we need to be +/// careful with value equivalence. We use reachability to make sure a value +/// cannot be involved in a cycle. +const unsigned MaxNumPhiBBsValueReachabilityCheck = 20; + +// The max limit of the search depth in DecomposeGEPExpression() and +// GetUnderlyingObject(), both functions need to use the same search +// depth otherwise the algorithm in aliasGEP will assert. +static const unsigned MaxLookupSearchDepth = 6; + +bool BasicAAResult::invalidate(Function &Fn, const PreservedAnalyses &PA, +                               FunctionAnalysisManager::Invalidator &Inv) { +  // We don't care if this analysis itself is preserved, it has no state. But +  // we need to check that the analyses it depends on have been. Note that we +  // may be created without handles to some analyses and in that case don't +  // depend on them. +  if (Inv.invalidate<AssumptionAnalysis>(Fn, PA) || +      (DT && Inv.invalidate<DominatorTreeAnalysis>(Fn, PA)) || +      (LI && Inv.invalidate<LoopAnalysis>(Fn, PA)) || +      (PV && Inv.invalidate<PhiValuesAnalysis>(Fn, PA))) +    return true; + +  // Otherwise this analysis result remains valid. +  return false; +} + +//===----------------------------------------------------------------------===// +// Useful predicates +//===----------------------------------------------------------------------===// + +/// Returns true if the pointer is to a function-local object that never +/// escapes from the function. +static bool isNonEscapingLocalObject(const Value *V) { +  // If this is a local allocation, check to see if it escapes. +  if (isa<AllocaInst>(V) || isNoAliasCall(V)) +    // Set StoreCaptures to True so that we can assume in our callers that the +    // pointer is not the result of a load instruction. Currently +    // PointerMayBeCaptured doesn't have any special analysis for the +    // StoreCaptures=false case; if it did, our callers could be refined to be +    // more precise. +    return !PointerMayBeCaptured(V, false, /*StoreCaptures=*/true); + +  // If this is an argument that corresponds to a byval or noalias argument, +  // then it has not escaped before entering the function.  Check if it escapes +  // inside the function. +  if (const Argument *A = dyn_cast<Argument>(V)) +    if (A->hasByValAttr() || A->hasNoAliasAttr()) +      // Note even if the argument is marked nocapture, we still need to check +      // for copies made inside the function. The nocapture attribute only +      // specifies that there are no copies made that outlive the function. +      return !PointerMayBeCaptured(V, false, /*StoreCaptures=*/true); + +  return false; +} + +/// Returns true if the pointer is one which would have been considered an +/// escape by isNonEscapingLocalObject. +static bool isEscapeSource(const Value *V) { +  if (ImmutableCallSite(V)) +    return true; + +  if (isa<Argument>(V)) +    return true; + +  // The load case works because isNonEscapingLocalObject considers all +  // stores to be escapes (it passes true for the StoreCaptures argument +  // to PointerMayBeCaptured). +  if (isa<LoadInst>(V)) +    return true; + +  return false; +} + +/// Returns the size of the object specified by V or UnknownSize if unknown. +static uint64_t getObjectSize(const Value *V, const DataLayout &DL, +                              const TargetLibraryInfo &TLI, +                              bool NullIsValidLoc, +                              bool RoundToAlign = false) { +  uint64_t Size; +  ObjectSizeOpts Opts; +  Opts.RoundToAlign = RoundToAlign; +  Opts.NullIsUnknownSize = NullIsValidLoc; +  if (getObjectSize(V, Size, DL, &TLI, Opts)) +    return Size; +  return MemoryLocation::UnknownSize; +} + +/// Returns true if we can prove that the object specified by V is smaller than +/// Size. +static bool isObjectSmallerThan(const Value *V, uint64_t Size, +                                const DataLayout &DL, +                                const TargetLibraryInfo &TLI, +                                bool NullIsValidLoc) { +  // Note that the meanings of the "object" are slightly different in the +  // following contexts: +  //    c1: llvm::getObjectSize() +  //    c2: llvm.objectsize() intrinsic +  //    c3: isObjectSmallerThan() +  // c1 and c2 share the same meaning; however, the meaning of "object" in c3 +  // refers to the "entire object". +  // +  //  Consider this example: +  //     char *p = (char*)malloc(100) +  //     char *q = p+80; +  // +  //  In the context of c1 and c2, the "object" pointed by q refers to the +  // stretch of memory of q[0:19]. So, getObjectSize(q) should return 20. +  // +  //  However, in the context of c3, the "object" refers to the chunk of memory +  // being allocated. So, the "object" has 100 bytes, and q points to the middle +  // the "object". In case q is passed to isObjectSmallerThan() as the 1st +  // parameter, before the llvm::getObjectSize() is called to get the size of +  // entire object, we should: +  //    - either rewind the pointer q to the base-address of the object in +  //      question (in this case rewind to p), or +  //    - just give up. It is up to caller to make sure the pointer is pointing +  //      to the base address the object. +  // +  // We go for 2nd option for simplicity. +  if (!isIdentifiedObject(V)) +    return false; + +  // This function needs to use the aligned object size because we allow +  // reads a bit past the end given sufficient alignment. +  uint64_t ObjectSize = getObjectSize(V, DL, TLI, NullIsValidLoc, +                                      /*RoundToAlign*/ true); + +  return ObjectSize != MemoryLocation::UnknownSize && ObjectSize < Size; +} + +/// Returns true if we can prove that the object specified by V has size Size. +static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, +                         const TargetLibraryInfo &TLI, bool NullIsValidLoc) { +  uint64_t ObjectSize = getObjectSize(V, DL, TLI, NullIsValidLoc); +  return ObjectSize != MemoryLocation::UnknownSize && ObjectSize == Size; +} + +//===----------------------------------------------------------------------===// +// GetElementPtr Instruction Decomposition and Analysis +//===----------------------------------------------------------------------===// + +/// Analyzes the specified value as a linear expression: "A*V + B", where A and +/// B are constant integers. +/// +/// Returns the scale and offset values as APInts and return V as a Value*, and +/// return whether we looked through any sign or zero extends.  The incoming +/// Value is known to have IntegerType, and it may already be sign or zero +/// extended. +/// +/// Note that this looks through extends, so the high bits may not be +/// represented in the result. +/*static*/ const Value *BasicAAResult::GetLinearExpression( +    const Value *V, APInt &Scale, APInt &Offset, unsigned &ZExtBits, +    unsigned &SExtBits, const DataLayout &DL, unsigned Depth, +    AssumptionCache *AC, DominatorTree *DT, bool &NSW, bool &NUW) { +  assert(V->getType()->isIntegerTy() && "Not an integer value"); + +  // Limit our recursion depth. +  if (Depth == 6) { +    Scale = 1; +    Offset = 0; +    return V; +  } + +  if (const ConstantInt *Const = dyn_cast<ConstantInt>(V)) { +    // If it's a constant, just convert it to an offset and remove the variable. +    // If we've been called recursively, the Offset bit width will be greater +    // than the constant's (the Offset's always as wide as the outermost call), +    // so we'll zext here and process any extension in the isa<SExtInst> & +    // isa<ZExtInst> cases below. +    Offset += Const->getValue().zextOrSelf(Offset.getBitWidth()); +    assert(Scale == 0 && "Constant values don't have a scale"); +    return V; +  } + +  if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(V)) { +    if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) { +      // If we've been called recursively, then Offset and Scale will be wider +      // than the BOp operands. We'll always zext it here as we'll process sign +      // extensions below (see the isa<SExtInst> / isa<ZExtInst> cases). +      APInt RHS = RHSC->getValue().zextOrSelf(Offset.getBitWidth()); + +      switch (BOp->getOpcode()) { +      default: +        // We don't understand this instruction, so we can't decompose it any +        // further. +        Scale = 1; +        Offset = 0; +        return V; +      case Instruction::Or: +        // X|C == X+C if all the bits in C are unset in X.  Otherwise we can't +        // analyze it. +        if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), DL, 0, AC, +                               BOp, DT)) { +          Scale = 1; +          Offset = 0; +          return V; +        } +        LLVM_FALLTHROUGH; +      case Instruction::Add: +        V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, +                                SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); +        Offset += RHS; +        break; +      case Instruction::Sub: +        V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, +                                SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); +        Offset -= RHS; +        break; +      case Instruction::Mul: +        V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, +                                SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); +        Offset *= RHS; +        Scale *= RHS; +        break; +      case Instruction::Shl: +        V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, ZExtBits, +                                SExtBits, DL, Depth + 1, AC, DT, NSW, NUW); + +        // We're trying to linearize an expression of the kind: +        //   shl i8 -128, 36 +        // where the shift count exceeds the bitwidth of the type. +        // We can't decompose this further (the expression would return +        // a poison value). +        if (Offset.getBitWidth() < RHS.getLimitedValue() || +            Scale.getBitWidth() < RHS.getLimitedValue()) { +          Scale = 1; +          Offset = 0; +          return V; +        } + +        Offset <<= RHS.getLimitedValue(); +        Scale <<= RHS.getLimitedValue(); +        // the semantics of nsw and nuw for left shifts don't match those of +        // multiplications, so we won't propagate them. +        NSW = NUW = false; +        return V; +      } + +      if (isa<OverflowingBinaryOperator>(BOp)) { +        NUW &= BOp->hasNoUnsignedWrap(); +        NSW &= BOp->hasNoSignedWrap(); +      } +      return V; +    } +  } + +  // Since GEP indices are sign extended anyway, we don't care about the high +  // bits of a sign or zero extended value - just scales and offsets.  The +  // extensions have to be consistent though. +  if (isa<SExtInst>(V) || isa<ZExtInst>(V)) { +    Value *CastOp = cast<CastInst>(V)->getOperand(0); +    unsigned NewWidth = V->getType()->getPrimitiveSizeInBits(); +    unsigned SmallWidth = CastOp->getType()->getPrimitiveSizeInBits(); +    unsigned OldZExtBits = ZExtBits, OldSExtBits = SExtBits; +    const Value *Result = +        GetLinearExpression(CastOp, Scale, Offset, ZExtBits, SExtBits, DL, +                            Depth + 1, AC, DT, NSW, NUW); + +    // zext(zext(%x)) == zext(%x), and similarly for sext; we'll handle this +    // by just incrementing the number of bits we've extended by. +    unsigned ExtendedBy = NewWidth - SmallWidth; + +    if (isa<SExtInst>(V) && ZExtBits == 0) { +      // sext(sext(%x, a), b) == sext(%x, a + b) + +      if (NSW) { +        // We haven't sign-wrapped, so it's valid to decompose sext(%x + c) +        // into sext(%x) + sext(c). We'll sext the Offset ourselves: +        unsigned OldWidth = Offset.getBitWidth(); +        Offset = Offset.trunc(SmallWidth).sext(NewWidth).zextOrSelf(OldWidth); +      } else { +        // We may have signed-wrapped, so don't decompose sext(%x + c) into +        // sext(%x) + sext(c) +        Scale = 1; +        Offset = 0; +        Result = CastOp; +        ZExtBits = OldZExtBits; +        SExtBits = OldSExtBits; +      } +      SExtBits += ExtendedBy; +    } else { +      // sext(zext(%x, a), b) = zext(zext(%x, a), b) = zext(%x, a + b) + +      if (!NUW) { +        // We may have unsigned-wrapped, so don't decompose zext(%x + c) into +        // zext(%x) + zext(c) +        Scale = 1; +        Offset = 0; +        Result = CastOp; +        ZExtBits = OldZExtBits; +        SExtBits = OldSExtBits; +      } +      ZExtBits += ExtendedBy; +    } + +    return Result; +  } + +  Scale = 1; +  Offset = 0; +  return V; +} + +/// To ensure a pointer offset fits in an integer of size PointerSize +/// (in bits) when that size is smaller than 64. This is an issue in +/// particular for 32b programs with negative indices that rely on two's +/// complement wrap-arounds for precise alias information. +static int64_t adjustToPointerSize(int64_t Offset, unsigned PointerSize) { +  assert(PointerSize <= 64 && "Invalid PointerSize!"); +  unsigned ShiftBits = 64 - PointerSize; +  return (int64_t)((uint64_t)Offset << ShiftBits) >> ShiftBits; +} + +/// If V is a symbolic pointer expression, decompose it into a base pointer +/// with a constant offset and a number of scaled symbolic offsets. +/// +/// The scaled symbolic offsets (represented by pairs of a Value* and a scale +/// in the VarIndices vector) are Value*'s that are known to be scaled by the +/// specified amount, but which may have other unrepresented high bits. As +/// such, the gep cannot necessarily be reconstructed from its decomposed form. +/// +/// When DataLayout is around, this function is capable of analyzing everything +/// that GetUnderlyingObject can look through. To be able to do that +/// GetUnderlyingObject and DecomposeGEPExpression must use the same search +/// depth (MaxLookupSearchDepth). When DataLayout not is around, it just looks +/// through pointer casts. +bool BasicAAResult::DecomposeGEPExpression(const Value *V, +       DecomposedGEP &Decomposed, const DataLayout &DL, AssumptionCache *AC, +       DominatorTree *DT) { +  // Limit recursion depth to limit compile time in crazy cases. +  unsigned MaxLookup = MaxLookupSearchDepth; +  SearchTimes++; + +  Decomposed.StructOffset = 0; +  Decomposed.OtherOffset = 0; +  Decomposed.VarIndices.clear(); +  do { +    // See if this is a bitcast or GEP. +    const Operator *Op = dyn_cast<Operator>(V); +    if (!Op) { +      // The only non-operator case we can handle are GlobalAliases. +      if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { +        if (!GA->isInterposable()) { +          V = GA->getAliasee(); +          continue; +        } +      } +      Decomposed.Base = V; +      return false; +    } + +    if (Op->getOpcode() == Instruction::BitCast || +        Op->getOpcode() == Instruction::AddrSpaceCast) { +      V = Op->getOperand(0); +      continue; +    } + +    const GEPOperator *GEPOp = dyn_cast<GEPOperator>(Op); +    if (!GEPOp) { +      if (auto CS = ImmutableCallSite(V)) { +        // CaptureTracking can know about special capturing properties of some +        // intrinsics like launder.invariant.group, that can't be expressed with +        // the attributes, but have properties like returning aliasing pointer. +        // Because some analysis may assume that nocaptured pointer is not +        // returned from some special intrinsic (because function would have to +        // be marked with returns attribute), it is crucial to use this function +        // because it should be in sync with CaptureTracking. Not using it may +        // cause weird miscompilations where 2 aliasing pointers are assumed to +        // noalias. +        if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) { +          V = RP; +          continue; +        } +      } + +      // If it's not a GEP, hand it off to SimplifyInstruction to see if it +      // can come up with something. This matches what GetUnderlyingObject does. +      if (const Instruction *I = dyn_cast<Instruction>(V)) +        // TODO: Get a DominatorTree and AssumptionCache and use them here +        // (these are both now available in this function, but this should be +        // updated when GetUnderlyingObject is updated). TLI should be +        // provided also. +        if (const Value *Simplified = +                SimplifyInstruction(const_cast<Instruction *>(I), DL)) { +          V = Simplified; +          continue; +        } + +      Decomposed.Base = V; +      return false; +    } + +    // Don't attempt to analyze GEPs over unsized objects. +    if (!GEPOp->getSourceElementType()->isSized()) { +      Decomposed.Base = V; +      return false; +    } + +    unsigned AS = GEPOp->getPointerAddressSpace(); +    // Walk the indices of the GEP, accumulating them into BaseOff/VarIndices. +    gep_type_iterator GTI = gep_type_begin(GEPOp); +    unsigned PointerSize = DL.getPointerSizeInBits(AS); +    // Assume all GEP operands are constants until proven otherwise. +    bool GepHasConstantOffset = true; +    for (User::const_op_iterator I = GEPOp->op_begin() + 1, E = GEPOp->op_end(); +         I != E; ++I, ++GTI) { +      const Value *Index = *I; +      // Compute the (potentially symbolic) offset in bytes for this index. +      if (StructType *STy = GTI.getStructTypeOrNull()) { +        // For a struct, add the member offset. +        unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue(); +        if (FieldNo == 0) +          continue; + +        Decomposed.StructOffset += +          DL.getStructLayout(STy)->getElementOffset(FieldNo); +        continue; +      } + +      // For an array/pointer, add the element offset, explicitly scaled. +      if (const ConstantInt *CIdx = dyn_cast<ConstantInt>(Index)) { +        if (CIdx->isZero()) +          continue; +        Decomposed.OtherOffset += +          DL.getTypeAllocSize(GTI.getIndexedType()) * CIdx->getSExtValue(); +        continue; +      } + +      GepHasConstantOffset = false; + +      uint64_t Scale = DL.getTypeAllocSize(GTI.getIndexedType()); +      unsigned ZExtBits = 0, SExtBits = 0; + +      // If the integer type is smaller than the pointer size, it is implicitly +      // sign extended to pointer size. +      unsigned Width = Index->getType()->getIntegerBitWidth(); +      if (PointerSize > Width) +        SExtBits += PointerSize - Width; + +      // Use GetLinearExpression to decompose the index into a C1*V+C2 form. +      APInt IndexScale(Width, 0), IndexOffset(Width, 0); +      bool NSW = true, NUW = true; +      Index = GetLinearExpression(Index, IndexScale, IndexOffset, ZExtBits, +                                  SExtBits, DL, 0, AC, DT, NSW, NUW); + +      // All GEP math happens in the width of the pointer type, +      // so we can truncate the value to 64-bits as we don't handle +      // currently pointers larger than 64 bits and we would crash +      // later. TODO: Make `Scale` an APInt to avoid this problem. +      if (IndexScale.getBitWidth() > 64) +        IndexScale = IndexScale.sextOrTrunc(64); + +      // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale. +      // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale. +      Decomposed.OtherOffset += IndexOffset.getSExtValue() * Scale; +      Scale *= IndexScale.getSExtValue(); + +      // If we already had an occurrence of this index variable, merge this +      // scale into it.  For example, we want to handle: +      //   A[x][x] -> x*16 + x*4 -> x*20 +      // This also ensures that 'x' only appears in the index list once. +      for (unsigned i = 0, e = Decomposed.VarIndices.size(); i != e; ++i) { +        if (Decomposed.VarIndices[i].V == Index && +            Decomposed.VarIndices[i].ZExtBits == ZExtBits && +            Decomposed.VarIndices[i].SExtBits == SExtBits) { +          Scale += Decomposed.VarIndices[i].Scale; +          Decomposed.VarIndices.erase(Decomposed.VarIndices.begin() + i); +          break; +        } +      } + +      // Make sure that we have a scale that makes sense for this target's +      // pointer size. +      Scale = adjustToPointerSize(Scale, PointerSize); + +      if (Scale) { +        VariableGEPIndex Entry = {Index, ZExtBits, SExtBits, +                                  static_cast<int64_t>(Scale)}; +        Decomposed.VarIndices.push_back(Entry); +      } +    } + +    // Take care of wrap-arounds +    if (GepHasConstantOffset) { +      Decomposed.StructOffset = +          adjustToPointerSize(Decomposed.StructOffset, PointerSize); +      Decomposed.OtherOffset = +          adjustToPointerSize(Decomposed.OtherOffset, PointerSize); +    } + +    // Analyze the base pointer next. +    V = GEPOp->getOperand(0); +  } while (--MaxLookup); + +  // If the chain of expressions is too deep, just return early. +  Decomposed.Base = V; +  SearchLimitReached++; +  return true; +} + +/// Returns whether the given pointer value points to memory that is local to +/// the function, with global constants being considered local to all +/// functions. +bool BasicAAResult::pointsToConstantMemory(const MemoryLocation &Loc, +                                           bool OrLocal) { +  assert(Visited.empty() && "Visited must be cleared after use!"); + +  unsigned MaxLookup = 8; +  SmallVector<const Value *, 16> Worklist; +  Worklist.push_back(Loc.Ptr); +  do { +    const Value *V = GetUnderlyingObject(Worklist.pop_back_val(), DL); +    if (!Visited.insert(V).second) { +      Visited.clear(); +      return AAResultBase::pointsToConstantMemory(Loc, OrLocal); +    } + +    // An alloca instruction defines local memory. +    if (OrLocal && isa<AllocaInst>(V)) +      continue; + +    // A global constant counts as local memory for our purposes. +    if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) { +      // Note: this doesn't require GV to be "ODR" because it isn't legal for a +      // global to be marked constant in some modules and non-constant in +      // others.  GV may even be a declaration, not a definition. +      if (!GV->isConstant()) { +        Visited.clear(); +        return AAResultBase::pointsToConstantMemory(Loc, OrLocal); +      } +      continue; +    } + +    // If both select values point to local memory, then so does the select. +    if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { +      Worklist.push_back(SI->getTrueValue()); +      Worklist.push_back(SI->getFalseValue()); +      continue; +    } + +    // If all values incoming to a phi node point to local memory, then so does +    // the phi. +    if (const PHINode *PN = dyn_cast<PHINode>(V)) { +      // Don't bother inspecting phi nodes with many operands. +      if (PN->getNumIncomingValues() > MaxLookup) { +        Visited.clear(); +        return AAResultBase::pointsToConstantMemory(Loc, OrLocal); +      } +      for (Value *IncValue : PN->incoming_values()) +        Worklist.push_back(IncValue); +      continue; +    } + +    // Otherwise be conservative. +    Visited.clear(); +    return AAResultBase::pointsToConstantMemory(Loc, OrLocal); +  } while (!Worklist.empty() && --MaxLookup); + +  Visited.clear(); +  return Worklist.empty(); +} + +/// Returns the behavior when calling the given call site. +FunctionModRefBehavior BasicAAResult::getModRefBehavior(ImmutableCallSite CS) { +  if (CS.doesNotAccessMemory()) +    // Can't do better than this. +    return FMRB_DoesNotAccessMemory; + +  FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + +  // If the callsite knows it only reads memory, don't return worse +  // than that. +  if (CS.onlyReadsMemory()) +    Min = FMRB_OnlyReadsMemory; +  else if (CS.doesNotReadMemory()) +    Min = FMRB_DoesNotReadMemory; + +  if (CS.onlyAccessesArgMemory()) +    Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees); +  else if (CS.onlyAccessesInaccessibleMemory()) +    Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleMem); +  else if (CS.onlyAccessesInaccessibleMemOrArgMem()) +    Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleOrArgMem); + +  // If CS has operand bundles then aliasing attributes from the function it +  // calls do not directly apply to the CallSite.  This can be made more +  // precise in the future. +  if (!CS.hasOperandBundles()) +    if (const Function *F = CS.getCalledFunction()) +      Min = +          FunctionModRefBehavior(Min & getBestAAResults().getModRefBehavior(F)); + +  return Min; +} + +/// Returns the behavior when calling the given function. For use when the call +/// site is not known. +FunctionModRefBehavior BasicAAResult::getModRefBehavior(const Function *F) { +  // If the function declares it doesn't access memory, we can't do better. +  if (F->doesNotAccessMemory()) +    return FMRB_DoesNotAccessMemory; + +  FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + +  // If the function declares it only reads memory, go with that. +  if (F->onlyReadsMemory()) +    Min = FMRB_OnlyReadsMemory; +  else if (F->doesNotReadMemory()) +    Min = FMRB_DoesNotReadMemory; + +  if (F->onlyAccessesArgMemory()) +    Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees); +  else if (F->onlyAccessesInaccessibleMemory()) +    Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleMem); +  else if (F->onlyAccessesInaccessibleMemOrArgMem()) +    Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleOrArgMem); + +  return Min; +} + +/// Returns true if this is a writeonly (i.e Mod only) parameter. +static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, +                             const TargetLibraryInfo &TLI) { +  if (CS.paramHasAttr(ArgIdx, Attribute::WriteOnly)) +    return true; + +  // We can bound the aliasing properties of memset_pattern16 just as we can +  // for memcpy/memset.  This is particularly important because the +  // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 +  // whenever possible. +  // FIXME Consider handling this in InferFunctionAttr.cpp together with other +  // attributes. +  LibFunc F; +  if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && +      F == LibFunc_memset_pattern16 && TLI.has(F)) +    if (ArgIdx == 0) +      return true; + +  // TODO: memset_pattern4, memset_pattern8 +  // TODO: _chk variants +  // TODO: strcmp, strcpy + +  return false; +} + +ModRefInfo BasicAAResult::getArgModRefInfo(ImmutableCallSite CS, +                                           unsigned ArgIdx) { +  // Checking for known builtin intrinsics and target library functions. +  if (isWriteOnlyParam(CS, ArgIdx, TLI)) +    return ModRefInfo::Mod; + +  if (CS.paramHasAttr(ArgIdx, Attribute::ReadOnly)) +    return ModRefInfo::Ref; + +  if (CS.paramHasAttr(ArgIdx, Attribute::ReadNone)) +    return ModRefInfo::NoModRef; + +  return AAResultBase::getArgModRefInfo(CS, ArgIdx); +} + +static bool isIntrinsicCall(ImmutableCallSite CS, Intrinsic::ID IID) { +  const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction()); +  return II && II->getIntrinsicID() == IID; +} + +#ifndef NDEBUG +static const Function *getParent(const Value *V) { +  if (const Instruction *inst = dyn_cast<Instruction>(V)) { +    if (!inst->getParent()) +      return nullptr; +    return inst->getParent()->getParent(); +  } + +  if (const Argument *arg = dyn_cast<Argument>(V)) +    return arg->getParent(); + +  return nullptr; +} + +static bool notDifferentParent(const Value *O1, const Value *O2) { + +  const Function *F1 = getParent(O1); +  const Function *F2 = getParent(O2); + +  return !F1 || !F2 || F1 == F2; +} +#endif + +AliasResult BasicAAResult::alias(const MemoryLocation &LocA, +                                 const MemoryLocation &LocB) { +  assert(notDifferentParent(LocA.Ptr, LocB.Ptr) && +         "BasicAliasAnalysis doesn't support interprocedural queries."); + +  // If we have a directly cached entry for these locations, we have recursed +  // through this once, so just return the cached results. Notably, when this +  // happens, we don't clear the cache. +  auto CacheIt = AliasCache.find(LocPair(LocA, LocB)); +  if (CacheIt != AliasCache.end()) +    return CacheIt->second; + +  AliasResult Alias = aliasCheck(LocA.Ptr, LocA.Size, LocA.AATags, LocB.Ptr, +                                 LocB.Size, LocB.AATags); +  // AliasCache rarely has more than 1 or 2 elements, always use +  // shrink_and_clear so it quickly returns to the inline capacity of the +  // SmallDenseMap if it ever grows larger. +  // FIXME: This should really be shrink_to_inline_capacity_and_clear(). +  AliasCache.shrink_and_clear(); +  VisitedPhiBBs.clear(); +  return Alias; +} + +/// Checks to see if the specified callsite can clobber the specified memory +/// object. +/// +/// Since we only look at local properties of this function, we really can't +/// say much about this query.  We do, however, use simple "address taken" +/// analysis on local objects. +ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, +                                        const MemoryLocation &Loc) { +  assert(notDifferentParent(CS.getInstruction(), Loc.Ptr) && +         "AliasAnalysis query involving multiple functions!"); + +  const Value *Object = GetUnderlyingObject(Loc.Ptr, DL); + +  // Calls marked 'tail' cannot read or write allocas from the current frame +  // because the current frame might be destroyed by the time they run. However, +  // a tail call may use an alloca with byval. Calling with byval copies the +  // contents of the alloca into argument registers or stack slots, so there is +  // no lifetime issue. +  if (isa<AllocaInst>(Object)) +    if (const CallInst *CI = dyn_cast<CallInst>(CS.getInstruction())) +      if (CI->isTailCall() && +          !CI->getAttributes().hasAttrSomewhere(Attribute::ByVal)) +        return ModRefInfo::NoModRef; + +  // If the pointer is to a locally allocated object that does not escape, +  // then the call can not mod/ref the pointer unless the call takes the pointer +  // as an argument, and itself doesn't capture it. +  if (!isa<Constant>(Object) && CS.getInstruction() != Object && +      isNonEscapingLocalObject(Object)) { + +    // Optimistically assume that call doesn't touch Object and check this +    // assumption in the following loop. +    ModRefInfo Result = ModRefInfo::NoModRef; +    bool IsMustAlias = true; + +    unsigned OperandNo = 0; +    for (auto CI = CS.data_operands_begin(), CE = CS.data_operands_end(); +         CI != CE; ++CI, ++OperandNo) { +      // Only look at the no-capture or byval pointer arguments.  If this +      // pointer were passed to arguments that were neither of these, then it +      // couldn't be no-capture. +      if (!(*CI)->getType()->isPointerTy() || +          (!CS.doesNotCapture(OperandNo) && +           OperandNo < CS.getNumArgOperands() && !CS.isByValArgument(OperandNo))) +        continue; + +      // Call doesn't access memory through this operand, so we don't care +      // if it aliases with Object. +      if (CS.doesNotAccessMemory(OperandNo)) +        continue; + +      // If this is a no-capture pointer argument, see if we can tell that it +      // is impossible to alias the pointer we're checking. +      AliasResult AR = +          getBestAAResults().alias(MemoryLocation(*CI), MemoryLocation(Object)); +      if (AR != MustAlias) +        IsMustAlias = false; +      // Operand doesnt alias 'Object', continue looking for other aliases +      if (AR == NoAlias) +        continue; +      // Operand aliases 'Object', but call doesn't modify it. Strengthen +      // initial assumption and keep looking in case if there are more aliases. +      if (CS.onlyReadsMemory(OperandNo)) { +        Result = setRef(Result); +        continue; +      } +      // Operand aliases 'Object' but call only writes into it. +      if (CS.doesNotReadMemory(OperandNo)) { +        Result = setMod(Result); +        continue; +      } +      // This operand aliases 'Object' and call reads and writes into it. +      // Setting ModRef will not yield an early return below, MustAlias is not +      // used further. +      Result = ModRefInfo::ModRef; +      break; +    } + +    // No operand aliases, reset Must bit. Add below if at least one aliases +    // and all aliases found are MustAlias. +    if (isNoModRef(Result)) +      IsMustAlias = false; + +    // Early return if we improved mod ref information +    if (!isModAndRefSet(Result)) { +      if (isNoModRef(Result)) +        return ModRefInfo::NoModRef; +      return IsMustAlias ? setMust(Result) : clearMust(Result); +    } +  } + +  // If the CallSite is to malloc or calloc, we can assume that it doesn't +  // modify any IR visible value.  This is only valid because we assume these +  // routines do not read values visible in the IR.  TODO: Consider special +  // casing realloc and strdup routines which access only their arguments as +  // well.  Or alternatively, replace all of this with inaccessiblememonly once +  // that's implemented fully. +  auto *Inst = CS.getInstruction(); +  if (isMallocOrCallocLikeFn(Inst, &TLI)) { +    // Be conservative if the accessed pointer may alias the allocation - +    // fallback to the generic handling below. +    if (getBestAAResults().alias(MemoryLocation(Inst), Loc) == NoAlias) +      return ModRefInfo::NoModRef; +  } + +  // The semantics of memcpy intrinsics forbid overlap between their respective +  // operands, i.e., source and destination of any given memcpy must no-alias. +  // If Loc must-aliases either one of these two locations, then it necessarily +  // no-aliases the other. +  if (auto *Inst = dyn_cast<AnyMemCpyInst>(CS.getInstruction())) { +    AliasResult SrcAA, DestAA; + +    if ((SrcAA = getBestAAResults().alias(MemoryLocation::getForSource(Inst), +                                          Loc)) == MustAlias) +      // Loc is exactly the memcpy source thus disjoint from memcpy dest. +      return ModRefInfo::Ref; +    if ((DestAA = getBestAAResults().alias(MemoryLocation::getForDest(Inst), +                                           Loc)) == MustAlias) +      // The converse case. +      return ModRefInfo::Mod; + +    // It's also possible for Loc to alias both src and dest, or neither. +    ModRefInfo rv = ModRefInfo::NoModRef; +    if (SrcAA != NoAlias) +      rv = setRef(rv); +    if (DestAA != NoAlias) +      rv = setMod(rv); +    return rv; +  } + +  // While the assume intrinsic is marked as arbitrarily writing so that +  // proper control dependencies will be maintained, it never aliases any +  // particular memory location. +  if (isIntrinsicCall(CS, Intrinsic::assume)) +    return ModRefInfo::NoModRef; + +  // Like assumes, guard intrinsics are also marked as arbitrarily writing so +  // that proper control dependencies are maintained but they never mods any +  // particular memory location. +  // +  // *Unlike* assumes, guard intrinsics are modeled as reading memory since the +  // heap state at the point the guard is issued needs to be consistent in case +  // the guard invokes the "deopt" continuation. +  if (isIntrinsicCall(CS, Intrinsic::experimental_guard)) +    return ModRefInfo::Ref; + +  // Like assumes, invariant.start intrinsics were also marked as arbitrarily +  // writing so that proper control dependencies are maintained but they never +  // mod any particular memory location visible to the IR. +  // *Unlike* assumes (which are now modeled as NoModRef), invariant.start +  // intrinsic is now modeled as reading memory. This prevents hoisting the +  // invariant.start intrinsic over stores. Consider: +  // *ptr = 40; +  // *ptr = 50; +  // invariant_start(ptr) +  // int val = *ptr; +  // print(val); +  // +  // This cannot be transformed to: +  // +  // *ptr = 40; +  // invariant_start(ptr) +  // *ptr = 50; +  // int val = *ptr; +  // print(val); +  // +  // The transformation will cause the second store to be ignored (based on +  // rules of invariant.start)  and print 40, while the first program always +  // prints 50. +  if (isIntrinsicCall(CS, Intrinsic::invariant_start)) +    return ModRefInfo::Ref; + +  // The AAResultBase base class has some smarts, lets use them. +  return AAResultBase::getModRefInfo(CS, Loc); +} + +ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS1, +                                        ImmutableCallSite CS2) { +  // While the assume intrinsic is marked as arbitrarily writing so that +  // proper control dependencies will be maintained, it never aliases any +  // particular memory location. +  if (isIntrinsicCall(CS1, Intrinsic::assume) || +      isIntrinsicCall(CS2, Intrinsic::assume)) +    return ModRefInfo::NoModRef; + +  // Like assumes, guard intrinsics are also marked as arbitrarily writing so +  // that proper control dependencies are maintained but they never mod any +  // particular memory location. +  // +  // *Unlike* assumes, guard intrinsics are modeled as reading memory since the +  // heap state at the point the guard is issued needs to be consistent in case +  // the guard invokes the "deopt" continuation. + +  // NB! This function is *not* commutative, so we specical case two +  // possibilities for guard intrinsics. + +  if (isIntrinsicCall(CS1, Intrinsic::experimental_guard)) +    return isModSet(createModRefInfo(getModRefBehavior(CS2))) +               ? ModRefInfo::Ref +               : ModRefInfo::NoModRef; + +  if (isIntrinsicCall(CS2, Intrinsic::experimental_guard)) +    return isModSet(createModRefInfo(getModRefBehavior(CS1))) +               ? ModRefInfo::Mod +               : ModRefInfo::NoModRef; + +  // The AAResultBase base class has some smarts, lets use them. +  return AAResultBase::getModRefInfo(CS1, CS2); +} + +/// Provide ad-hoc rules to disambiguate accesses through two GEP operators, +/// both having the exact same pointer operand. +static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, +                                            LocationSize V1Size, +                                            const GEPOperator *GEP2, +                                            LocationSize V2Size, +                                            const DataLayout &DL) { +  assert(GEP1->getPointerOperand()->stripPointerCastsAndInvariantGroups() == +             GEP2->getPointerOperand()->stripPointerCastsAndInvariantGroups() && +         GEP1->getPointerOperandType() == GEP2->getPointerOperandType() && +         "Expected GEPs with the same pointer operand"); + +  // Try to determine whether GEP1 and GEP2 index through arrays, into structs, +  // such that the struct field accesses provably cannot alias. +  // We also need at least two indices (the pointer, and the struct field). +  if (GEP1->getNumIndices() != GEP2->getNumIndices() || +      GEP1->getNumIndices() < 2) +    return MayAlias; + +  // If we don't know the size of the accesses through both GEPs, we can't +  // determine whether the struct fields accessed can't alias. +  if (V1Size == MemoryLocation::UnknownSize || +      V2Size == MemoryLocation::UnknownSize) +    return MayAlias; + +  ConstantInt *C1 = +      dyn_cast<ConstantInt>(GEP1->getOperand(GEP1->getNumOperands() - 1)); +  ConstantInt *C2 = +      dyn_cast<ConstantInt>(GEP2->getOperand(GEP2->getNumOperands() - 1)); + +  // If the last (struct) indices are constants and are equal, the other indices +  // might be also be dynamically equal, so the GEPs can alias. +  if (C1 && C2 && C1->getSExtValue() == C2->getSExtValue()) +    return MayAlias; + +  // Find the last-indexed type of the GEP, i.e., the type you'd get if +  // you stripped the last index. +  // On the way, look at each indexed type.  If there's something other +  // than an array, different indices can lead to different final types. +  SmallVector<Value *, 8> IntermediateIndices; + +  // Insert the first index; we don't need to check the type indexed +  // through it as it only drops the pointer indirection. +  assert(GEP1->getNumIndices() > 1 && "Not enough GEP indices to examine"); +  IntermediateIndices.push_back(GEP1->getOperand(1)); + +  // Insert all the remaining indices but the last one. +  // Also, check that they all index through arrays. +  for (unsigned i = 1, e = GEP1->getNumIndices() - 1; i != e; ++i) { +    if (!isa<ArrayType>(GetElementPtrInst::getIndexedType( +            GEP1->getSourceElementType(), IntermediateIndices))) +      return MayAlias; +    IntermediateIndices.push_back(GEP1->getOperand(i + 1)); +  } + +  auto *Ty = GetElementPtrInst::getIndexedType( +    GEP1->getSourceElementType(), IntermediateIndices); +  StructType *LastIndexedStruct = dyn_cast<StructType>(Ty); + +  if (isa<SequentialType>(Ty)) { +    // We know that: +    // - both GEPs begin indexing from the exact same pointer; +    // - the last indices in both GEPs are constants, indexing into a sequential +    //   type (array or pointer); +    // - both GEPs only index through arrays prior to that. +    // +    // Because array indices greater than the number of elements are valid in +    // GEPs, unless we know the intermediate indices are identical between +    // GEP1 and GEP2 we cannot guarantee that the last indexed arrays don't +    // partially overlap. We also need to check that the loaded size matches +    // the element size, otherwise we could still have overlap. +    const uint64_t ElementSize = +        DL.getTypeStoreSize(cast<SequentialType>(Ty)->getElementType()); +    if (V1Size != ElementSize || V2Size != ElementSize) +      return MayAlias; + +    for (unsigned i = 0, e = GEP1->getNumIndices() - 1; i != e; ++i) +      if (GEP1->getOperand(i + 1) != GEP2->getOperand(i + 1)) +        return MayAlias; + +    // Now we know that the array/pointer that GEP1 indexes into and that +    // that GEP2 indexes into must either precisely overlap or be disjoint. +    // Because they cannot partially overlap and because fields in an array +    // cannot overlap, if we can prove the final indices are different between +    // GEP1 and GEP2, we can conclude GEP1 and GEP2 don't alias. + +    // If the last indices are constants, we've already checked they don't +    // equal each other so we can exit early. +    if (C1 && C2) +      return NoAlias; +    { +      Value *GEP1LastIdx = GEP1->getOperand(GEP1->getNumOperands() - 1); +      Value *GEP2LastIdx = GEP2->getOperand(GEP2->getNumOperands() - 1); +      if (isa<PHINode>(GEP1LastIdx) || isa<PHINode>(GEP2LastIdx)) { +        // If one of the indices is a PHI node, be safe and only use +        // computeKnownBits so we don't make any assumptions about the +        // relationships between the two indices. This is important if we're +        // asking about values from different loop iterations. See PR32314. +        // TODO: We may be able to change the check so we only do this when +        // we definitely looked through a PHINode. +        if (GEP1LastIdx != GEP2LastIdx && +            GEP1LastIdx->getType() == GEP2LastIdx->getType()) { +          KnownBits Known1 = computeKnownBits(GEP1LastIdx, DL); +          KnownBits Known2 = computeKnownBits(GEP2LastIdx, DL); +          if (Known1.Zero.intersects(Known2.One) || +              Known1.One.intersects(Known2.Zero)) +            return NoAlias; +        } +      } else if (isKnownNonEqual(GEP1LastIdx, GEP2LastIdx, DL)) +        return NoAlias; +    } +    return MayAlias; +  } else if (!LastIndexedStruct || !C1 || !C2) { +    return MayAlias; +  } + +  // We know that: +  // - both GEPs begin indexing from the exact same pointer; +  // - the last indices in both GEPs are constants, indexing into a struct; +  // - said indices are different, hence, the pointed-to fields are different; +  // - both GEPs only index through arrays prior to that. +  // +  // This lets us determine that the struct that GEP1 indexes into and the +  // struct that GEP2 indexes into must either precisely overlap or be +  // completely disjoint.  Because they cannot partially overlap, indexing into +  // different non-overlapping fields of the struct will never alias. + +  // Therefore, the only remaining thing needed to show that both GEPs can't +  // alias is that the fields are not overlapping. +  const StructLayout *SL = DL.getStructLayout(LastIndexedStruct); +  const uint64_t StructSize = SL->getSizeInBytes(); +  const uint64_t V1Off = SL->getElementOffset(C1->getZExtValue()); +  const uint64_t V2Off = SL->getElementOffset(C2->getZExtValue()); + +  auto EltsDontOverlap = [StructSize](uint64_t V1Off, uint64_t V1Size, +                                      uint64_t V2Off, uint64_t V2Size) { +    return V1Off < V2Off && V1Off + V1Size <= V2Off && +           ((V2Off + V2Size <= StructSize) || +            (V2Off + V2Size - StructSize <= V1Off)); +  }; + +  if (EltsDontOverlap(V1Off, V1Size, V2Off, V2Size) || +      EltsDontOverlap(V2Off, V2Size, V1Off, V1Size)) +    return NoAlias; + +  return MayAlias; +} + +// If a we have (a) a GEP and (b) a pointer based on an alloca, and the +// beginning of the object the GEP points would have a negative offset with +// repsect to the alloca, that means the GEP can not alias pointer (b). +// Note that the pointer based on the alloca may not be a GEP. For +// example, it may be the alloca itself. +// The same applies if (b) is based on a GlobalVariable. Note that just being +// based on isIdentifiedObject() is not enough - we need an identified object +// that does not permit access to negative offsets. For example, a negative +// offset from a noalias argument or call can be inbounds w.r.t the actual +// underlying object. +// +// For example, consider: +// +//   struct { int f0, int f1, ...} foo; +//   foo alloca; +//   foo* random = bar(alloca); +//   int *f0 = &alloca.f0 +//   int *f1 = &random->f1; +// +// Which is lowered, approximately, to: +// +//  %alloca = alloca %struct.foo +//  %random = call %struct.foo* @random(%struct.foo* %alloca) +//  %f0 = getelementptr inbounds %struct, %struct.foo* %alloca, i32 0, i32 0 +//  %f1 = getelementptr inbounds %struct, %struct.foo* %random, i32 0, i32 1 +// +// Assume %f1 and %f0 alias. Then %f1 would point into the object allocated +// by %alloca. Since the %f1 GEP is inbounds, that means %random must also +// point into the same object. But since %f0 points to the beginning of %alloca, +// the highest %f1 can be is (%alloca + 3). This means %random can not be higher +// than (%alloca - 1), and so is not inbounds, a contradiction. +bool BasicAAResult::isGEPBaseAtNegativeOffset(const GEPOperator *GEPOp, +      const DecomposedGEP &DecompGEP, const DecomposedGEP &DecompObject, +      LocationSize ObjectAccessSize) { +  // If the object access size is unknown, or the GEP isn't inbounds, bail. +  if (ObjectAccessSize == MemoryLocation::UnknownSize || !GEPOp->isInBounds()) +    return false; + +  // We need the object to be an alloca or a globalvariable, and want to know +  // the offset of the pointer from the object precisely, so no variable +  // indices are allowed. +  if (!(isa<AllocaInst>(DecompObject.Base) || +        isa<GlobalVariable>(DecompObject.Base)) || +      !DecompObject.VarIndices.empty()) +    return false; + +  int64_t ObjectBaseOffset = DecompObject.StructOffset + +                             DecompObject.OtherOffset; + +  // If the GEP has no variable indices, we know the precise offset +  // from the base, then use it. If the GEP has variable indices, +  // we can't get exact GEP offset to identify pointer alias. So return +  // false in that case. +  if (!DecompGEP.VarIndices.empty()) +    return false; +  int64_t GEPBaseOffset = DecompGEP.StructOffset; +  GEPBaseOffset += DecompGEP.OtherOffset; + +  return (GEPBaseOffset >= ObjectBaseOffset + (int64_t)ObjectAccessSize); +} + +/// Provides a bunch of ad-hoc rules to disambiguate a GEP instruction against +/// another pointer. +/// +/// We know that V1 is a GEP, but we don't know anything about V2. +/// UnderlyingV1 is GetUnderlyingObject(GEP1, DL), UnderlyingV2 is the same for +/// V2. +AliasResult +BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, +                        const AAMDNodes &V1AAInfo, const Value *V2, +                        LocationSize V2Size, const AAMDNodes &V2AAInfo, +                        const Value *UnderlyingV1, const Value *UnderlyingV2) { +  DecomposedGEP DecompGEP1, DecompGEP2; +  bool GEP1MaxLookupReached = +    DecomposeGEPExpression(GEP1, DecompGEP1, DL, &AC, DT); +  bool GEP2MaxLookupReached = +    DecomposeGEPExpression(V2, DecompGEP2, DL, &AC, DT); + +  int64_t GEP1BaseOffset = DecompGEP1.StructOffset + DecompGEP1.OtherOffset; +  int64_t GEP2BaseOffset = DecompGEP2.StructOffset + DecompGEP2.OtherOffset; + +  assert(DecompGEP1.Base == UnderlyingV1 && DecompGEP2.Base == UnderlyingV2 && +         "DecomposeGEPExpression returned a result different from " +         "GetUnderlyingObject"); + +  // If the GEP's offset relative to its base is such that the base would +  // fall below the start of the object underlying V2, then the GEP and V2 +  // cannot alias. +  if (!GEP1MaxLookupReached && !GEP2MaxLookupReached && +      isGEPBaseAtNegativeOffset(GEP1, DecompGEP1, DecompGEP2, V2Size)) +    return NoAlias; +  // If we have two gep instructions with must-alias or not-alias'ing base +  // pointers, figure out if the indexes to the GEP tell us anything about the +  // derived pointer. +  if (const GEPOperator *GEP2 = dyn_cast<GEPOperator>(V2)) { +    // Check for the GEP base being at a negative offset, this time in the other +    // direction. +    if (!GEP1MaxLookupReached && !GEP2MaxLookupReached && +        isGEPBaseAtNegativeOffset(GEP2, DecompGEP2, DecompGEP1, V1Size)) +      return NoAlias; +    // Do the base pointers alias? +    AliasResult BaseAlias = +        aliasCheck(UnderlyingV1, MemoryLocation::UnknownSize, AAMDNodes(), +                   UnderlyingV2, MemoryLocation::UnknownSize, AAMDNodes()); + +    // Check for geps of non-aliasing underlying pointers where the offsets are +    // identical. +    if ((BaseAlias == MayAlias) && V1Size == V2Size) { +      // Do the base pointers alias assuming type and size. +      AliasResult PreciseBaseAlias = aliasCheck(UnderlyingV1, V1Size, V1AAInfo, +                                                UnderlyingV2, V2Size, V2AAInfo); +      if (PreciseBaseAlias == NoAlias) { +        // See if the computed offset from the common pointer tells us about the +        // relation of the resulting pointer. +        // If the max search depth is reached the result is undefined +        if (GEP2MaxLookupReached || GEP1MaxLookupReached) +          return MayAlias; + +        // Same offsets. +        if (GEP1BaseOffset == GEP2BaseOffset && +            DecompGEP1.VarIndices == DecompGEP2.VarIndices) +          return NoAlias; +      } +    } + +    // If we get a No or May, then return it immediately, no amount of analysis +    // will improve this situation. +    if (BaseAlias != MustAlias) { +      assert(BaseAlias == NoAlias || BaseAlias == MayAlias); +      return BaseAlias; +    } + +    // Otherwise, we have a MustAlias.  Since the base pointers alias each other +    // exactly, see if the computed offset from the common pointer tells us +    // about the relation of the resulting pointer. +    // If we know the two GEPs are based off of the exact same pointer (and not +    // just the same underlying object), see if that tells us anything about +    // the resulting pointers. +    if (GEP1->getPointerOperand()->stripPointerCastsAndInvariantGroups() == +            GEP2->getPointerOperand()->stripPointerCastsAndInvariantGroups() && +        GEP1->getPointerOperandType() == GEP2->getPointerOperandType()) { +      AliasResult R = aliasSameBasePointerGEPs(GEP1, V1Size, GEP2, V2Size, DL); +      // If we couldn't find anything interesting, don't abandon just yet. +      if (R != MayAlias) +        return R; +    } + +    // If the max search depth is reached, the result is undefined +    if (GEP2MaxLookupReached || GEP1MaxLookupReached) +      return MayAlias; + +    // Subtract the GEP2 pointer from the GEP1 pointer to find out their +    // symbolic difference. +    GEP1BaseOffset -= GEP2BaseOffset; +    GetIndexDifference(DecompGEP1.VarIndices, DecompGEP2.VarIndices); + +  } else { +    // Check to see if these two pointers are related by the getelementptr +    // instruction.  If one pointer is a GEP with a non-zero index of the other +    // pointer, we know they cannot alias. + +    // If both accesses are unknown size, we can't do anything useful here. +    if (V1Size == MemoryLocation::UnknownSize && +        V2Size == MemoryLocation::UnknownSize) +      return MayAlias; + +    AliasResult R = aliasCheck(UnderlyingV1, MemoryLocation::UnknownSize, +                               AAMDNodes(), V2, MemoryLocation::UnknownSize, +                               V2AAInfo, nullptr, UnderlyingV2); +    if (R != MustAlias) { +      // If V2 may alias GEP base pointer, conservatively returns MayAlias. +      // If V2 is known not to alias GEP base pointer, then the two values +      // cannot alias per GEP semantics: "Any memory access must be done through +      // a pointer value associated with an address range of the memory access, +      // otherwise the behavior is undefined.". +      assert(R == NoAlias || R == MayAlias); +      return R; +    } + +    // If the max search depth is reached the result is undefined +    if (GEP1MaxLookupReached) +      return MayAlias; +  } + +  // In the two GEP Case, if there is no difference in the offsets of the +  // computed pointers, the resultant pointers are a must alias.  This +  // happens when we have two lexically identical GEP's (for example). +  // +  // In the other case, if we have getelementptr <ptr>, 0, 0, 0, 0, ... and V2 +  // must aliases the GEP, the end result is a must alias also. +  if (GEP1BaseOffset == 0 && DecompGEP1.VarIndices.empty()) +    return MustAlias; + +  // If there is a constant difference between the pointers, but the difference +  // is less than the size of the associated memory object, then we know +  // that the objects are partially overlapping.  If the difference is +  // greater, we know they do not overlap. +  if (GEP1BaseOffset != 0 && DecompGEP1.VarIndices.empty()) { +    if (GEP1BaseOffset >= 0) { +      if (V2Size != MemoryLocation::UnknownSize) { +        if ((uint64_t)GEP1BaseOffset < V2Size) +          return PartialAlias; +        return NoAlias; +      } +    } else { +      // We have the situation where: +      // +                + +      // | BaseOffset     | +      // ---------------->| +      // |-->V1Size       |-------> V2Size +      // GEP1             V2 +      // We need to know that V2Size is not unknown, otherwise we might have +      // stripped a gep with negative index ('gep <ptr>, -1, ...). +      if (V1Size != MemoryLocation::UnknownSize && +          V2Size != MemoryLocation::UnknownSize) { +        if (-(uint64_t)GEP1BaseOffset < V1Size) +          return PartialAlias; +        return NoAlias; +      } +    } +  } + +  if (!DecompGEP1.VarIndices.empty()) { +    uint64_t Modulo = 0; +    bool AllPositive = true; +    for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) { + +      // Try to distinguish something like &A[i][1] against &A[42][0]. +      // Grab the least significant bit set in any of the scales. We +      // don't need std::abs here (even if the scale's negative) as we'll +      // be ^'ing Modulo with itself later. +      Modulo |= (uint64_t)DecompGEP1.VarIndices[i].Scale; + +      if (AllPositive) { +        // If the Value could change between cycles, then any reasoning about +        // the Value this cycle may not hold in the next cycle. We'll just +        // give up if we can't determine conditions that hold for every cycle: +        const Value *V = DecompGEP1.VarIndices[i].V; + +        KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, DT); +        bool SignKnownZero = Known.isNonNegative(); +        bool SignKnownOne = Known.isNegative(); + +        // Zero-extension widens the variable, and so forces the sign +        // bit to zero. +        bool IsZExt = DecompGEP1.VarIndices[i].ZExtBits > 0 || isa<ZExtInst>(V); +        SignKnownZero |= IsZExt; +        SignKnownOne &= !IsZExt; + +        // If the variable begins with a zero then we know it's +        // positive, regardless of whether the value is signed or +        // unsigned. +        int64_t Scale = DecompGEP1.VarIndices[i].Scale; +        AllPositive = +            (SignKnownZero && Scale >= 0) || (SignKnownOne && Scale < 0); +      } +    } + +    Modulo = Modulo ^ (Modulo & (Modulo - 1)); + +    // We can compute the difference between the two addresses +    // mod Modulo. Check whether that difference guarantees that the +    // two locations do not alias. +    uint64_t ModOffset = (uint64_t)GEP1BaseOffset & (Modulo - 1); +    if (V1Size != MemoryLocation::UnknownSize && +        V2Size != MemoryLocation::UnknownSize && ModOffset >= V2Size && +        V1Size <= Modulo - ModOffset) +      return NoAlias; + +    // If we know all the variables are positive, then GEP1 >= GEP1BasePtr. +    // If GEP1BasePtr > V2 (GEP1BaseOffset > 0) then we know the pointers +    // don't alias if V2Size can fit in the gap between V2 and GEP1BasePtr. +    if (AllPositive && GEP1BaseOffset > 0 && V2Size <= (uint64_t)GEP1BaseOffset) +      return NoAlias; + +    if (constantOffsetHeuristic(DecompGEP1.VarIndices, V1Size, V2Size, +                                GEP1BaseOffset, &AC, DT)) +      return NoAlias; +  } + +  // Statically, we can see that the base objects are the same, but the +  // pointers have dynamic offsets which we can't resolve. And none of our +  // little tricks above worked. +  return MayAlias; +} + +static AliasResult MergeAliasResults(AliasResult A, AliasResult B) { +  // If the results agree, take it. +  if (A == B) +    return A; +  // A mix of PartialAlias and MustAlias is PartialAlias. +  if ((A == PartialAlias && B == MustAlias) || +      (B == PartialAlias && A == MustAlias)) +    return PartialAlias; +  // Otherwise, we don't know anything. +  return MayAlias; +} + +/// Provides a bunch of ad-hoc rules to disambiguate a Select instruction +/// against another. +AliasResult BasicAAResult::aliasSelect(const SelectInst *SI, +                                       LocationSize SISize, +                                       const AAMDNodes &SIAAInfo, +                                       const Value *V2, LocationSize V2Size, +                                       const AAMDNodes &V2AAInfo, +                                       const Value *UnderV2) { +  // If the values are Selects with the same condition, we can do a more precise +  // check: just check for aliases between the values on corresponding arms. +  if (const SelectInst *SI2 = dyn_cast<SelectInst>(V2)) +    if (SI->getCondition() == SI2->getCondition()) { +      AliasResult Alias = aliasCheck(SI->getTrueValue(), SISize, SIAAInfo, +                                     SI2->getTrueValue(), V2Size, V2AAInfo); +      if (Alias == MayAlias) +        return MayAlias; +      AliasResult ThisAlias = +          aliasCheck(SI->getFalseValue(), SISize, SIAAInfo, +                     SI2->getFalseValue(), V2Size, V2AAInfo); +      return MergeAliasResults(ThisAlias, Alias); +    } + +  // If both arms of the Select node NoAlias or MustAlias V2, then returns +  // NoAlias / MustAlias. Otherwise, returns MayAlias. +  AliasResult Alias = +      aliasCheck(V2, V2Size, V2AAInfo, SI->getTrueValue(), +                 SISize, SIAAInfo, UnderV2); +  if (Alias == MayAlias) +    return MayAlias; + +  AliasResult ThisAlias = +      aliasCheck(V2, V2Size, V2AAInfo, SI->getFalseValue(), SISize, SIAAInfo, +                 UnderV2); +  return MergeAliasResults(ThisAlias, Alias); +} + +/// Provide a bunch of ad-hoc rules to disambiguate a PHI instruction against +/// another. +AliasResult BasicAAResult::aliasPHI(const PHINode *PN, LocationSize PNSize, +                                    const AAMDNodes &PNAAInfo, const Value *V2, +                                    LocationSize V2Size, +                                    const AAMDNodes &V2AAInfo, +                                    const Value *UnderV2) { +  // Track phi nodes we have visited. We use this information when we determine +  // value equivalence. +  VisitedPhiBBs.insert(PN->getParent()); + +  // If the values are PHIs in the same block, we can do a more precise +  // as well as efficient check: just check for aliases between the values +  // on corresponding edges. +  if (const PHINode *PN2 = dyn_cast<PHINode>(V2)) +    if (PN2->getParent() == PN->getParent()) { +      LocPair Locs(MemoryLocation(PN, PNSize, PNAAInfo), +                   MemoryLocation(V2, V2Size, V2AAInfo)); +      if (PN > V2) +        std::swap(Locs.first, Locs.second); +      // Analyse the PHIs' inputs under the assumption that the PHIs are +      // NoAlias. +      // If the PHIs are May/MustAlias there must be (recursively) an input +      // operand from outside the PHIs' cycle that is MayAlias/MustAlias or +      // there must be an operation on the PHIs within the PHIs' value cycle +      // that causes a MayAlias. +      // Pretend the phis do not alias. +      AliasResult Alias = NoAlias; +      assert(AliasCache.count(Locs) && +             "There must exist an entry for the phi node"); +      AliasResult OrigAliasResult = AliasCache[Locs]; +      AliasCache[Locs] = NoAlias; + +      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { +        AliasResult ThisAlias = +            aliasCheck(PN->getIncomingValue(i), PNSize, PNAAInfo, +                       PN2->getIncomingValueForBlock(PN->getIncomingBlock(i)), +                       V2Size, V2AAInfo); +        Alias = MergeAliasResults(ThisAlias, Alias); +        if (Alias == MayAlias) +          break; +      } + +      // Reset if speculation failed. +      if (Alias != NoAlias) +        AliasCache[Locs] = OrigAliasResult; + +      return Alias; +    } + +  SmallVector<Value *, 4> V1Srcs; +  bool isRecursive = false; +  if (PV)  { +    // If we have PhiValues then use it to get the underlying phi values. +    const PhiValues::ValueSet &PhiValueSet = PV->getValuesForPhi(PN); +    // If we have more phi values than the search depth then return MayAlias +    // conservatively to avoid compile time explosion. The worst possible case +    // is if both sides are PHI nodes. In which case, this is O(m x n) time +    // where 'm' and 'n' are the number of PHI sources. +    if (PhiValueSet.size() > MaxLookupSearchDepth) +      return MayAlias; +    // Add the values to V1Srcs +    for (Value *PV1 : PhiValueSet) { +      if (EnableRecPhiAnalysis) { +        if (GEPOperator *PV1GEP = dyn_cast<GEPOperator>(PV1)) { +          // Check whether the incoming value is a GEP that advances the pointer +          // result of this PHI node (e.g. in a loop). If this is the case, we +          // would recurse and always get a MayAlias. Handle this case specially +          // below. +          if (PV1GEP->getPointerOperand() == PN && PV1GEP->getNumIndices() == 1 && +              isa<ConstantInt>(PV1GEP->idx_begin())) { +            isRecursive = true; +            continue; +          } +        } +      } +      V1Srcs.push_back(PV1); +    } +  } else { +    // If we don't have PhiInfo then just look at the operands of the phi itself +    // FIXME: Remove this once we can guarantee that we have PhiInfo always +    SmallPtrSet<Value *, 4> UniqueSrc; +    for (Value *PV1 : PN->incoming_values()) { +      if (isa<PHINode>(PV1)) +        // If any of the source itself is a PHI, return MayAlias conservatively +        // to avoid compile time explosion. The worst possible case is if both +        // sides are PHI nodes. In which case, this is O(m x n) time where 'm' +        // and 'n' are the number of PHI sources. +        return MayAlias; + +      if (EnableRecPhiAnalysis) +        if (GEPOperator *PV1GEP = dyn_cast<GEPOperator>(PV1)) { +          // Check whether the incoming value is a GEP that advances the pointer +          // result of this PHI node (e.g. in a loop). If this is the case, we +          // would recurse and always get a MayAlias. Handle this case specially +          // below. +          if (PV1GEP->getPointerOperand() == PN && PV1GEP->getNumIndices() == 1 && +              isa<ConstantInt>(PV1GEP->idx_begin())) { +            isRecursive = true; +            continue; +          } +        } + +      if (UniqueSrc.insert(PV1).second) +        V1Srcs.push_back(PV1); +    } +  } + +  // If V1Srcs is empty then that means that the phi has no underlying non-phi +  // value. This should only be possible in blocks unreachable from the entry +  // block, but return MayAlias just in case. +  if (V1Srcs.empty()) +    return MayAlias; + +  // If this PHI node is recursive, set the size of the accessed memory to +  // unknown to represent all the possible values the GEP could advance the +  // pointer to. +  if (isRecursive) +    PNSize = MemoryLocation::UnknownSize; + +  AliasResult Alias = +      aliasCheck(V2, V2Size, V2AAInfo, V1Srcs[0], +                 PNSize, PNAAInfo, UnderV2); + +  // Early exit if the check of the first PHI source against V2 is MayAlias. +  // Other results are not possible. +  if (Alias == MayAlias) +    return MayAlias; + +  // If all sources of the PHI node NoAlias or MustAlias V2, then returns +  // NoAlias / MustAlias. Otherwise, returns MayAlias. +  for (unsigned i = 1, e = V1Srcs.size(); i != e; ++i) { +    Value *V = V1Srcs[i]; + +    AliasResult ThisAlias = +        aliasCheck(V2, V2Size, V2AAInfo, V, PNSize, PNAAInfo, UnderV2); +    Alias = MergeAliasResults(ThisAlias, Alias); +    if (Alias == MayAlias) +      break; +  } + +  return Alias; +} + +/// Provides a bunch of ad-hoc rules to disambiguate in common cases, such as +/// array references. +AliasResult BasicAAResult::aliasCheck(const Value *V1, LocationSize V1Size, +                                      AAMDNodes V1AAInfo, const Value *V2, +                                      LocationSize V2Size, AAMDNodes V2AAInfo, +                                      const Value *O1, const Value *O2) { +  // If either of the memory references is empty, it doesn't matter what the +  // pointer values are. +  if (V1Size == 0 || V2Size == 0) +    return NoAlias; + +  // Strip off any casts if they exist. +  V1 = V1->stripPointerCastsAndInvariantGroups(); +  V2 = V2->stripPointerCastsAndInvariantGroups(); + +  // If V1 or V2 is undef, the result is NoAlias because we can always pick a +  // value for undef that aliases nothing in the program. +  if (isa<UndefValue>(V1) || isa<UndefValue>(V2)) +    return NoAlias; + +  // Are we checking for alias of the same value? +  // Because we look 'through' phi nodes, we could look at "Value" pointers from +  // different iterations. We must therefore make sure that this is not the +  // case. The function isValueEqualInPotentialCycles ensures that this cannot +  // happen by looking at the visited phi nodes and making sure they cannot +  // reach the value. +  if (isValueEqualInPotentialCycles(V1, V2)) +    return MustAlias; + +  if (!V1->getType()->isPointerTy() || !V2->getType()->isPointerTy()) +    return NoAlias; // Scalars cannot alias each other + +  // Figure out what objects these things are pointing to if we can. +  if (O1 == nullptr) +    O1 = GetUnderlyingObject(V1, DL, MaxLookupSearchDepth); + +  if (O2 == nullptr) +    O2 = GetUnderlyingObject(V2, DL, MaxLookupSearchDepth); + +  // Null values in the default address space don't point to any object, so they +  // don't alias any other pointer. +  if (const ConstantPointerNull *CPN = dyn_cast<ConstantPointerNull>(O1)) +    if (!NullPointerIsDefined(&F, CPN->getType()->getAddressSpace())) +      return NoAlias; +  if (const ConstantPointerNull *CPN = dyn_cast<ConstantPointerNull>(O2)) +    if (!NullPointerIsDefined(&F, CPN->getType()->getAddressSpace())) +      return NoAlias; + +  if (O1 != O2) { +    // If V1/V2 point to two different objects, we know that we have no alias. +    if (isIdentifiedObject(O1) && isIdentifiedObject(O2)) +      return NoAlias; + +    // Constant pointers can't alias with non-const isIdentifiedObject objects. +    if ((isa<Constant>(O1) && isIdentifiedObject(O2) && !isa<Constant>(O2)) || +        (isa<Constant>(O2) && isIdentifiedObject(O1) && !isa<Constant>(O1))) +      return NoAlias; + +    // Function arguments can't alias with things that are known to be +    // unambigously identified at the function level. +    if ((isa<Argument>(O1) && isIdentifiedFunctionLocal(O2)) || +        (isa<Argument>(O2) && isIdentifiedFunctionLocal(O1))) +      return NoAlias; + +    // If one pointer is the result of a call/invoke or load and the other is a +    // non-escaping local object within the same function, then we know the +    // object couldn't escape to a point where the call could return it. +    // +    // Note that if the pointers are in different functions, there are a +    // variety of complications. A call with a nocapture argument may still +    // temporary store the nocapture argument's value in a temporary memory +    // location if that memory location doesn't escape. Or it may pass a +    // nocapture value to other functions as long as they don't capture it. +    if (isEscapeSource(O1) && isNonEscapingLocalObject(O2)) +      return NoAlias; +    if (isEscapeSource(O2) && isNonEscapingLocalObject(O1)) +      return NoAlias; +  } + +  // If the size of one access is larger than the entire object on the other +  // side, then we know such behavior is undefined and can assume no alias. +  bool NullIsValidLocation = NullPointerIsDefined(&F); +  if ((V1Size != MemoryLocation::UnknownSize && +       isObjectSmallerThan(O2, V1Size, DL, TLI, NullIsValidLocation)) || +      (V2Size != MemoryLocation::UnknownSize && +       isObjectSmallerThan(O1, V2Size, DL, TLI, NullIsValidLocation))) +    return NoAlias; + +  // Check the cache before climbing up use-def chains. This also terminates +  // otherwise infinitely recursive queries. +  LocPair Locs(MemoryLocation(V1, V1Size, V1AAInfo), +               MemoryLocation(V2, V2Size, V2AAInfo)); +  if (V1 > V2) +    std::swap(Locs.first, Locs.second); +  std::pair<AliasCacheTy::iterator, bool> Pair = +      AliasCache.insert(std::make_pair(Locs, MayAlias)); +  if (!Pair.second) +    return Pair.first->second; + +  // FIXME: This isn't aggressively handling alias(GEP, PHI) for example: if the +  // GEP can't simplify, we don't even look at the PHI cases. +  if (!isa<GEPOperator>(V1) && isa<GEPOperator>(V2)) { +    std::swap(V1, V2); +    std::swap(V1Size, V2Size); +    std::swap(O1, O2); +    std::swap(V1AAInfo, V2AAInfo); +  } +  if (const GEPOperator *GV1 = dyn_cast<GEPOperator>(V1)) { +    AliasResult Result = +        aliasGEP(GV1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O1, O2); +    if (Result != MayAlias) +      return AliasCache[Locs] = Result; +  } + +  if (isa<PHINode>(V2) && !isa<PHINode>(V1)) { +    std::swap(V1, V2); +    std::swap(O1, O2); +    std::swap(V1Size, V2Size); +    std::swap(V1AAInfo, V2AAInfo); +  } +  if (const PHINode *PN = dyn_cast<PHINode>(V1)) { +    AliasResult Result = aliasPHI(PN, V1Size, V1AAInfo, +                                  V2, V2Size, V2AAInfo, O2); +    if (Result != MayAlias) +      return AliasCache[Locs] = Result; +  } + +  if (isa<SelectInst>(V2) && !isa<SelectInst>(V1)) { +    std::swap(V1, V2); +    std::swap(O1, O2); +    std::swap(V1Size, V2Size); +    std::swap(V1AAInfo, V2AAInfo); +  } +  if (const SelectInst *S1 = dyn_cast<SelectInst>(V1)) { +    AliasResult Result = +        aliasSelect(S1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O2); +    if (Result != MayAlias) +      return AliasCache[Locs] = Result; +  } + +  // If both pointers are pointing into the same object and one of them +  // accesses the entire object, then the accesses must overlap in some way. +  if (O1 == O2) +    if (V1Size != MemoryLocation::UnknownSize && +        V2Size != MemoryLocation::UnknownSize && +        (isObjectSize(O1, V1Size, DL, TLI, NullIsValidLocation) || +         isObjectSize(O2, V2Size, DL, TLI, NullIsValidLocation))) +      return AliasCache[Locs] = PartialAlias; + +  // Recurse back into the best AA results we have, potentially with refined +  // memory locations. We have already ensured that BasicAA has a MayAlias +  // cache result for these, so any recursion back into BasicAA won't loop. +  AliasResult Result = getBestAAResults().alias(Locs.first, Locs.second); +  return AliasCache[Locs] = Result; +} + +/// Check whether two Values can be considered equivalent. +/// +/// In addition to pointer equivalence of \p V1 and \p V2 this checks whether +/// they can not be part of a cycle in the value graph by looking at all +/// visited phi nodes an making sure that the phis cannot reach the value. We +/// have to do this because we are looking through phi nodes (That is we say +/// noalias(V, phi(VA, VB)) if noalias(V, VA) and noalias(V, VB). +bool BasicAAResult::isValueEqualInPotentialCycles(const Value *V, +                                                  const Value *V2) { +  if (V != V2) +    return false; + +  const Instruction *Inst = dyn_cast<Instruction>(V); +  if (!Inst) +    return true; + +  if (VisitedPhiBBs.empty()) +    return true; + +  if (VisitedPhiBBs.size() > MaxNumPhiBBsValueReachabilityCheck) +    return false; + +  // Make sure that the visited phis cannot reach the Value. This ensures that +  // the Values cannot come from different iterations of a potential cycle the +  // phi nodes could be involved in. +  for (auto *P : VisitedPhiBBs) +    if (isPotentiallyReachable(&P->front(), Inst, DT, LI)) +      return false; + +  return true; +} + +/// Computes the symbolic difference between two de-composed GEPs. +/// +/// Dest and Src are the variable indices from two decomposed GetElementPtr +/// instructions GEP1 and GEP2 which have common base pointers. +void BasicAAResult::GetIndexDifference( +    SmallVectorImpl<VariableGEPIndex> &Dest, +    const SmallVectorImpl<VariableGEPIndex> &Src) { +  if (Src.empty()) +    return; + +  for (unsigned i = 0, e = Src.size(); i != e; ++i) { +    const Value *V = Src[i].V; +    unsigned ZExtBits = Src[i].ZExtBits, SExtBits = Src[i].SExtBits; +    int64_t Scale = Src[i].Scale; + +    // Find V in Dest.  This is N^2, but pointer indices almost never have more +    // than a few variable indexes. +    for (unsigned j = 0, e = Dest.size(); j != e; ++j) { +      if (!isValueEqualInPotentialCycles(Dest[j].V, V) || +          Dest[j].ZExtBits != ZExtBits || Dest[j].SExtBits != SExtBits) +        continue; + +      // If we found it, subtract off Scale V's from the entry in Dest.  If it +      // goes to zero, remove the entry. +      if (Dest[j].Scale != Scale) +        Dest[j].Scale -= Scale; +      else +        Dest.erase(Dest.begin() + j); +      Scale = 0; +      break; +    } + +    // If we didn't consume this entry, add it to the end of the Dest list. +    if (Scale) { +      VariableGEPIndex Entry = {V, ZExtBits, SExtBits, -Scale}; +      Dest.push_back(Entry); +    } +  } +} + +bool BasicAAResult::constantOffsetHeuristic( +    const SmallVectorImpl<VariableGEPIndex> &VarIndices, LocationSize V1Size, +    LocationSize V2Size, int64_t BaseOffset, AssumptionCache *AC, +    DominatorTree *DT) { +  if (VarIndices.size() != 2 || V1Size == MemoryLocation::UnknownSize || +      V2Size == MemoryLocation::UnknownSize) +    return false; + +  const VariableGEPIndex &Var0 = VarIndices[0], &Var1 = VarIndices[1]; + +  if (Var0.ZExtBits != Var1.ZExtBits || Var0.SExtBits != Var1.SExtBits || +      Var0.Scale != -Var1.Scale) +    return false; + +  unsigned Width = Var1.V->getType()->getIntegerBitWidth(); + +  // We'll strip off the Extensions of Var0 and Var1 and do another round +  // of GetLinearExpression decomposition. In the example above, if Var0 +  // is zext(%x + 1) we should get V1 == %x and V1Offset == 1. + +  APInt V0Scale(Width, 0), V0Offset(Width, 0), V1Scale(Width, 0), +      V1Offset(Width, 0); +  bool NSW = true, NUW = true; +  unsigned V0ZExtBits = 0, V0SExtBits = 0, V1ZExtBits = 0, V1SExtBits = 0; +  const Value *V0 = GetLinearExpression(Var0.V, V0Scale, V0Offset, V0ZExtBits, +                                        V0SExtBits, DL, 0, AC, DT, NSW, NUW); +  NSW = true; +  NUW = true; +  const Value *V1 = GetLinearExpression(Var1.V, V1Scale, V1Offset, V1ZExtBits, +                                        V1SExtBits, DL, 0, AC, DT, NSW, NUW); + +  if (V0Scale != V1Scale || V0ZExtBits != V1ZExtBits || +      V0SExtBits != V1SExtBits || !isValueEqualInPotentialCycles(V0, V1)) +    return false; + +  // We have a hit - Var0 and Var1 only differ by a constant offset! + +  // If we've been sext'ed then zext'd the maximum difference between Var0 and +  // Var1 is possible to calculate, but we're just interested in the absolute +  // minimum difference between the two. The minimum distance may occur due to +  // wrapping; consider "add i3 %i, 5": if %i == 7 then 7 + 5 mod 8 == 4, and so +  // the minimum distance between %i and %i + 5 is 3. +  APInt MinDiff = V0Offset - V1Offset, Wrapped = -MinDiff; +  MinDiff = APIntOps::umin(MinDiff, Wrapped); +  uint64_t MinDiffBytes = MinDiff.getZExtValue() * std::abs(Var0.Scale); + +  // We can't definitely say whether GEP1 is before or after V2 due to wrapping +  // arithmetic (i.e. for some values of GEP1 and V2 GEP1 < V2, and for other +  // values GEP1 > V2). We'll therefore only declare NoAlias if both V1Size and +  // V2Size can fit in the MinDiffBytes gap. +  return V1Size + std::abs(BaseOffset) <= MinDiffBytes && +         V2Size + std::abs(BaseOffset) <= MinDiffBytes; +} + +//===----------------------------------------------------------------------===// +// BasicAliasAnalysis Pass +//===----------------------------------------------------------------------===// + +AnalysisKey BasicAA::Key; + +BasicAAResult BasicAA::run(Function &F, FunctionAnalysisManager &AM) { +  return BasicAAResult(F.getParent()->getDataLayout(), +                       F, +                       AM.getResult<TargetLibraryAnalysis>(F), +                       AM.getResult<AssumptionAnalysis>(F), +                       &AM.getResult<DominatorTreeAnalysis>(F), +                       AM.getCachedResult<LoopAnalysis>(F), +                       AM.getCachedResult<PhiValuesAnalysis>(F)); +} + +BasicAAWrapperPass::BasicAAWrapperPass() : FunctionPass(ID) { +    initializeBasicAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +char BasicAAWrapperPass::ID = 0; + +void BasicAAWrapperPass::anchor() {} + +INITIALIZE_PASS_BEGIN(BasicAAWrapperPass, "basicaa", +                      "Basic Alias Analysis (stateless AA impl)", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(BasicAAWrapperPass, "basicaa", +                    "Basic Alias Analysis (stateless AA impl)", false, true) + +FunctionPass *llvm::createBasicAAWrapperPass() { +  return new BasicAAWrapperPass(); +} + +bool BasicAAWrapperPass::runOnFunction(Function &F) { +  auto &ACT = getAnalysis<AssumptionCacheTracker>(); +  auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>(); +  auto &DTWP = getAnalysis<DominatorTreeWrapperPass>(); +  auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); +  auto *PVWP = getAnalysisIfAvailable<PhiValuesWrapperPass>(); + +  Result.reset(new BasicAAResult(F.getParent()->getDataLayout(), F, TLIWP.getTLI(), +                                 ACT.getAssumptionCache(F), &DTWP.getDomTree(), +                                 LIWP ? &LIWP->getLoopInfo() : nullptr, +                                 PVWP ? &PVWP->getResult() : nullptr)); + +  return false; +} + +void BasicAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<AssumptionCacheTracker>(); +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +  AU.addUsedIfAvailable<PhiValuesWrapperPass>(); +} + +BasicAAResult llvm::createLegacyPMBasicAAResult(Pass &P, Function &F) { +  return BasicAAResult( +      F.getParent()->getDataLayout(), +      F, +      P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), +      P.getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F)); +} diff --git a/contrib/llvm/lib/Analysis/BlockFrequencyInfo.cpp b/contrib/llvm/lib/Analysis/BlockFrequencyInfo.cpp new file mode 100644 index 000000000000..41c295895213 --- /dev/null +++ b/contrib/llvm/lib/Analysis/BlockFrequencyInfo.cpp @@ -0,0 +1,342 @@ +//===- BlockFrequencyInfo.cpp - Block Frequency Analysis ------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Loops should be simplified before this analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/iterator.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <string> + +using namespace llvm; + +#define DEBUG_TYPE "block-freq" + +static cl::opt<GVDAGType> ViewBlockFreqPropagationDAG( +    "view-block-freq-propagation-dags", cl::Hidden, +    cl::desc("Pop up a window to show a dag displaying how block " +             "frequencies propagation through the CFG."), +    cl::values(clEnumValN(GVDT_None, "none", "do not display graphs."), +               clEnumValN(GVDT_Fraction, "fraction", +                          "display a graph using the " +                          "fractional block frequency representation."), +               clEnumValN(GVDT_Integer, "integer", +                          "display a graph using the raw " +                          "integer fractional block frequency representation."), +               clEnumValN(GVDT_Count, "count", "display a graph using the real " +                                               "profile count if available."))); + +cl::opt<std::string> +    ViewBlockFreqFuncName("view-bfi-func-name", cl::Hidden, +                          cl::desc("The option to specify " +                                   "the name of the function " +                                   "whose CFG will be displayed.")); + +cl::opt<unsigned> +    ViewHotFreqPercent("view-hot-freq-percent", cl::init(10), cl::Hidden, +                       cl::desc("An integer in percent used to specify " +                                "the hot blocks/edges to be displayed " +                                "in red: a block or edge whose frequency " +                                "is no less than the max frequency of the " +                                "function multiplied by this percent.")); + +// Command line option to turn on CFG dot or text dump after profile annotation. +cl::opt<PGOViewCountsType> PGOViewCounts( +    "pgo-view-counts", cl::Hidden, +    cl::desc("A boolean option to show CFG dag or text with " +             "block profile counts and branch probabilities " +             "right after PGO profile annotation step. The " +             "profile counts are computed using branch " +             "probabilities from the runtime profile data and " +             "block frequency propagation algorithm. To view " +             "the raw counts from the profile, use option " +             "-pgo-view-raw-counts instead. To limit graph " +             "display to only one function, use filtering option " +             "-view-bfi-func-name."), +    cl::values(clEnumValN(PGOVCT_None, "none", "do not show."), +               clEnumValN(PGOVCT_Graph, "graph", "show a graph."), +               clEnumValN(PGOVCT_Text, "text", "show in text."))); + +static cl::opt<bool> PrintBlockFreq( +    "print-bfi", cl::init(false), cl::Hidden, +    cl::desc("Print the block frequency info.")); + +cl::opt<std::string> PrintBlockFreqFuncName( +    "print-bfi-func-name", cl::Hidden, +    cl::desc("The option to specify the name of the function " +             "whose block frequency info is printed.")); + +namespace llvm { + +static GVDAGType getGVDT() { +  if (PGOViewCounts == PGOVCT_Graph) +    return GVDT_Count; +  return ViewBlockFreqPropagationDAG; +} + +template <> +struct GraphTraits<BlockFrequencyInfo *> { +  using NodeRef = const BasicBlock *; +  using ChildIteratorType = succ_const_iterator; +  using nodes_iterator = pointer_iterator<Function::const_iterator>; + +  static NodeRef getEntryNode(const BlockFrequencyInfo *G) { +    return &G->getFunction()->front(); +  } + +  static ChildIteratorType child_begin(const NodeRef N) { +    return succ_begin(N); +  } + +  static ChildIteratorType child_end(const NodeRef N) { return succ_end(N); } + +  static nodes_iterator nodes_begin(const BlockFrequencyInfo *G) { +    return nodes_iterator(G->getFunction()->begin()); +  } + +  static nodes_iterator nodes_end(const BlockFrequencyInfo *G) { +    return nodes_iterator(G->getFunction()->end()); +  } +}; + +using BFIDOTGTraitsBase = +    BFIDOTGraphTraitsBase<BlockFrequencyInfo, BranchProbabilityInfo>; + +template <> +struct DOTGraphTraits<BlockFrequencyInfo *> : public BFIDOTGTraitsBase { +  explicit DOTGraphTraits(bool isSimple = false) +      : BFIDOTGTraitsBase(isSimple) {} + +  std::string getNodeLabel(const BasicBlock *Node, +                           const BlockFrequencyInfo *Graph) { + +    return BFIDOTGTraitsBase::getNodeLabel(Node, Graph, getGVDT()); +  } + +  std::string getNodeAttributes(const BasicBlock *Node, +                                const BlockFrequencyInfo *Graph) { +    return BFIDOTGTraitsBase::getNodeAttributes(Node, Graph, +                                                ViewHotFreqPercent); +  } + +  std::string getEdgeAttributes(const BasicBlock *Node, EdgeIter EI, +                                const BlockFrequencyInfo *BFI) { +    return BFIDOTGTraitsBase::getEdgeAttributes(Node, EI, BFI, BFI->getBPI(), +                                                ViewHotFreqPercent); +  } +}; + +} // end namespace llvm + +BlockFrequencyInfo::BlockFrequencyInfo() = default; + +BlockFrequencyInfo::BlockFrequencyInfo(const Function &F, +                                       const BranchProbabilityInfo &BPI, +                                       const LoopInfo &LI) { +  calculate(F, BPI, LI); +} + +BlockFrequencyInfo::BlockFrequencyInfo(BlockFrequencyInfo &&Arg) +    : BFI(std::move(Arg.BFI)) {} + +BlockFrequencyInfo &BlockFrequencyInfo::operator=(BlockFrequencyInfo &&RHS) { +  releaseMemory(); +  BFI = std::move(RHS.BFI); +  return *this; +} + +// Explicitly define the default constructor otherwise it would be implicitly +// defined at the first ODR-use which is the BFI member in the +// LazyBlockFrequencyInfo header.  The dtor needs the BlockFrequencyInfoImpl +// template instantiated which is not available in the header. +BlockFrequencyInfo::~BlockFrequencyInfo() = default; + +bool BlockFrequencyInfo::invalidate(Function &F, const PreservedAnalyses &PA, +                                    FunctionAnalysisManager::Invalidator &) { +  // Check whether the analysis, all analyses on functions, or the function's +  // CFG have been preserved. +  auto PAC = PA.getChecker<BlockFrequencyAnalysis>(); +  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || +           PAC.preservedSet<CFGAnalyses>()); +} + +void BlockFrequencyInfo::calculate(const Function &F, +                                   const BranchProbabilityInfo &BPI, +                                   const LoopInfo &LI) { +  if (!BFI) +    BFI.reset(new ImplType); +  BFI->calculate(F, BPI, LI); +  if (ViewBlockFreqPropagationDAG != GVDT_None && +      (ViewBlockFreqFuncName.empty() || +       F.getName().equals(ViewBlockFreqFuncName))) { +    view(); +  } +  if (PrintBlockFreq && +      (PrintBlockFreqFuncName.empty() || +       F.getName().equals(PrintBlockFreqFuncName))) { +    print(dbgs()); +  } +} + +BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const { +  return BFI ? BFI->getBlockFreq(BB) : 0; +} + +Optional<uint64_t> +BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB) const { +  if (!BFI) +    return None; + +  return BFI->getBlockProfileCount(*getFunction(), BB); +} + +Optional<uint64_t> +BlockFrequencyInfo::getProfileCountFromFreq(uint64_t Freq) const { +  if (!BFI) +    return None; +  return BFI->getProfileCountFromFreq(*getFunction(), Freq); +} + +bool BlockFrequencyInfo::isIrrLoopHeader(const BasicBlock *BB) { +  assert(BFI && "Expected analysis to be available"); +  return BFI->isIrrLoopHeader(BB); +} + +void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { +  assert(BFI && "Expected analysis to be available"); +  BFI->setBlockFreq(BB, Freq); +} + +void BlockFrequencyInfo::setBlockFreqAndScale( +    const BasicBlock *ReferenceBB, uint64_t Freq, +    SmallPtrSetImpl<BasicBlock *> &BlocksToScale) { +  assert(BFI && "Expected analysis to be available"); +  // Use 128 bits APInt to avoid overflow. +  APInt NewFreq(128, Freq); +  APInt OldFreq(128, BFI->getBlockFreq(ReferenceBB).getFrequency()); +  APInt BBFreq(128, 0); +  for (auto *BB : BlocksToScale) { +    BBFreq = BFI->getBlockFreq(BB).getFrequency(); +    // Multiply first by NewFreq and then divide by OldFreq +    // to minimize loss of precision. +    BBFreq *= NewFreq; +    // udiv is an expensive operation in the general case. If this ends up being +    // a hot spot, one of the options proposed in +    // https://reviews.llvm.org/D28535#650071 could be used to avoid this. +    BBFreq = BBFreq.udiv(OldFreq); +    BFI->setBlockFreq(BB, BBFreq.getLimitedValue()); +  } +  BFI->setBlockFreq(ReferenceBB, Freq); +} + +/// Pop up a ghostview window with the current block frequency propagation +/// rendered using dot. +void BlockFrequencyInfo::view() const { +  ViewGraph(const_cast<BlockFrequencyInfo *>(this), "BlockFrequencyDAGs"); +} + +const Function *BlockFrequencyInfo::getFunction() const { +  return BFI ? BFI->getFunction() : nullptr; +} + +const BranchProbabilityInfo *BlockFrequencyInfo::getBPI() const { +  return BFI ? &BFI->getBPI() : nullptr; +} + +raw_ostream &BlockFrequencyInfo:: +printBlockFreq(raw_ostream &OS, const BlockFrequency Freq) const { +  return BFI ? BFI->printBlockFreq(OS, Freq) : OS; +} + +raw_ostream & +BlockFrequencyInfo::printBlockFreq(raw_ostream &OS, +                                   const BasicBlock *BB) const { +  return BFI ? BFI->printBlockFreq(OS, BB) : OS; +} + +uint64_t BlockFrequencyInfo::getEntryFreq() const { +  return BFI ? BFI->getEntryFreq() : 0; +} + +void BlockFrequencyInfo::releaseMemory() { BFI.reset(); } + +void BlockFrequencyInfo::print(raw_ostream &OS) const { +  if (BFI) +    BFI->print(OS); +} + +INITIALIZE_PASS_BEGIN(BlockFrequencyInfoWrapperPass, "block-freq", +                      "Block Frequency Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(BlockFrequencyInfoWrapperPass, "block-freq", +                    "Block Frequency Analysis", true, true) + +char BlockFrequencyInfoWrapperPass::ID = 0; + +BlockFrequencyInfoWrapperPass::BlockFrequencyInfoWrapperPass() +    : FunctionPass(ID) { +  initializeBlockFrequencyInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +BlockFrequencyInfoWrapperPass::~BlockFrequencyInfoWrapperPass() = default; + +void BlockFrequencyInfoWrapperPass::print(raw_ostream &OS, +                                          const Module *) const { +  BFI.print(OS); +} + +void BlockFrequencyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.addRequired<BranchProbabilityInfoWrapperPass>(); +  AU.addRequired<LoopInfoWrapperPass>(); +  AU.setPreservesAll(); +} + +void BlockFrequencyInfoWrapperPass::releaseMemory() { BFI.releaseMemory(); } + +bool BlockFrequencyInfoWrapperPass::runOnFunction(Function &F) { +  BranchProbabilityInfo &BPI = +      getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); +  LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  BFI.calculate(F, BPI, LI); +  return false; +} + +AnalysisKey BlockFrequencyAnalysis::Key; +BlockFrequencyInfo BlockFrequencyAnalysis::run(Function &F, +                                               FunctionAnalysisManager &AM) { +  BlockFrequencyInfo BFI; +  BFI.calculate(F, AM.getResult<BranchProbabilityAnalysis>(F), +                AM.getResult<LoopAnalysis>(F)); +  return BFI; +} + +PreservedAnalyses +BlockFrequencyPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { +  OS << "Printing analysis results of BFI for function " +     << "'" << F.getName() << "':" +     << "\n"; +  AM.getResult<BlockFrequencyAnalysis>(F).print(OS); +  return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp b/contrib/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp new file mode 100644 index 000000000000..3d095068e7ff --- /dev/null +++ b/contrib/llvm/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -0,0 +1,847 @@ +//===- BlockFrequencyImplInfo.cpp - Block Frequency Info Implementation ---===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Loops should be simplified before this analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ScaledNumber.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <list> +#include <numeric> +#include <utility> +#include <vector> + +using namespace llvm; +using namespace llvm::bfi_detail; + +#define DEBUG_TYPE "block-freq" + +ScaledNumber<uint64_t> BlockMass::toScaled() const { +  if (isFull()) +    return ScaledNumber<uint64_t>(1, 0); +  return ScaledNumber<uint64_t>(getMass() + 1, -64); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void BlockMass::dump() const { print(dbgs()); } +#endif + +static char getHexDigit(int N) { +  assert(N < 16); +  if (N < 10) +    return '0' + N; +  return 'a' + N - 10; +} + +raw_ostream &BlockMass::print(raw_ostream &OS) const { +  for (int Digits = 0; Digits < 16; ++Digits) +    OS << getHexDigit(Mass >> (60 - Digits * 4) & 0xf); +  return OS; +} + +namespace { + +using BlockNode = BlockFrequencyInfoImplBase::BlockNode; +using Distribution = BlockFrequencyInfoImplBase::Distribution; +using WeightList = BlockFrequencyInfoImplBase::Distribution::WeightList; +using Scaled64 = BlockFrequencyInfoImplBase::Scaled64; +using LoopData = BlockFrequencyInfoImplBase::LoopData; +using Weight = BlockFrequencyInfoImplBase::Weight; +using FrequencyData = BlockFrequencyInfoImplBase::FrequencyData; + +/// Dithering mass distributer. +/// +/// This class splits up a single mass into portions by weight, dithering to +/// spread out error.  No mass is lost.  The dithering precision depends on the +/// precision of the product of \a BlockMass and \a BranchProbability. +/// +/// The distribution algorithm follows. +/// +///  1. Initialize by saving the sum of the weights in \a RemWeight and the +///     mass to distribute in \a RemMass. +/// +///  2. For each portion: +/// +///      1. Construct a branch probability, P, as the portion's weight divided +///         by the current value of \a RemWeight. +///      2. Calculate the portion's mass as \a RemMass times P. +///      3. Update \a RemWeight and \a RemMass at each portion by subtracting +///         the current portion's weight and mass. +struct DitheringDistributer { +  uint32_t RemWeight; +  BlockMass RemMass; + +  DitheringDistributer(Distribution &Dist, const BlockMass &Mass); + +  BlockMass takeMass(uint32_t Weight); +}; + +} // end anonymous namespace + +DitheringDistributer::DitheringDistributer(Distribution &Dist, +                                           const BlockMass &Mass) { +  Dist.normalize(); +  RemWeight = Dist.Total; +  RemMass = Mass; +} + +BlockMass DitheringDistributer::takeMass(uint32_t Weight) { +  assert(Weight && "invalid weight"); +  assert(Weight <= RemWeight); +  BlockMass Mass = RemMass * BranchProbability(Weight, RemWeight); + +  // Decrement totals (dither). +  RemWeight -= Weight; +  RemMass -= Mass; +  return Mass; +} + +void Distribution::add(const BlockNode &Node, uint64_t Amount, +                       Weight::DistType Type) { +  assert(Amount && "invalid weight of 0"); +  uint64_t NewTotal = Total + Amount; + +  // Check for overflow.  It should be impossible to overflow twice. +  bool IsOverflow = NewTotal < Total; +  assert(!(DidOverflow && IsOverflow) && "unexpected repeated overflow"); +  DidOverflow |= IsOverflow; + +  // Update the total. +  Total = NewTotal; + +  // Save the weight. +  Weights.push_back(Weight(Type, Node, Amount)); +} + +static void combineWeight(Weight &W, const Weight &OtherW) { +  assert(OtherW.TargetNode.isValid()); +  if (!W.Amount) { +    W = OtherW; +    return; +  } +  assert(W.Type == OtherW.Type); +  assert(W.TargetNode == OtherW.TargetNode); +  assert(OtherW.Amount && "Expected non-zero weight"); +  if (W.Amount > W.Amount + OtherW.Amount) +    // Saturate on overflow. +    W.Amount = UINT64_MAX; +  else +    W.Amount += OtherW.Amount; +} + +static void combineWeightsBySorting(WeightList &Weights) { +  // Sort so edges to the same node are adjacent. +  llvm::sort(Weights.begin(), Weights.end(), +             [](const Weight &L, +                const Weight &R) { return L.TargetNode < R.TargetNode; }); + +  // Combine adjacent edges. +  WeightList::iterator O = Weights.begin(); +  for (WeightList::const_iterator I = O, L = O, E = Weights.end(); I != E; +       ++O, (I = L)) { +    *O = *I; + +    // Find the adjacent weights to the same node. +    for (++L; L != E && I->TargetNode == L->TargetNode; ++L) +      combineWeight(*O, *L); +  } + +  // Erase extra entries. +  Weights.erase(O, Weights.end()); +} + +static void combineWeightsByHashing(WeightList &Weights) { +  // Collect weights into a DenseMap. +  using HashTable = DenseMap<BlockNode::IndexType, Weight>; + +  HashTable Combined(NextPowerOf2(2 * Weights.size())); +  for (const Weight &W : Weights) +    combineWeight(Combined[W.TargetNode.Index], W); + +  // Check whether anything changed. +  if (Weights.size() == Combined.size()) +    return; + +  // Fill in the new weights. +  Weights.clear(); +  Weights.reserve(Combined.size()); +  for (const auto &I : Combined) +    Weights.push_back(I.second); +} + +static void combineWeights(WeightList &Weights) { +  // Use a hash table for many successors to keep this linear. +  if (Weights.size() > 128) { +    combineWeightsByHashing(Weights); +    return; +  } + +  combineWeightsBySorting(Weights); +} + +static uint64_t shiftRightAndRound(uint64_t N, int Shift) { +  assert(Shift >= 0); +  assert(Shift < 64); +  if (!Shift) +    return N; +  return (N >> Shift) + (UINT64_C(1) & N >> (Shift - 1)); +} + +void Distribution::normalize() { +  // Early exit for termination nodes. +  if (Weights.empty()) +    return; + +  // Only bother if there are multiple successors. +  if (Weights.size() > 1) +    combineWeights(Weights); + +  // Early exit when combined into a single successor. +  if (Weights.size() == 1) { +    Total = 1; +    Weights.front().Amount = 1; +    return; +  } + +  // Determine how much to shift right so that the total fits into 32-bits. +  // +  // If we shift at all, shift by 1 extra.  Otherwise, the lower limit of 1 +  // for each weight can cause a 32-bit overflow. +  int Shift = 0; +  if (DidOverflow) +    Shift = 33; +  else if (Total > UINT32_MAX) +    Shift = 33 - countLeadingZeros(Total); + +  // Early exit if nothing needs to be scaled. +  if (!Shift) { +    // If we didn't overflow then combineWeights() shouldn't have changed the +    // sum of the weights, but let's double-check. +    assert(Total == std::accumulate(Weights.begin(), Weights.end(), UINT64_C(0), +                                    [](uint64_t Sum, const Weight &W) { +                      return Sum + W.Amount; +                    }) && +           "Expected total to be correct"); +    return; +  } + +  // Recompute the total through accumulation (rather than shifting it) so that +  // it's accurate after shifting and any changes combineWeights() made above. +  Total = 0; + +  // Sum the weights to each node and shift right if necessary. +  for (Weight &W : Weights) { +    // Scale down below UINT32_MAX.  Since Shift is larger than necessary, we +    // can round here without concern about overflow. +    assert(W.TargetNode.isValid()); +    W.Amount = std::max(UINT64_C(1), shiftRightAndRound(W.Amount, Shift)); +    assert(W.Amount <= UINT32_MAX); + +    // Update the total. +    Total += W.Amount; +  } +  assert(Total <= UINT32_MAX); +} + +void BlockFrequencyInfoImplBase::clear() { +  // Swap with a default-constructed std::vector, since std::vector<>::clear() +  // does not actually clear heap storage. +  std::vector<FrequencyData>().swap(Freqs); +  IsIrrLoopHeader.clear(); +  std::vector<WorkingData>().swap(Working); +  Loops.clear(); +} + +/// Clear all memory not needed downstream. +/// +/// Releases all memory not used downstream.  In particular, saves Freqs. +static void cleanup(BlockFrequencyInfoImplBase &BFI) { +  std::vector<FrequencyData> SavedFreqs(std::move(BFI.Freqs)); +  SparseBitVector<> SavedIsIrrLoopHeader(std::move(BFI.IsIrrLoopHeader)); +  BFI.clear(); +  BFI.Freqs = std::move(SavedFreqs); +  BFI.IsIrrLoopHeader = std::move(SavedIsIrrLoopHeader); +} + +bool BlockFrequencyInfoImplBase::addToDist(Distribution &Dist, +                                           const LoopData *OuterLoop, +                                           const BlockNode &Pred, +                                           const BlockNode &Succ, +                                           uint64_t Weight) { +  if (!Weight) +    Weight = 1; + +  auto isLoopHeader = [&OuterLoop](const BlockNode &Node) { +    return OuterLoop && OuterLoop->isHeader(Node); +  }; + +  BlockNode Resolved = Working[Succ.Index].getResolvedNode(); + +#ifndef NDEBUG +  auto debugSuccessor = [&](const char *Type) { +    dbgs() << "  =>" +           << " [" << Type << "] weight = " << Weight; +    if (!isLoopHeader(Resolved)) +      dbgs() << ", succ = " << getBlockName(Succ); +    if (Resolved != Succ) +      dbgs() << ", resolved = " << getBlockName(Resolved); +    dbgs() << "\n"; +  }; +  (void)debugSuccessor; +#endif + +  if (isLoopHeader(Resolved)) { +    LLVM_DEBUG(debugSuccessor("backedge")); +    Dist.addBackedge(Resolved, Weight); +    return true; +  } + +  if (Working[Resolved.Index].getContainingLoop() != OuterLoop) { +    LLVM_DEBUG(debugSuccessor("  exit  ")); +    Dist.addExit(Resolved, Weight); +    return true; +  } + +  if (Resolved < Pred) { +    if (!isLoopHeader(Pred)) { +      // If OuterLoop is an irreducible loop, we can't actually handle this. +      assert((!OuterLoop || !OuterLoop->isIrreducible()) && +             "unhandled irreducible control flow"); + +      // Irreducible backedge.  Abort. +      LLVM_DEBUG(debugSuccessor("abort!!!")); +      return false; +    } + +    // If "Pred" is a loop header, then this isn't really a backedge; rather, +    // OuterLoop must be irreducible.  These false backedges can come only from +    // secondary loop headers. +    assert(OuterLoop && OuterLoop->isIrreducible() && !isLoopHeader(Resolved) && +           "unhandled irreducible control flow"); +  } + +  LLVM_DEBUG(debugSuccessor(" local  ")); +  Dist.addLocal(Resolved, Weight); +  return true; +} + +bool BlockFrequencyInfoImplBase::addLoopSuccessorsToDist( +    const LoopData *OuterLoop, LoopData &Loop, Distribution &Dist) { +  // Copy the exit map into Dist. +  for (const auto &I : Loop.Exits) +    if (!addToDist(Dist, OuterLoop, Loop.getHeader(), I.first, +                   I.second.getMass())) +      // Irreducible backedge. +      return false; + +  return true; +} + +/// Compute the loop scale for a loop. +void BlockFrequencyInfoImplBase::computeLoopScale(LoopData &Loop) { +  // Compute loop scale. +  LLVM_DEBUG(dbgs() << "compute-loop-scale: " << getLoopName(Loop) << "\n"); + +  // Infinite loops need special handling. If we give the back edge an infinite +  // mass, they may saturate all the other scales in the function down to 1, +  // making all the other region temperatures look exactly the same. Choose an +  // arbitrary scale to avoid these issues. +  // +  // FIXME: An alternate way would be to select a symbolic scale which is later +  // replaced to be the maximum of all computed scales plus 1. This would +  // appropriately describe the loop as having a large scale, without skewing +  // the final frequency computation. +  const Scaled64 InfiniteLoopScale(1, 12); + +  // LoopScale == 1 / ExitMass +  // ExitMass == HeadMass - BackedgeMass +  BlockMass TotalBackedgeMass; +  for (auto &Mass : Loop.BackedgeMass) +    TotalBackedgeMass += Mass; +  BlockMass ExitMass = BlockMass::getFull() - TotalBackedgeMass; + +  // Block scale stores the inverse of the scale. If this is an infinite loop, +  // its exit mass will be zero. In this case, use an arbitrary scale for the +  // loop scale. +  Loop.Scale = +      ExitMass.isEmpty() ? InfiniteLoopScale : ExitMass.toScaled().inverse(); + +  LLVM_DEBUG(dbgs() << " - exit-mass = " << ExitMass << " (" +                    << BlockMass::getFull() << " - " << TotalBackedgeMass +                    << ")\n" +                    << " - scale = " << Loop.Scale << "\n"); +} + +/// Package up a loop. +void BlockFrequencyInfoImplBase::packageLoop(LoopData &Loop) { +  LLVM_DEBUG(dbgs() << "packaging-loop: " << getLoopName(Loop) << "\n"); + +  // Clear the subloop exits to prevent quadratic memory usage. +  for (const BlockNode &M : Loop.Nodes) { +    if (auto *Loop = Working[M.Index].getPackagedLoop()) +      Loop->Exits.clear(); +    LLVM_DEBUG(dbgs() << " - node: " << getBlockName(M.Index) << "\n"); +  } +  Loop.IsPackaged = true; +} + +#ifndef NDEBUG +static void debugAssign(const BlockFrequencyInfoImplBase &BFI, +                        const DitheringDistributer &D, const BlockNode &T, +                        const BlockMass &M, const char *Desc) { +  dbgs() << "  => assign " << M << " (" << D.RemMass << ")"; +  if (Desc) +    dbgs() << " [" << Desc << "]"; +  if (T.isValid()) +    dbgs() << " to " << BFI.getBlockName(T); +  dbgs() << "\n"; +} +#endif + +void BlockFrequencyInfoImplBase::distributeMass(const BlockNode &Source, +                                                LoopData *OuterLoop, +                                                Distribution &Dist) { +  BlockMass Mass = Working[Source.Index].getMass(); +  LLVM_DEBUG(dbgs() << "  => mass:  " << Mass << "\n"); + +  // Distribute mass to successors as laid out in Dist. +  DitheringDistributer D(Dist, Mass); + +  for (const Weight &W : Dist.Weights) { +    // Check for a local edge (non-backedge and non-exit). +    BlockMass Taken = D.takeMass(W.Amount); +    if (W.Type == Weight::Local) { +      Working[W.TargetNode.Index].getMass() += Taken; +      LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr)); +      continue; +    } + +    // Backedges and exits only make sense if we're processing a loop. +    assert(OuterLoop && "backedge or exit outside of loop"); + +    // Check for a backedge. +    if (W.Type == Weight::Backedge) { +      OuterLoop->BackedgeMass[OuterLoop->getHeaderIndex(W.TargetNode)] += Taken; +      LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, "back")); +      continue; +    } + +    // This must be an exit. +    assert(W.Type == Weight::Exit); +    OuterLoop->Exits.push_back(std::make_pair(W.TargetNode, Taken)); +    LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, "exit")); +  } +} + +static void convertFloatingToInteger(BlockFrequencyInfoImplBase &BFI, +                                     const Scaled64 &Min, const Scaled64 &Max) { +  // Scale the Factor to a size that creates integers.  Ideally, integers would +  // be scaled so that Max == UINT64_MAX so that they can be best +  // differentiated.  However, in the presence of large frequency values, small +  // frequencies are scaled down to 1, making it impossible to differentiate +  // small, unequal numbers. When the spread between Min and Max frequencies +  // fits well within MaxBits, we make the scale be at least 8. +  const unsigned MaxBits = 64; +  const unsigned SpreadBits = (Max / Min).lg(); +  Scaled64 ScalingFactor; +  if (SpreadBits <= MaxBits - 3) { +    // If the values are small enough, make the scaling factor at least 8 to +    // allow distinguishing small values. +    ScalingFactor = Min.inverse(); +    ScalingFactor <<= 3; +  } else { +    // If the values need more than MaxBits to be represented, saturate small +    // frequency values down to 1 by using a scaling factor that benefits large +    // frequency values. +    ScalingFactor = Scaled64(1, MaxBits) / Max; +  } + +  // Translate the floats to integers. +  LLVM_DEBUG(dbgs() << "float-to-int: min = " << Min << ", max = " << Max +                    << ", factor = " << ScalingFactor << "\n"); +  for (size_t Index = 0; Index < BFI.Freqs.size(); ++Index) { +    Scaled64 Scaled = BFI.Freqs[Index].Scaled * ScalingFactor; +    BFI.Freqs[Index].Integer = std::max(UINT64_C(1), Scaled.toInt<uint64_t>()); +    LLVM_DEBUG(dbgs() << " - " << BFI.getBlockName(Index) << ": float = " +                      << BFI.Freqs[Index].Scaled << ", scaled = " << Scaled +                      << ", int = " << BFI.Freqs[Index].Integer << "\n"); +  } +} + +/// Unwrap a loop package. +/// +/// Visits all the members of a loop, adjusting their BlockData according to +/// the loop's pseudo-node. +static void unwrapLoop(BlockFrequencyInfoImplBase &BFI, LoopData &Loop) { +  LLVM_DEBUG(dbgs() << "unwrap-loop-package: " << BFI.getLoopName(Loop) +                    << ": mass = " << Loop.Mass << ", scale = " << Loop.Scale +                    << "\n"); +  Loop.Scale *= Loop.Mass.toScaled(); +  Loop.IsPackaged = false; +  LLVM_DEBUG(dbgs() << "  => combined-scale = " << Loop.Scale << "\n"); + +  // Propagate the head scale through the loop.  Since members are visited in +  // RPO, the head scale will be updated by the loop scale first, and then the +  // final head scale will be used for updated the rest of the members. +  for (const BlockNode &N : Loop.Nodes) { +    const auto &Working = BFI.Working[N.Index]; +    Scaled64 &F = Working.isAPackage() ? Working.getPackagedLoop()->Scale +                                       : BFI.Freqs[N.Index].Scaled; +    Scaled64 New = Loop.Scale * F; +    LLVM_DEBUG(dbgs() << " - " << BFI.getBlockName(N) << ": " << F << " => " +                      << New << "\n"); +    F = New; +  } +} + +void BlockFrequencyInfoImplBase::unwrapLoops() { +  // Set initial frequencies from loop-local masses. +  for (size_t Index = 0; Index < Working.size(); ++Index) +    Freqs[Index].Scaled = Working[Index].Mass.toScaled(); + +  for (LoopData &Loop : Loops) +    unwrapLoop(*this, Loop); +} + +void BlockFrequencyInfoImplBase::finalizeMetrics() { +  // Unwrap loop packages in reverse post-order, tracking min and max +  // frequencies. +  auto Min = Scaled64::getLargest(); +  auto Max = Scaled64::getZero(); +  for (size_t Index = 0; Index < Working.size(); ++Index) { +    // Update min/max scale. +    Min = std::min(Min, Freqs[Index].Scaled); +    Max = std::max(Max, Freqs[Index].Scaled); +  } + +  // Convert to integers. +  convertFloatingToInteger(*this, Min, Max); + +  // Clean up data structures. +  cleanup(*this); + +  // Print out the final stats. +  LLVM_DEBUG(dump()); +} + +BlockFrequency +BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const { +  if (!Node.isValid()) +    return 0; +  return Freqs[Node.Index].Integer; +} + +Optional<uint64_t> +BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F, +                                                 const BlockNode &Node) const { +  return getProfileCountFromFreq(F, getBlockFreq(Node).getFrequency()); +} + +Optional<uint64_t> +BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F, +                                                    uint64_t Freq) const { +  auto EntryCount = F.getEntryCount(); +  if (!EntryCount) +    return None; +  // Use 128 bit APInt to do the arithmetic to avoid overflow. +  APInt BlockCount(128, EntryCount.getCount()); +  APInt BlockFreq(128, Freq); +  APInt EntryFreq(128, getEntryFreq()); +  BlockCount *= BlockFreq; +  BlockCount = BlockCount.udiv(EntryFreq); +  return BlockCount.getLimitedValue(); +} + +bool +BlockFrequencyInfoImplBase::isIrrLoopHeader(const BlockNode &Node) { +  if (!Node.isValid()) +    return false; +  return IsIrrLoopHeader.test(Node.Index); +} + +Scaled64 +BlockFrequencyInfoImplBase::getFloatingBlockFreq(const BlockNode &Node) const { +  if (!Node.isValid()) +    return Scaled64::getZero(); +  return Freqs[Node.Index].Scaled; +} + +void BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node, +                                              uint64_t Freq) { +  assert(Node.isValid() && "Expected valid node"); +  assert(Node.Index < Freqs.size() && "Expected legal index"); +  Freqs[Node.Index].Integer = Freq; +} + +std::string +BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const { +  return {}; +} + +std::string +BlockFrequencyInfoImplBase::getLoopName(const LoopData &Loop) const { +  return getBlockName(Loop.getHeader()) + (Loop.isIrreducible() ? "**" : "*"); +} + +raw_ostream & +BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS, +                                           const BlockNode &Node) const { +  return OS << getFloatingBlockFreq(Node); +} + +raw_ostream & +BlockFrequencyInfoImplBase::printBlockFreq(raw_ostream &OS, +                                           const BlockFrequency &Freq) const { +  Scaled64 Block(Freq.getFrequency(), 0); +  Scaled64 Entry(getEntryFreq(), 0); + +  return OS << Block / Entry; +} + +void IrreducibleGraph::addNodesInLoop(const BFIBase::LoopData &OuterLoop) { +  Start = OuterLoop.getHeader(); +  Nodes.reserve(OuterLoop.Nodes.size()); +  for (auto N : OuterLoop.Nodes) +    addNode(N); +  indexNodes(); +} + +void IrreducibleGraph::addNodesInFunction() { +  Start = 0; +  for (uint32_t Index = 0; Index < BFI.Working.size(); ++Index) +    if (!BFI.Working[Index].isPackaged()) +      addNode(Index); +  indexNodes(); +} + +void IrreducibleGraph::indexNodes() { +  for (auto &I : Nodes) +    Lookup[I.Node.Index] = &I; +} + +void IrreducibleGraph::addEdge(IrrNode &Irr, const BlockNode &Succ, +                               const BFIBase::LoopData *OuterLoop) { +  if (OuterLoop && OuterLoop->isHeader(Succ)) +    return; +  auto L = Lookup.find(Succ.Index); +  if (L == Lookup.end()) +    return; +  IrrNode &SuccIrr = *L->second; +  Irr.Edges.push_back(&SuccIrr); +  SuccIrr.Edges.push_front(&Irr); +  ++SuccIrr.NumIn; +} + +namespace llvm { + +template <> struct GraphTraits<IrreducibleGraph> { +  using GraphT = bfi_detail::IrreducibleGraph; +  using NodeRef = const GraphT::IrrNode *; +  using ChildIteratorType = GraphT::IrrNode::iterator; + +  static NodeRef getEntryNode(const GraphT &G) { return G.StartIrr; } +  static ChildIteratorType child_begin(NodeRef N) { return N->succ_begin(); } +  static ChildIteratorType child_end(NodeRef N) { return N->succ_end(); } +}; + +} // end namespace llvm + +/// Find extra irreducible headers. +/// +/// Find entry blocks and other blocks with backedges, which exist when \c G +/// contains irreducible sub-SCCs. +static void findIrreducibleHeaders( +    const BlockFrequencyInfoImplBase &BFI, +    const IrreducibleGraph &G, +    const std::vector<const IrreducibleGraph::IrrNode *> &SCC, +    LoopData::NodeList &Headers, LoopData::NodeList &Others) { +  // Map from nodes in the SCC to whether it's an entry block. +  SmallDenseMap<const IrreducibleGraph::IrrNode *, bool, 8> InSCC; + +  // InSCC also acts the set of nodes in the graph.  Seed it. +  for (const auto *I : SCC) +    InSCC[I] = false; + +  for (auto I = InSCC.begin(), E = InSCC.end(); I != E; ++I) { +    auto &Irr = *I->first; +    for (const auto *P : make_range(Irr.pred_begin(), Irr.pred_end())) { +      if (InSCC.count(P)) +        continue; + +      // This is an entry block. +      I->second = true; +      Headers.push_back(Irr.Node); +      LLVM_DEBUG(dbgs() << "  => entry = " << BFI.getBlockName(Irr.Node) +                        << "\n"); +      break; +    } +  } +  assert(Headers.size() >= 2 && +         "Expected irreducible CFG; -loop-info is likely invalid"); +  if (Headers.size() == InSCC.size()) { +    // Every block is a header. +    llvm::sort(Headers.begin(), Headers.end()); +    return; +  } + +  // Look for extra headers from irreducible sub-SCCs. +  for (const auto &I : InSCC) { +    // Entry blocks are already headers. +    if (I.second) +      continue; + +    auto &Irr = *I.first; +    for (const auto *P : make_range(Irr.pred_begin(), Irr.pred_end())) { +      // Skip forward edges. +      if (P->Node < Irr.Node) +        continue; + +      // Skip predecessors from entry blocks.  These can have inverted +      // ordering. +      if (InSCC.lookup(P)) +        continue; + +      // Store the extra header. +      Headers.push_back(Irr.Node); +      LLVM_DEBUG(dbgs() << "  => extra = " << BFI.getBlockName(Irr.Node) +                        << "\n"); +      break; +    } +    if (Headers.back() == Irr.Node) +      // Added this as a header. +      continue; + +    // This is not a header. +    Others.push_back(Irr.Node); +    LLVM_DEBUG(dbgs() << "  => other = " << BFI.getBlockName(Irr.Node) << "\n"); +  } +  llvm::sort(Headers.begin(), Headers.end()); +  llvm::sort(Others.begin(), Others.end()); +} + +static void createIrreducibleLoop( +    BlockFrequencyInfoImplBase &BFI, const IrreducibleGraph &G, +    LoopData *OuterLoop, std::list<LoopData>::iterator Insert, +    const std::vector<const IrreducibleGraph::IrrNode *> &SCC) { +  // Translate the SCC into RPO. +  LLVM_DEBUG(dbgs() << " - found-scc\n"); + +  LoopData::NodeList Headers; +  LoopData::NodeList Others; +  findIrreducibleHeaders(BFI, G, SCC, Headers, Others); + +  auto Loop = BFI.Loops.emplace(Insert, OuterLoop, Headers.begin(), +                                Headers.end(), Others.begin(), Others.end()); + +  // Update loop hierarchy. +  for (const auto &N : Loop->Nodes) +    if (BFI.Working[N.Index].isLoopHeader()) +      BFI.Working[N.Index].Loop->Parent = &*Loop; +    else +      BFI.Working[N.Index].Loop = &*Loop; +} + +iterator_range<std::list<LoopData>::iterator> +BlockFrequencyInfoImplBase::analyzeIrreducible( +    const IrreducibleGraph &G, LoopData *OuterLoop, +    std::list<LoopData>::iterator Insert) { +  assert((OuterLoop == nullptr) == (Insert == Loops.begin())); +  auto Prev = OuterLoop ? std::prev(Insert) : Loops.end(); + +  for (auto I = scc_begin(G); !I.isAtEnd(); ++I) { +    if (I->size() < 2) +      continue; + +    // Translate the SCC into RPO. +    createIrreducibleLoop(*this, G, OuterLoop, Insert, *I); +  } + +  if (OuterLoop) +    return make_range(std::next(Prev), Insert); +  return make_range(Loops.begin(), Insert); +} + +void +BlockFrequencyInfoImplBase::updateLoopWithIrreducible(LoopData &OuterLoop) { +  OuterLoop.Exits.clear(); +  for (auto &Mass : OuterLoop.BackedgeMass) +    Mass = BlockMass::getEmpty(); +  auto O = OuterLoop.Nodes.begin() + 1; +  for (auto I = O, E = OuterLoop.Nodes.end(); I != E; ++I) +    if (!Working[I->Index].isPackaged()) +      *O++ = *I; +  OuterLoop.Nodes.erase(O, OuterLoop.Nodes.end()); +} + +void BlockFrequencyInfoImplBase::adjustLoopHeaderMass(LoopData &Loop) { +  assert(Loop.isIrreducible() && "this only makes sense on irreducible loops"); + +  // Since the loop has more than one header block, the mass flowing back into +  // each header will be different. Adjust the mass in each header loop to +  // reflect the masses flowing through back edges. +  // +  // To do this, we distribute the initial mass using the backedge masses +  // as weights for the distribution. +  BlockMass LoopMass = BlockMass::getFull(); +  Distribution Dist; + +  LLVM_DEBUG(dbgs() << "adjust-loop-header-mass:\n"); +  for (uint32_t H = 0; H < Loop.NumHeaders; ++H) { +    auto &HeaderNode = Loop.Nodes[H]; +    auto &BackedgeMass = Loop.BackedgeMass[Loop.getHeaderIndex(HeaderNode)]; +    LLVM_DEBUG(dbgs() << " - Add back edge mass for node " +                      << getBlockName(HeaderNode) << ": " << BackedgeMass +                      << "\n"); +    if (BackedgeMass.getMass() > 0) +      Dist.addLocal(HeaderNode, BackedgeMass.getMass()); +    else +      LLVM_DEBUG(dbgs() << "   Nothing added. Back edge mass is zero\n"); +  } + +  DitheringDistributer D(Dist, LoopMass); + +  LLVM_DEBUG(dbgs() << " Distribute loop mass " << LoopMass +                    << " to headers using above weights\n"); +  for (const Weight &W : Dist.Weights) { +    BlockMass Taken = D.takeMass(W.Amount); +    assert(W.Type == Weight::Local && "all weights should be local"); +    Working[W.TargetNode.Index].getMass() = Taken; +    LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr)); +  } +} + +void BlockFrequencyInfoImplBase::distributeIrrLoopHeaderMass(Distribution &Dist) { +  BlockMass LoopMass = BlockMass::getFull(); +  DitheringDistributer D(Dist, LoopMass); +  for (const Weight &W : Dist.Weights) { +    BlockMass Taken = D.takeMass(W.Amount); +    assert(W.Type == Weight::Local && "all weights should be local"); +    Working[W.TargetNode.Index].getMass() = Taken; +    LLVM_DEBUG(debugAssign(*this, D, W.TargetNode, Taken, nullptr)); +  } +} diff --git a/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp new file mode 100644 index 000000000000..54a657073f0f --- /dev/null +++ b/contrib/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -0,0 +1,1039 @@ +//===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Loops should be simplified before this analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "branch-prob" + +static cl::opt<bool> PrintBranchProb( +    "print-bpi", cl::init(false), cl::Hidden, +    cl::desc("Print the branch probability info.")); + +cl::opt<std::string> PrintBranchProbFuncName( +    "print-bpi-func-name", cl::Hidden, +    cl::desc("The option to specify the name of the function " +             "whose branch probability info is printed.")); + +INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob", +                      "Branch Probability Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob", +                    "Branch Probability Analysis", false, true) + +char BranchProbabilityInfoWrapperPass::ID = 0; + +// Weights are for internal use only. They are used by heuristics to help to +// estimate edges' probability. Example: +// +// Using "Loop Branch Heuristics" we predict weights of edges for the +// block BB2. +//         ... +//          | +//          V +//         BB1<-+ +//          |   | +//          |   | (Weight = 124) +//          V   | +//         BB2--+ +//          | +//          | (Weight = 4) +//          V +//         BB3 +// +// Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875 +// Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125 +static const uint32_t LBH_TAKEN_WEIGHT = 124; +static const uint32_t LBH_NONTAKEN_WEIGHT = 4; +// Unlikely edges within a loop are half as likely as other edges +static const uint32_t LBH_UNLIKELY_WEIGHT = 62; + +/// Unreachable-terminating branch taken probability. +/// +/// This is the probability for a branch being taken to a block that terminates +/// (eventually) in unreachable. These are predicted as unlikely as possible. +/// All reachable probability will equally share the remaining part. +static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1); + +/// Weight for a branch taken going into a cold block. +/// +/// This is the weight for a branch taken toward a block marked +/// cold.  A block is marked cold if it's postdominated by a +/// block containing a call to a cold function.  Cold functions +/// are those marked with attribute 'cold'. +static const uint32_t CC_TAKEN_WEIGHT = 4; + +/// Weight for a branch not-taken into a cold block. +/// +/// This is the weight for a branch not taken toward a block marked +/// cold. +static const uint32_t CC_NONTAKEN_WEIGHT = 64; + +static const uint32_t PH_TAKEN_WEIGHT = 20; +static const uint32_t PH_NONTAKEN_WEIGHT = 12; + +static const uint32_t ZH_TAKEN_WEIGHT = 20; +static const uint32_t ZH_NONTAKEN_WEIGHT = 12; + +static const uint32_t FPH_TAKEN_WEIGHT = 20; +static const uint32_t FPH_NONTAKEN_WEIGHT = 12; + +/// Invoke-terminating normal branch taken weight +/// +/// This is the weight for branching to the normal destination of an invoke +/// instruction. We expect this to happen most of the time. Set the weight to an +/// absurdly high value so that nested loops subsume it. +static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1; + +/// Invoke-terminating normal branch not-taken weight. +/// +/// This is the weight for branching to the unwind destination of an invoke +/// instruction. This is essentially never taken. +static const uint32_t IH_NONTAKEN_WEIGHT = 1; + +/// Add \p BB to PostDominatedByUnreachable set if applicable. +void +BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) { +  const TerminatorInst *TI = BB->getTerminator(); +  if (TI->getNumSuccessors() == 0) { +    if (isa<UnreachableInst>(TI) || +        // If this block is terminated by a call to +        // @llvm.experimental.deoptimize then treat it like an unreachable since +        // the @llvm.experimental.deoptimize call is expected to practically +        // never execute. +        BB->getTerminatingDeoptimizeCall()) +      PostDominatedByUnreachable.insert(BB); +    return; +  } + +  // If the terminator is an InvokeInst, check only the normal destination block +  // as the unwind edge of InvokeInst is also very unlikely taken. +  if (auto *II = dyn_cast<InvokeInst>(TI)) { +    if (PostDominatedByUnreachable.count(II->getNormalDest())) +      PostDominatedByUnreachable.insert(BB); +    return; +  } + +  for (auto *I : successors(BB)) +    // If any of successor is not post dominated then BB is also not. +    if (!PostDominatedByUnreachable.count(I)) +      return; + +  PostDominatedByUnreachable.insert(BB); +} + +/// Add \p BB to PostDominatedByColdCall set if applicable. +void +BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) { +  assert(!PostDominatedByColdCall.count(BB)); +  const TerminatorInst *TI = BB->getTerminator(); +  if (TI->getNumSuccessors() == 0) +    return; + +  // If all of successor are post dominated then BB is also done. +  if (llvm::all_of(successors(BB), [&](const BasicBlock *SuccBB) { +        return PostDominatedByColdCall.count(SuccBB); +      })) { +    PostDominatedByColdCall.insert(BB); +    return; +  } + +  // If the terminator is an InvokeInst, check only the normal destination +  // block as the unwind edge of InvokeInst is also very unlikely taken. +  if (auto *II = dyn_cast<InvokeInst>(TI)) +    if (PostDominatedByColdCall.count(II->getNormalDest())) { +      PostDominatedByColdCall.insert(BB); +      return; +    } + +  // Otherwise, if the block itself contains a cold function, add it to the +  // set of blocks post-dominated by a cold call. +  for (auto &I : *BB) +    if (const CallInst *CI = dyn_cast<CallInst>(&I)) +      if (CI->hasFnAttr(Attribute::Cold)) { +        PostDominatedByColdCall.insert(BB); +        return; +      } +} + +/// Calculate edge weights for successors lead to unreachable. +/// +/// Predict that a successor which leads necessarily to an +/// unreachable-terminated block as extremely unlikely. +bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { +  const TerminatorInst *TI = BB->getTerminator(); +  (void) TI; +  assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); +  assert(!isa<InvokeInst>(TI) && +         "Invokes should have already been handled by calcInvokeHeuristics"); + +  SmallVector<unsigned, 4> UnreachableEdges; +  SmallVector<unsigned, 4> ReachableEdges; + +  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) +    if (PostDominatedByUnreachable.count(*I)) +      UnreachableEdges.push_back(I.getSuccessorIndex()); +    else +      ReachableEdges.push_back(I.getSuccessorIndex()); + +  // Skip probabilities if all were reachable. +  if (UnreachableEdges.empty()) +    return false; + +  if (ReachableEdges.empty()) { +    BranchProbability Prob(1, UnreachableEdges.size()); +    for (unsigned SuccIdx : UnreachableEdges) +      setEdgeProbability(BB, SuccIdx, Prob); +    return true; +  } + +  auto UnreachableProb = UR_TAKEN_PROB; +  auto ReachableProb = +      (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) / +      ReachableEdges.size(); + +  for (unsigned SuccIdx : UnreachableEdges) +    setEdgeProbability(BB, SuccIdx, UnreachableProb); +  for (unsigned SuccIdx : ReachableEdges) +    setEdgeProbability(BB, SuccIdx, ReachableProb); + +  return true; +} + +// Propagate existing explicit probabilities from either profile data or +// 'expect' intrinsic processing. Examine metadata against unreachable +// heuristic. The probability of the edge coming to unreachable block is +// set to min of metadata and unreachable heuristic. +bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { +  const TerminatorInst *TI = BB->getTerminator(); +  assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); +  if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI))) +    return false; + +  MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); +  if (!WeightsNode) +    return false; + +  // Check that the number of successors is manageable. +  assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors"); + +  // Ensure there are weights for all of the successors. Note that the first +  // operand to the metadata node is a name, not a weight. +  if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1) +    return false; + +  // Build up the final weights that will be used in a temporary buffer. +  // Compute the sum of all weights to later decide whether they need to +  // be scaled to fit in 32 bits. +  uint64_t WeightSum = 0; +  SmallVector<uint32_t, 2> Weights; +  SmallVector<unsigned, 2> UnreachableIdxs; +  SmallVector<unsigned, 2> ReachableIdxs; +  Weights.reserve(TI->getNumSuccessors()); +  for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) { +    ConstantInt *Weight = +        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i)); +    if (!Weight) +      return false; +    assert(Weight->getValue().getActiveBits() <= 32 && +           "Too many bits for uint32_t"); +    Weights.push_back(Weight->getZExtValue()); +    WeightSum += Weights.back(); +    if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1))) +      UnreachableIdxs.push_back(i - 1); +    else +      ReachableIdxs.push_back(i - 1); +  } +  assert(Weights.size() == TI->getNumSuccessors() && "Checked above"); + +  // If the sum of weights does not fit in 32 bits, scale every weight down +  // accordingly. +  uint64_t ScalingFactor = +      (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1; + +  if (ScalingFactor > 1) { +    WeightSum = 0; +    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { +      Weights[i] /= ScalingFactor; +      WeightSum += Weights[i]; +    } +  } +  assert(WeightSum <= UINT32_MAX && +         "Expected weights to scale down to 32 bits"); + +  if (WeightSum == 0 || ReachableIdxs.size() == 0) { +    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) +      Weights[i] = 1; +    WeightSum = TI->getNumSuccessors(); +  } + +  // Set the probability. +  SmallVector<BranchProbability, 2> BP; +  for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) +    BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) }); + +  // Examine the metadata against unreachable heuristic. +  // If the unreachable heuristic is more strong then we use it for this edge. +  if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) { +    auto ToDistribute = BranchProbability::getZero(); +    auto UnreachableProb = UR_TAKEN_PROB; +    for (auto i : UnreachableIdxs) +      if (UnreachableProb < BP[i]) { +        ToDistribute += BP[i] - UnreachableProb; +        BP[i] = UnreachableProb; +      } + +    // If we modified the probability of some edges then we must distribute +    // the difference between reachable blocks. +    if (ToDistribute > BranchProbability::getZero()) { +      BranchProbability PerEdge = ToDistribute / ReachableIdxs.size(); +      for (auto i : ReachableIdxs) +        BP[i] += PerEdge; +    } +  } + +  for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) +    setEdgeProbability(BB, i, BP[i]); + +  return true; +} + +/// Calculate edge weights for edges leading to cold blocks. +/// +/// A cold block is one post-dominated by  a block with a call to a +/// cold function.  Those edges are unlikely to be taken, so we give +/// them relatively low weight. +/// +/// Return true if we could compute the weights for cold edges. +/// Return false, otherwise. +bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) { +  const TerminatorInst *TI = BB->getTerminator(); +  (void) TI; +  assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); +  assert(!isa<InvokeInst>(TI) && +         "Invokes should have already been handled by calcInvokeHeuristics"); + +  // Determine which successors are post-dominated by a cold block. +  SmallVector<unsigned, 4> ColdEdges; +  SmallVector<unsigned, 4> NormalEdges; +  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) +    if (PostDominatedByColdCall.count(*I)) +      ColdEdges.push_back(I.getSuccessorIndex()); +    else +      NormalEdges.push_back(I.getSuccessorIndex()); + +  // Skip probabilities if no cold edges. +  if (ColdEdges.empty()) +    return false; + +  if (NormalEdges.empty()) { +    BranchProbability Prob(1, ColdEdges.size()); +    for (unsigned SuccIdx : ColdEdges) +      setEdgeProbability(BB, SuccIdx, Prob); +    return true; +  } + +  auto ColdProb = BranchProbability::getBranchProbability( +      CC_TAKEN_WEIGHT, +      (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size())); +  auto NormalProb = BranchProbability::getBranchProbability( +      CC_NONTAKEN_WEIGHT, +      (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size())); + +  for (unsigned SuccIdx : ColdEdges) +    setEdgeProbability(BB, SuccIdx, ColdProb); +  for (unsigned SuccIdx : NormalEdges) +    setEdgeProbability(BB, SuccIdx, NormalProb); + +  return true; +} + +// Calculate Edge Weights using "Pointer Heuristics". Predict a comparison +// between two pointer or pointer and NULL will fail. +bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) { +  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); +  if (!BI || !BI->isConditional()) +    return false; + +  Value *Cond = BI->getCondition(); +  ICmpInst *CI = dyn_cast<ICmpInst>(Cond); +  if (!CI || !CI->isEquality()) +    return false; + +  Value *LHS = CI->getOperand(0); + +  if (!LHS->getType()->isPointerTy()) +    return false; + +  assert(CI->getOperand(1)->getType()->isPointerTy()); + +  // p != 0   ->   isProb = true +  // p == 0   ->   isProb = false +  // p != q   ->   isProb = true +  // p == q   ->   isProb = false; +  unsigned TakenIdx = 0, NonTakenIdx = 1; +  bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE; +  if (!isProb) +    std::swap(TakenIdx, NonTakenIdx); + +  BranchProbability TakenProb(PH_TAKEN_WEIGHT, +                              PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT); +  setEdgeProbability(BB, TakenIdx, TakenProb); +  setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl()); +  return true; +} + +static int getSCCNum(const BasicBlock *BB, +                     const BranchProbabilityInfo::SccInfo &SccI) { +  auto SccIt = SccI.SccNums.find(BB); +  if (SccIt == SccI.SccNums.end()) +    return -1; +  return SccIt->second; +} + +// Consider any block that is an entry point to the SCC as a header. +static bool isSCCHeader(const BasicBlock *BB, int SccNum, +                        BranchProbabilityInfo::SccInfo &SccI) { +  assert(getSCCNum(BB, SccI) == SccNum); + +  // Lazily compute the set of headers for a given SCC and cache the results +  // in the SccHeaderMap. +  if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum)) +    SccI.SccHeaders.resize(SccNum + 1); +  auto &HeaderMap = SccI.SccHeaders[SccNum]; +  bool Inserted; +  BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt; +  std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false)); +  if (Inserted) { +    bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)), +                                 [&](const BasicBlock *Pred) { +                                   return getSCCNum(Pred, SccI) != SccNum; +                                 }); +    HeaderMapIt->second = IsHeader; +    return IsHeader; +  } else +    return HeaderMapIt->second; +} + +// Compute the unlikely successors to the block BB in the loop L, specifically +// those that are unlikely because this is a loop, and add them to the +// UnlikelyBlocks set. +static void +computeUnlikelySuccessors(const BasicBlock *BB, Loop *L, +                          SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) { +  // Sometimes in a loop we have a branch whose condition is made false by +  // taking it. This is typically something like +  //  int n = 0; +  //  while (...) { +  //    if (++n >= MAX) { +  //      n = 0; +  //    } +  //  } +  // In this sort of situation taking the branch means that at the very least it +  // won't be taken again in the next iteration of the loop, so we should +  // consider it less likely than a typical branch. +  // +  // We detect this by looking back through the graph of PHI nodes that sets the +  // value that the condition depends on, and seeing if we can reach a successor +  // block which can be determined to make the condition false. +  // +  // FIXME: We currently consider unlikely blocks to be half as likely as other +  // blocks, but if we consider the example above the likelyhood is actually +  // 1/MAX. We could therefore be more precise in how unlikely we consider +  // blocks to be, but it would require more careful examination of the form +  // of the comparison expression. +  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); +  if (!BI || !BI->isConditional()) +    return; + +  // Check if the branch is based on an instruction compared with a constant +  CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition()); +  if (!CI || !isa<Instruction>(CI->getOperand(0)) || +      !isa<Constant>(CI->getOperand(1))) +    return; + +  // Either the instruction must be a PHI, or a chain of operations involving +  // constants that ends in a PHI which we can then collapse into a single value +  // if the PHI value is known. +  Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0)); +  PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS); +  Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1)); +  // Collect the instructions until we hit a PHI +  SmallVector<BinaryOperator *, 1> InstChain; +  while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) && +         isa<Constant>(CmpLHS->getOperand(1))) { +    // Stop if the chain extends outside of the loop +    if (!L->contains(CmpLHS)) +      return; +    InstChain.push_back(cast<BinaryOperator>(CmpLHS)); +    CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0)); +    if (CmpLHS) +      CmpPHI = dyn_cast<PHINode>(CmpLHS); +  } +  if (!CmpPHI || !L->contains(CmpPHI)) +    return; + +  // Trace the phi node to find all values that come from successors of BB +  SmallPtrSet<PHINode*, 8> VisitedInsts; +  SmallVector<PHINode*, 8> WorkList; +  WorkList.push_back(CmpPHI); +  VisitedInsts.insert(CmpPHI); +  while (!WorkList.empty()) { +    PHINode *P = WorkList.back(); +    WorkList.pop_back(); +    for (BasicBlock *B : P->blocks()) { +      // Skip blocks that aren't part of the loop +      if (!L->contains(B)) +        continue; +      Value *V = P->getIncomingValueForBlock(B); +      // If the source is a PHI add it to the work list if we haven't +      // already visited it. +      if (PHINode *PN = dyn_cast<PHINode>(V)) { +        if (VisitedInsts.insert(PN).second) +          WorkList.push_back(PN); +        continue; +      } +      // If this incoming value is a constant and B is a successor of BB, then +      // we can constant-evaluate the compare to see if it makes the branch be +      // taken or not. +      Constant *CmpLHSConst = dyn_cast<Constant>(V); +      if (!CmpLHSConst || +          std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB)) +        continue; +      // First collapse InstChain +      for (Instruction *I : llvm::reverse(InstChain)) { +        CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst, +                                        cast<Constant>(I->getOperand(1)), true); +        if (!CmpLHSConst) +          break; +      } +      if (!CmpLHSConst) +        continue; +      // Now constant-evaluate the compare +      Constant *Result = ConstantExpr::getCompare(CI->getPredicate(), +                                                  CmpLHSConst, CmpConst, true); +      // If the result means we don't branch to the block then that block is +      // unlikely. +      if (Result && +          ((Result->isZeroValue() && B == BI->getSuccessor(0)) || +           (Result->isOneValue() && B == BI->getSuccessor(1)))) +        UnlikelyBlocks.insert(B); +    } +  } +} + +// Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges +// as taken, exiting edges as not-taken. +bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB, +                                                     const LoopInfo &LI, +                                                     SccInfo &SccI) { +  int SccNum; +  Loop *L = LI.getLoopFor(BB); +  if (!L) { +    SccNum = getSCCNum(BB, SccI); +    if (SccNum < 0) +      return false; +  } + +  SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks; +  if (L) +    computeUnlikelySuccessors(BB, L, UnlikelyBlocks); + +  SmallVector<unsigned, 8> BackEdges; +  SmallVector<unsigned, 8> ExitingEdges; +  SmallVector<unsigned, 8> InEdges; // Edges from header to the loop. +  SmallVector<unsigned, 8> UnlikelyEdges; + +  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { +    // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch +    // irreducible loops. +    if (L) { +      if (UnlikelyBlocks.count(*I) != 0) +        UnlikelyEdges.push_back(I.getSuccessorIndex()); +      else if (!L->contains(*I)) +        ExitingEdges.push_back(I.getSuccessorIndex()); +      else if (L->getHeader() == *I) +        BackEdges.push_back(I.getSuccessorIndex()); +      else +        InEdges.push_back(I.getSuccessorIndex()); +    } else { +      if (getSCCNum(*I, SccI) != SccNum) +        ExitingEdges.push_back(I.getSuccessorIndex()); +      else if (isSCCHeader(*I, SccNum, SccI)) +        BackEdges.push_back(I.getSuccessorIndex()); +      else +        InEdges.push_back(I.getSuccessorIndex()); +    } +  } + +  if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty()) +    return false; + +  // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and +  // normalize them so that they sum up to one. +  unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) + +                   (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) + +                   (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) + +                   (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT); + +  if (uint32_t numBackEdges = BackEdges.size()) { +    BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom); +    auto Prob = TakenProb / numBackEdges; +    for (unsigned SuccIdx : BackEdges) +      setEdgeProbability(BB, SuccIdx, Prob); +  } + +  if (uint32_t numInEdges = InEdges.size()) { +    BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom); +    auto Prob = TakenProb / numInEdges; +    for (unsigned SuccIdx : InEdges) +      setEdgeProbability(BB, SuccIdx, Prob); +  } + +  if (uint32_t numExitingEdges = ExitingEdges.size()) { +    BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT, +                                                       Denom); +    auto Prob = NotTakenProb / numExitingEdges; +    for (unsigned SuccIdx : ExitingEdges) +      setEdgeProbability(BB, SuccIdx, Prob); +  } + +  if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) { +    BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT, +                                                       Denom); +    auto Prob = UnlikelyProb / numUnlikelyEdges; +    for (unsigned SuccIdx : UnlikelyEdges) +      setEdgeProbability(BB, SuccIdx, Prob); +  } + +  return true; +} + +bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB, +                                               const TargetLibraryInfo *TLI) { +  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); +  if (!BI || !BI->isConditional()) +    return false; + +  Value *Cond = BI->getCondition(); +  ICmpInst *CI = dyn_cast<ICmpInst>(Cond); +  if (!CI) +    return false; + +  Value *RHS = CI->getOperand(1); +  ConstantInt *CV = dyn_cast<ConstantInt>(RHS); +  if (!CV) +    return false; + +  // If the LHS is the result of AND'ing a value with a single bit bitmask, +  // we don't have information about probabilities. +  if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0))) +    if (LHS->getOpcode() == Instruction::And) +      if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1))) +        if (AndRHS->getValue().isPowerOf2()) +          return false; + +  // Check if the LHS is the return value of a library function +  LibFunc Func = NumLibFuncs; +  if (TLI) +    if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0))) +      if (Function *CalledFn = Call->getCalledFunction()) +        TLI->getLibFunc(*CalledFn, Func); + +  bool isProb; +  if (Func == LibFunc_strcasecmp || +      Func == LibFunc_strcmp || +      Func == LibFunc_strncasecmp || +      Func == LibFunc_strncmp || +      Func == LibFunc_memcmp) { +    // strcmp and similar functions return zero, negative, or positive, if the +    // first string is equal, less, or greater than the second. We consider it +    // likely that the strings are not equal, so a comparison with zero is +    // probably false, but also a comparison with any other number is also +    // probably false given that what exactly is returned for nonzero values is +    // not specified. Any kind of comparison other than equality we know +    // nothing about. +    switch (CI->getPredicate()) { +    case CmpInst::ICMP_EQ: +      isProb = false; +      break; +    case CmpInst::ICMP_NE: +      isProb = true; +      break; +    default: +      return false; +    } +  } else if (CV->isZero()) { +    switch (CI->getPredicate()) { +    case CmpInst::ICMP_EQ: +      // X == 0   ->  Unlikely +      isProb = false; +      break; +    case CmpInst::ICMP_NE: +      // X != 0   ->  Likely +      isProb = true; +      break; +    case CmpInst::ICMP_SLT: +      // X < 0   ->  Unlikely +      isProb = false; +      break; +    case CmpInst::ICMP_SGT: +      // X > 0   ->  Likely +      isProb = true; +      break; +    default: +      return false; +    } +  } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) { +    // InstCombine canonicalizes X <= 0 into X < 1. +    // X <= 0   ->  Unlikely +    isProb = false; +  } else if (CV->isMinusOne()) { +    switch (CI->getPredicate()) { +    case CmpInst::ICMP_EQ: +      // X == -1  ->  Unlikely +      isProb = false; +      break; +    case CmpInst::ICMP_NE: +      // X != -1  ->  Likely +      isProb = true; +      break; +    case CmpInst::ICMP_SGT: +      // InstCombine canonicalizes X >= 0 into X > -1. +      // X >= 0   ->  Likely +      isProb = true; +      break; +    default: +      return false; +    } +  } else { +    return false; +  } + +  unsigned TakenIdx = 0, NonTakenIdx = 1; + +  if (!isProb) +    std::swap(TakenIdx, NonTakenIdx); + +  BranchProbability TakenProb(ZH_TAKEN_WEIGHT, +                              ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT); +  setEdgeProbability(BB, TakenIdx, TakenProb); +  setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl()); +  return true; +} + +bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) { +  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); +  if (!BI || !BI->isConditional()) +    return false; + +  Value *Cond = BI->getCondition(); +  FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond); +  if (!FCmp) +    return false; + +  bool isProb; +  if (FCmp->isEquality()) { +    // f1 == f2 -> Unlikely +    // f1 != f2 -> Likely +    isProb = !FCmp->isTrueWhenEqual(); +  } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) { +    // !isnan -> Likely +    isProb = true; +  } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) { +    // isnan -> Unlikely +    isProb = false; +  } else { +    return false; +  } + +  unsigned TakenIdx = 0, NonTakenIdx = 1; + +  if (!isProb) +    std::swap(TakenIdx, NonTakenIdx); + +  BranchProbability TakenProb(FPH_TAKEN_WEIGHT, +                              FPH_TAKEN_WEIGHT + FPH_NONTAKEN_WEIGHT); +  setEdgeProbability(BB, TakenIdx, TakenProb); +  setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl()); +  return true; +} + +bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) { +  const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()); +  if (!II) +    return false; + +  BranchProbability TakenProb(IH_TAKEN_WEIGHT, +                              IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT); +  setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb); +  setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl()); +  return true; +} + +void BranchProbabilityInfo::releaseMemory() { +  Probs.clear(); +} + +void BranchProbabilityInfo::print(raw_ostream &OS) const { +  OS << "---- Branch Probabilities ----\n"; +  // We print the probabilities from the last function the analysis ran over, +  // or the function it is currently running over. +  assert(LastF && "Cannot print prior to running over a function"); +  for (const auto &BI : *LastF) { +    for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE; +         ++SI) { +      printEdgeProbability(OS << "  ", &BI, *SI); +    } +  } +} + +bool BranchProbabilityInfo:: +isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const { +  // Hot probability is at least 4/5 = 80% +  // FIXME: Compare against a static "hot" BranchProbability. +  return getEdgeProbability(Src, Dst) > BranchProbability(4, 5); +} + +const BasicBlock * +BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const { +  auto MaxProb = BranchProbability::getZero(); +  const BasicBlock *MaxSucc = nullptr; + +  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { +    const BasicBlock *Succ = *I; +    auto Prob = getEdgeProbability(BB, Succ); +    if (Prob > MaxProb) { +      MaxProb = Prob; +      MaxSucc = Succ; +    } +  } + +  // Hot probability is at least 4/5 = 80% +  if (MaxProb > BranchProbability(4, 5)) +    return MaxSucc; + +  return nullptr; +} + +/// Get the raw edge probability for the edge. If can't find it, return a +/// default probability 1/N where N is the number of successors. Here an edge is +/// specified using PredBlock and an +/// index to the successors. +BranchProbability +BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src, +                                          unsigned IndexInSuccessors) const { +  auto I = Probs.find(std::make_pair(Src, IndexInSuccessors)); + +  if (I != Probs.end()) +    return I->second; + +  return {1, static_cast<uint32_t>(succ_size(Src))}; +} + +BranchProbability +BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src, +                                          succ_const_iterator Dst) const { +  return getEdgeProbability(Src, Dst.getSuccessorIndex()); +} + +/// Get the raw edge probability calculated for the block pair. This returns the +/// sum of all raw edge probabilities from Src to Dst. +BranchProbability +BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src, +                                          const BasicBlock *Dst) const { +  auto Prob = BranchProbability::getZero(); +  bool FoundProb = false; +  for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I) +    if (*I == Dst) { +      auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex())); +      if (MapI != Probs.end()) { +        FoundProb = true; +        Prob += MapI->second; +      } +    } +  uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src)); +  return FoundProb ? Prob : BranchProbability(1, succ_num); +} + +/// Set the edge probability for a given edge specified by PredBlock and an +/// index to the successors. +void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src, +                                               unsigned IndexInSuccessors, +                                               BranchProbability Prob) { +  Probs[std::make_pair(Src, IndexInSuccessors)] = Prob; +  Handles.insert(BasicBlockCallbackVH(Src, this)); +  LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> " +                    << IndexInSuccessors << " successor probability to " << Prob +                    << "\n"); +} + +raw_ostream & +BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS, +                                            const BasicBlock *Src, +                                            const BasicBlock *Dst) const { +  const BranchProbability Prob = getEdgeProbability(Src, Dst); +  OS << "edge " << Src->getName() << " -> " << Dst->getName() +     << " probability is " << Prob +     << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n"); + +  return OS; +} + +void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) { +  for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) { +    auto Key = I->first; +    if (Key.first == BB) +      Probs.erase(Key); +  } +} + +void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI, +                                      const TargetLibraryInfo *TLI) { +  LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName() +                    << " ----\n\n"); +  LastF = &F; // Store the last function we ran on for printing. +  assert(PostDominatedByUnreachable.empty()); +  assert(PostDominatedByColdCall.empty()); + +  // Record SCC numbers of blocks in the CFG to identify irreducible loops. +  // FIXME: We could only calculate this if the CFG is known to be irreducible +  // (perhaps cache this info in LoopInfo if we can easily calculate it there?). +  int SccNum = 0; +  SccInfo SccI; +  for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd(); +       ++It, ++SccNum) { +    // Ignore single-block SCCs since they either aren't loops or LoopInfo will +    // catch them. +    const std::vector<const BasicBlock *> &Scc = *It; +    if (Scc.size() == 1) +      continue; + +    LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":"); +    for (auto *BB : Scc) { +      LLVM_DEBUG(dbgs() << " " << BB->getName()); +      SccI.SccNums[BB] = SccNum; +    } +    LLVM_DEBUG(dbgs() << "\n"); +  } + +  // Walk the basic blocks in post-order so that we can build up state about +  // the successors of a block iteratively. +  for (auto BB : post_order(&F.getEntryBlock())) { +    LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName() +                      << "\n"); +    updatePostDominatedByUnreachable(BB); +    updatePostDominatedByColdCall(BB); +    // If there is no at least two successors, no sense to set probability. +    if (BB->getTerminator()->getNumSuccessors() < 2) +      continue; +    if (calcMetadataWeights(BB)) +      continue; +    if (calcInvokeHeuristics(BB)) +      continue; +    if (calcUnreachableHeuristics(BB)) +      continue; +    if (calcColdCallHeuristics(BB)) +      continue; +    if (calcLoopBranchHeuristics(BB, LI, SccI)) +      continue; +    if (calcPointerHeuristics(BB)) +      continue; +    if (calcZeroHeuristics(BB, TLI)) +      continue; +    if (calcFloatingPointHeuristics(BB)) +      continue; +  } + +  PostDominatedByUnreachable.clear(); +  PostDominatedByColdCall.clear(); + +  if (PrintBranchProb && +      (PrintBranchProbFuncName.empty() || +       F.getName().equals(PrintBranchProbFuncName))) { +    print(dbgs()); +  } +} + +void BranchProbabilityInfoWrapperPass::getAnalysisUsage( +    AnalysisUsage &AU) const { +  // We require DT so it's available when LI is available. The LI updating code +  // asserts that DT is also present so if we don't make sure that we have DT +  // here, that assert will trigger. +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.addRequired<LoopInfoWrapperPass>(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +  AU.setPreservesAll(); +} + +bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) { +  const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  const TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); +  BPI.calculate(F, LI, &TLI); +  return false; +} + +void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); } + +void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS, +                                             const Module *) const { +  BPI.print(OS); +} + +AnalysisKey BranchProbabilityAnalysis::Key; +BranchProbabilityInfo +BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) { +  BranchProbabilityInfo BPI; +  BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F)); +  return BPI; +} + +PreservedAnalyses +BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { +  OS << "Printing analysis results of BPI for function " +     << "'" << F.getName() << "':" +     << "\n"; +  AM.getResult<BranchProbabilityAnalysis>(F).print(OS); +  return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Analysis/CFG.cpp b/contrib/llvm/lib/Analysis/CFG.cpp new file mode 100644 index 000000000000..a319be8092f9 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CFG.cpp @@ -0,0 +1,236 @@ +//===-- CFG.cpp - BasicBlock analysis --------------------------------------==// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This family of functions performs analyses on basic blocks, and instructions +// contained within basic blocks. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CFG.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +/// FindFunctionBackedges - Analyze the specified function to find all of the +/// loop backedges in the function and return them.  This is a relatively cheap +/// (compared to computing dominators and loop info) analysis. +/// +/// The output is added to Result, as pairs of <from,to> edge info. +void llvm::FindFunctionBackedges(const Function &F, +     SmallVectorImpl<std::pair<const BasicBlock*,const BasicBlock*> > &Result) { +  const BasicBlock *BB = &F.getEntryBlock(); +  if (succ_empty(BB)) +    return; + +  SmallPtrSet<const BasicBlock*, 8> Visited; +  SmallVector<std::pair<const BasicBlock*, succ_const_iterator>, 8> VisitStack; +  SmallPtrSet<const BasicBlock*, 8> InStack; + +  Visited.insert(BB); +  VisitStack.push_back(std::make_pair(BB, succ_begin(BB))); +  InStack.insert(BB); +  do { +    std::pair<const BasicBlock*, succ_const_iterator> &Top = VisitStack.back(); +    const BasicBlock *ParentBB = Top.first; +    succ_const_iterator &I = Top.second; + +    bool FoundNew = false; +    while (I != succ_end(ParentBB)) { +      BB = *I++; +      if (Visited.insert(BB).second) { +        FoundNew = true; +        break; +      } +      // Successor is in VisitStack, it's a back edge. +      if (InStack.count(BB)) +        Result.push_back(std::make_pair(ParentBB, BB)); +    } + +    if (FoundNew) { +      // Go down one level if there is a unvisited successor. +      InStack.insert(BB); +      VisitStack.push_back(std::make_pair(BB, succ_begin(BB))); +    } else { +      // Go up one level. +      InStack.erase(VisitStack.pop_back_val().first); +    } +  } while (!VisitStack.empty()); +} + +/// GetSuccessorNumber - Search for the specified successor of basic block BB +/// and return its position in the terminator instruction's list of +/// successors.  It is an error to call this with a block that is not a +/// successor. +unsigned llvm::GetSuccessorNumber(const BasicBlock *BB, +    const BasicBlock *Succ) { +  const TerminatorInst *Term = BB->getTerminator(); +#ifndef NDEBUG +  unsigned e = Term->getNumSuccessors(); +#endif +  for (unsigned i = 0; ; ++i) { +    assert(i != e && "Didn't find edge?"); +    if (Term->getSuccessor(i) == Succ) +      return i; +  } +} + +/// isCriticalEdge - Return true if the specified edge is a critical edge. +/// Critical edges are edges from a block with multiple successors to a block +/// with multiple predecessors. +bool llvm::isCriticalEdge(const TerminatorInst *TI, unsigned SuccNum, +                          bool AllowIdenticalEdges) { +  assert(SuccNum < TI->getNumSuccessors() && "Illegal edge specification!"); +  if (TI->getNumSuccessors() == 1) return false; + +  const BasicBlock *Dest = TI->getSuccessor(SuccNum); +  const_pred_iterator I = pred_begin(Dest), E = pred_end(Dest); + +  // If there is more than one predecessor, this is a critical edge... +  assert(I != E && "No preds, but we have an edge to the block?"); +  const BasicBlock *FirstPred = *I; +  ++I;        // Skip one edge due to the incoming arc from TI. +  if (!AllowIdenticalEdges) +    return I != E; + +  // If AllowIdenticalEdges is true, then we allow this edge to be considered +  // non-critical iff all preds come from TI's block. +  for (; I != E; ++I) +    if (*I != FirstPred) +      return true; +  return false; +} + +// LoopInfo contains a mapping from basic block to the innermost loop. Find +// the outermost loop in the loop nest that contains BB. +static const Loop *getOutermostLoop(const LoopInfo *LI, const BasicBlock *BB) { +  const Loop *L = LI->getLoopFor(BB); +  if (L) { +    while (const Loop *Parent = L->getParentLoop()) +      L = Parent; +  } +  return L; +} + +// True if there is a loop which contains both BB1 and BB2. +static bool loopContainsBoth(const LoopInfo *LI, +                             const BasicBlock *BB1, const BasicBlock *BB2) { +  const Loop *L1 = getOutermostLoop(LI, BB1); +  const Loop *L2 = getOutermostLoop(LI, BB2); +  return L1 != nullptr && L1 == L2; +} + +bool llvm::isPotentiallyReachableFromMany( +    SmallVectorImpl<BasicBlock *> &Worklist, BasicBlock *StopBB, +    const DominatorTree *DT, const LoopInfo *LI) { +  // When the stop block is unreachable, it's dominated from everywhere, +  // regardless of whether there's a path between the two blocks. +  if (DT && !DT->isReachableFromEntry(StopBB)) +    DT = nullptr; + +  // Limit the number of blocks we visit. The goal is to avoid run-away compile +  // times on large CFGs without hampering sensible code. Arbitrarily chosen. +  unsigned Limit = 32; +  SmallPtrSet<const BasicBlock*, 32> Visited; +  do { +    BasicBlock *BB = Worklist.pop_back_val(); +    if (!Visited.insert(BB).second) +      continue; +    if (BB == StopBB) +      return true; +    if (DT && DT->dominates(BB, StopBB)) +      return true; +    if (LI && loopContainsBoth(LI, BB, StopBB)) +      return true; + +    if (!--Limit) { +      // We haven't been able to prove it one way or the other. Conservatively +      // answer true -- that there is potentially a path. +      return true; +    } + +    if (const Loop *Outer = LI ? getOutermostLoop(LI, BB) : nullptr) { +      // All blocks in a single loop are reachable from all other blocks. From +      // any of these blocks, we can skip directly to the exits of the loop, +      // ignoring any other blocks inside the loop body. +      Outer->getExitBlocks(Worklist); +    } else { +      Worklist.append(succ_begin(BB), succ_end(BB)); +    } +  } while (!Worklist.empty()); + +  // We have exhausted all possible paths and are certain that 'To' can not be +  // reached from 'From'. +  return false; +} + +bool llvm::isPotentiallyReachable(const BasicBlock *A, const BasicBlock *B, +                                  const DominatorTree *DT, const LoopInfo *LI) { +  assert(A->getParent() == B->getParent() && +         "This analysis is function-local!"); + +  SmallVector<BasicBlock*, 32> Worklist; +  Worklist.push_back(const_cast<BasicBlock*>(A)); + +  return isPotentiallyReachableFromMany(Worklist, const_cast<BasicBlock *>(B), +                                        DT, LI); +} + +bool llvm::isPotentiallyReachable(const Instruction *A, const Instruction *B, +                                  const DominatorTree *DT, const LoopInfo *LI) { +  assert(A->getParent()->getParent() == B->getParent()->getParent() && +         "This analysis is function-local!"); + +  SmallVector<BasicBlock*, 32> Worklist; + +  if (A->getParent() == B->getParent()) { +    // The same block case is special because it's the only time we're looking +    // within a single block to see which instruction comes first. Once we +    // start looking at multiple blocks, the first instruction of the block is +    // reachable, so we only need to determine reachability between whole +    // blocks. +    BasicBlock *BB = const_cast<BasicBlock *>(A->getParent()); + +    // If the block is in a loop then we can reach any instruction in the block +    // from any other instruction in the block by going around a backedge. +    if (LI && LI->getLoopFor(BB) != nullptr) +      return true; + +    // Linear scan, start at 'A', see whether we hit 'B' or the end first. +    for (BasicBlock::const_iterator I = A->getIterator(), E = BB->end(); I != E; +         ++I) { +      if (&*I == B) +        return true; +    } + +    // Can't be in a loop if it's the entry block -- the entry block may not +    // have predecessors. +    if (BB == &BB->getParent()->getEntryBlock()) +      return false; + +    // Otherwise, continue doing the normal per-BB CFG walk. +    Worklist.append(succ_begin(BB), succ_end(BB)); + +    if (Worklist.empty()) { +      // We've proven that there's no path! +      return false; +    } +  } else { +    Worklist.push_back(const_cast<BasicBlock*>(A->getParent())); +  } + +  if (A->getParent() == &A->getParent()->getParent()->getEntryBlock()) +    return true; +  if (B->getParent() == &A->getParent()->getParent()->getEntryBlock()) +    return false; + +  return isPotentiallyReachableFromMany( +      Worklist, const_cast<BasicBlock *>(B->getParent()), DT, LI); +} diff --git a/contrib/llvm/lib/Analysis/CFGPrinter.cpp b/contrib/llvm/lib/Analysis/CFGPrinter.cpp new file mode 100644 index 000000000000..5b170dfa7903 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CFGPrinter.cpp @@ -0,0 +1,195 @@ +//===- CFGPrinter.cpp - DOT printer for the control flow graph ------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines a '-dot-cfg' analysis pass, which emits the +// cfg.<fnname>.dot file for each function in the program, with a graph of the +// CFG for that function. +// +// The other main feature of this file is that it implements the +// Function::viewCFG method, which is useful for debugging passes which operate +// on the CFG. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CFGPrinter.h" +#include "llvm/Pass.h" +#include "llvm/Support/FileSystem.h" +using namespace llvm; + +static cl::opt<std::string> CFGFuncName( +    "cfg-func-name", cl::Hidden, +    cl::desc("The name of a function (or its substring)" +             " whose CFG is viewed/printed.")); + +namespace { +  struct CFGViewerLegacyPass : public FunctionPass { +    static char ID; // Pass identifcation, replacement for typeid +    CFGViewerLegacyPass() : FunctionPass(ID) { +      initializeCFGViewerLegacyPassPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnFunction(Function &F) override { +      F.viewCFG(); +      return false; +    } + +    void print(raw_ostream &OS, const Module* = nullptr) const override {} + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } +  }; +} + +char CFGViewerLegacyPass::ID = 0; +INITIALIZE_PASS(CFGViewerLegacyPass, "view-cfg", "View CFG of function", false, true) + +PreservedAnalyses CFGViewerPass::run(Function &F, +                                     FunctionAnalysisManager &AM) { +  F.viewCFG(); +  return PreservedAnalyses::all(); +} + + +namespace { +  struct CFGOnlyViewerLegacyPass : public FunctionPass { +    static char ID; // Pass identifcation, replacement for typeid +    CFGOnlyViewerLegacyPass() : FunctionPass(ID) { +      initializeCFGOnlyViewerLegacyPassPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnFunction(Function &F) override { +      F.viewCFGOnly(); +      return false; +    } + +    void print(raw_ostream &OS, const Module* = nullptr) const override {} + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } +  }; +} + +char CFGOnlyViewerLegacyPass::ID = 0; +INITIALIZE_PASS(CFGOnlyViewerLegacyPass, "view-cfg-only", +                "View CFG of function (with no function bodies)", false, true) + +PreservedAnalyses CFGOnlyViewerPass::run(Function &F, +                                         FunctionAnalysisManager &AM) { +  F.viewCFGOnly(); +  return PreservedAnalyses::all(); +} + +static void writeCFGToDotFile(Function &F, bool CFGOnly = false) { +  if (!CFGFuncName.empty() && !F.getName().contains(CFGFuncName)) +     return; +  std::string Filename = ("cfg." + F.getName() + ".dot").str(); +  errs() << "Writing '" << Filename << "'..."; + +  std::error_code EC; +  raw_fd_ostream File(Filename, EC, sys::fs::F_Text); + +  if (!EC) +    WriteGraph(File, (const Function*)&F, CFGOnly); +  else +    errs() << "  error opening file for writing!"; +  errs() << "\n"; +} + +namespace { +  struct CFGPrinterLegacyPass : public FunctionPass { +    static char ID; // Pass identification, replacement for typeid +    CFGPrinterLegacyPass() : FunctionPass(ID) { +      initializeCFGPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnFunction(Function &F) override { +      writeCFGToDotFile(F); +      return false; +    } + +    void print(raw_ostream &OS, const Module* = nullptr) const override {} + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } +  }; +} + +char CFGPrinterLegacyPass::ID = 0; +INITIALIZE_PASS(CFGPrinterLegacyPass, "dot-cfg", "Print CFG of function to 'dot' file", +                false, true) + +PreservedAnalyses CFGPrinterPass::run(Function &F, +                                      FunctionAnalysisManager &AM) { +  writeCFGToDotFile(F); +  return PreservedAnalyses::all(); +} + +namespace { +  struct CFGOnlyPrinterLegacyPass : public FunctionPass { +    static char ID; // Pass identification, replacement for typeid +    CFGOnlyPrinterLegacyPass() : FunctionPass(ID) { +      initializeCFGOnlyPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnFunction(Function &F) override { +      writeCFGToDotFile(F, /*CFGOnly=*/true); +      return false; +    } +    void print(raw_ostream &OS, const Module* = nullptr) const override {} + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } +  }; +} + +char CFGOnlyPrinterLegacyPass::ID = 0; +INITIALIZE_PASS(CFGOnlyPrinterLegacyPass, "dot-cfg-only", +   "Print CFG of function to 'dot' file (with no function bodies)", +   false, true) + +PreservedAnalyses CFGOnlyPrinterPass::run(Function &F, +                                          FunctionAnalysisManager &AM) { +  writeCFGToDotFile(F, /*CFGOnly=*/true); +  return PreservedAnalyses::all(); +} + +/// viewCFG - This function is meant for use from the debugger.  You can just +/// say 'call F->viewCFG()' and a ghostview window should pop up from the +/// program, displaying the CFG of the current function.  This depends on there +/// being a 'dot' and 'gv' program in your path. +/// +void Function::viewCFG() const { +  if (!CFGFuncName.empty() && !getName().contains(CFGFuncName)) +     return; +  ViewGraph(this, "cfg" + getName()); +} + +/// viewCFGOnly - This function is meant for use from the debugger.  It works +/// just like viewCFG, but it does not include the contents of basic blocks +/// into the nodes, just the label.  If you are only interested in the CFG +/// this can make the graph smaller. +/// +void Function::viewCFGOnly() const { +  if (!CFGFuncName.empty() && !getName().contains(CFGFuncName)) +     return; +  ViewGraph(this, "cfg" + getName(), true); +} + +FunctionPass *llvm::createCFGPrinterLegacyPassPass () { +  return new CFGPrinterLegacyPass(); +} + +FunctionPass *llvm::createCFGOnlyPrinterLegacyPassPass () { +  return new CFGOnlyPrinterLegacyPass(); +} + diff --git a/contrib/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp new file mode 100644 index 000000000000..194983418b08 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CFLAndersAliasAnalysis.cpp @@ -0,0 +1,922 @@ +//===- CFLAndersAliasAnalysis.cpp - Unification-based Alias Analysis ------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-based, summary-based alias analysis algorithm. It +// differs from CFLSteensAliasAnalysis in its inclusion-based nature while +// CFLSteensAliasAnalysis is unification-based. This pass has worse performance +// than CFLSteensAliasAnalysis (the worst case complexity of +// CFLAndersAliasAnalysis is cubic, while the worst case complexity of +// CFLSteensAliasAnalysis is almost linear), but it is able to yield more +// precise analysis result. The precision of this analysis is roughly the same +// as that of an one level context-sensitive Andersen's algorithm. +// +// The algorithm used here is based on recursive state machine matching scheme +// proposed in "Demand-driven alias analysis for C" by Xin Zheng and Radu +// Rugina. The general idea is to extend the traditional transitive closure +// algorithm to perform CFL matching along the way: instead of recording +// "whether X is reachable from Y", we keep track of "whether X is reachable +// from Y at state Z", where the "state" field indicates where we are in the CFL +// matching process. To understand the matching better, it is advisable to have +// the state machine shown in Figure 3 of the paper available when reading the +// codes: all we do here is to selectively expand the transitive closure by +// discarding edges that are not recognized by the state machine. +// +// There are two differences between our current implementation and the one +// described in the paper: +// - Our algorithm eagerly computes all alias pairs after the CFLGraph is built, +// while in the paper the authors did the computation in a demand-driven +// fashion. We did not implement the demand-driven algorithm due to the +// additional coding complexity and higher memory profile, but if we found it +// necessary we may switch to it eventually. +// - In the paper the authors use a state machine that does not distinguish +// value reads from value writes. For example, if Y is reachable from X at state +// S3, it may be the case that X is written into Y, or it may be the case that +// there's a third value Z that writes into both X and Y. To make that +// distinction (which is crucial in building function summary as well as +// retrieving mod-ref info), we choose to duplicate some of the states in the +// paper's proposed state machine. The duplication does not change the set the +// machine accepts. Given a pair of reachable values, it only provides more +// detailed information on which value is being written into and which is being +// read from. +// +//===----------------------------------------------------------------------===// + +// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and +// CFLAndersAA is interprocedural. This is *technically* A Bad Thing, because +// FunctionPasses are only allowed to inspect the Function that they're being +// run on. Realistically, this likely isn't a problem until we allow +// FunctionPasses to run concurrently. + +#include "llvm/Analysis/CFLAndersAliasAnalysis.h" +#include "AliasAnalysisSummary.h" +#include "CFLGraph.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <bitset> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <functional> +#include <utility> +#include <vector> + +using namespace llvm; +using namespace llvm::cflaa; + +#define DEBUG_TYPE "cfl-anders-aa" + +CFLAndersAAResult::CFLAndersAAResult(const TargetLibraryInfo &TLI) : TLI(TLI) {} +CFLAndersAAResult::CFLAndersAAResult(CFLAndersAAResult &&RHS) +    : AAResultBase(std::move(RHS)), TLI(RHS.TLI) {} +CFLAndersAAResult::~CFLAndersAAResult() = default; + +namespace { + +enum class MatchState : uint8_t { +  // The following state represents S1 in the paper. +  FlowFromReadOnly = 0, +  // The following two states together represent S2 in the paper. +  // The 'NoReadWrite' suffix indicates that there exists an alias path that +  // does not contain assignment and reverse assignment edges. +  // The 'ReadOnly' suffix indicates that there exists an alias path that +  // contains reverse assignment edges only. +  FlowFromMemAliasNoReadWrite, +  FlowFromMemAliasReadOnly, +  // The following two states together represent S3 in the paper. +  // The 'WriteOnly' suffix indicates that there exists an alias path that +  // contains assignment edges only. +  // The 'ReadWrite' suffix indicates that there exists an alias path that +  // contains both assignment and reverse assignment edges. Note that if X and Y +  // are reachable at 'ReadWrite' state, it does NOT mean X is both read from +  // and written to Y. Instead, it means that a third value Z is written to both +  // X and Y. +  FlowToWriteOnly, +  FlowToReadWrite, +  // The following two states together represent S4 in the paper. +  FlowToMemAliasWriteOnly, +  FlowToMemAliasReadWrite, +}; + +using StateSet = std::bitset<7>; + +const unsigned ReadOnlyStateMask = +    (1U << static_cast<uint8_t>(MatchState::FlowFromReadOnly)) | +    (1U << static_cast<uint8_t>(MatchState::FlowFromMemAliasReadOnly)); +const unsigned WriteOnlyStateMask = +    (1U << static_cast<uint8_t>(MatchState::FlowToWriteOnly)) | +    (1U << static_cast<uint8_t>(MatchState::FlowToMemAliasWriteOnly)); + +// A pair that consists of a value and an offset +struct OffsetValue { +  const Value *Val; +  int64_t Offset; +}; + +bool operator==(OffsetValue LHS, OffsetValue RHS) { +  return LHS.Val == RHS.Val && LHS.Offset == RHS.Offset; +} +bool operator<(OffsetValue LHS, OffsetValue RHS) { +  return std::less<const Value *>()(LHS.Val, RHS.Val) || +         (LHS.Val == RHS.Val && LHS.Offset < RHS.Offset); +} + +// A pair that consists of an InstantiatedValue and an offset +struct OffsetInstantiatedValue { +  InstantiatedValue IVal; +  int64_t Offset; +}; + +bool operator==(OffsetInstantiatedValue LHS, OffsetInstantiatedValue RHS) { +  return LHS.IVal == RHS.IVal && LHS.Offset == RHS.Offset; +} + +// We use ReachabilitySet to keep track of value aliases (The nonterminal "V" in +// the paper) during the analysis. +class ReachabilitySet { +  using ValueStateMap = DenseMap<InstantiatedValue, StateSet>; +  using ValueReachMap = DenseMap<InstantiatedValue, ValueStateMap>; + +  ValueReachMap ReachMap; + +public: +  using const_valuestate_iterator = ValueStateMap::const_iterator; +  using const_value_iterator = ValueReachMap::const_iterator; + +  // Insert edge 'From->To' at state 'State' +  bool insert(InstantiatedValue From, InstantiatedValue To, MatchState State) { +    assert(From != To); +    auto &States = ReachMap[To][From]; +    auto Idx = static_cast<size_t>(State); +    if (!States.test(Idx)) { +      States.set(Idx); +      return true; +    } +    return false; +  } + +  // Return the set of all ('From', 'State') pair for a given node 'To' +  iterator_range<const_valuestate_iterator> +  reachableValueAliases(InstantiatedValue V) const { +    auto Itr = ReachMap.find(V); +    if (Itr == ReachMap.end()) +      return make_range<const_valuestate_iterator>(const_valuestate_iterator(), +                                                   const_valuestate_iterator()); +    return make_range<const_valuestate_iterator>(Itr->second.begin(), +                                                 Itr->second.end()); +  } + +  iterator_range<const_value_iterator> value_mappings() const { +    return make_range<const_value_iterator>(ReachMap.begin(), ReachMap.end()); +  } +}; + +// We use AliasMemSet to keep track of all memory aliases (the nonterminal "M" +// in the paper) during the analysis. +class AliasMemSet { +  using MemSet = DenseSet<InstantiatedValue>; +  using MemMapType = DenseMap<InstantiatedValue, MemSet>; + +  MemMapType MemMap; + +public: +  using const_mem_iterator = MemSet::const_iterator; + +  bool insert(InstantiatedValue LHS, InstantiatedValue RHS) { +    // Top-level values can never be memory aliases because one cannot take the +    // addresses of them +    assert(LHS.DerefLevel > 0 && RHS.DerefLevel > 0); +    return MemMap[LHS].insert(RHS).second; +  } + +  const MemSet *getMemoryAliases(InstantiatedValue V) const { +    auto Itr = MemMap.find(V); +    if (Itr == MemMap.end()) +      return nullptr; +    return &Itr->second; +  } +}; + +// We use AliasAttrMap to keep track of the AliasAttr of each node. +class AliasAttrMap { +  using MapType = DenseMap<InstantiatedValue, AliasAttrs>; + +  MapType AttrMap; + +public: +  using const_iterator = MapType::const_iterator; + +  bool add(InstantiatedValue V, AliasAttrs Attr) { +    auto &OldAttr = AttrMap[V]; +    auto NewAttr = OldAttr | Attr; +    if (OldAttr == NewAttr) +      return false; +    OldAttr = NewAttr; +    return true; +  } + +  AliasAttrs getAttrs(InstantiatedValue V) const { +    AliasAttrs Attr; +    auto Itr = AttrMap.find(V); +    if (Itr != AttrMap.end()) +      Attr = Itr->second; +    return Attr; +  } + +  iterator_range<const_iterator> mappings() const { +    return make_range<const_iterator>(AttrMap.begin(), AttrMap.end()); +  } +}; + +struct WorkListItem { +  InstantiatedValue From; +  InstantiatedValue To; +  MatchState State; +}; + +struct ValueSummary { +  struct Record { +    InterfaceValue IValue; +    unsigned DerefLevel; +  }; +  SmallVector<Record, 4> FromRecords, ToRecords; +}; + +} // end anonymous namespace + +namespace llvm { + +// Specialize DenseMapInfo for OffsetValue. +template <> struct DenseMapInfo<OffsetValue> { +  static OffsetValue getEmptyKey() { +    return OffsetValue{DenseMapInfo<const Value *>::getEmptyKey(), +                       DenseMapInfo<int64_t>::getEmptyKey()}; +  } + +  static OffsetValue getTombstoneKey() { +    return OffsetValue{DenseMapInfo<const Value *>::getTombstoneKey(), +                       DenseMapInfo<int64_t>::getEmptyKey()}; +  } + +  static unsigned getHashValue(const OffsetValue &OVal) { +    return DenseMapInfo<std::pair<const Value *, int64_t>>::getHashValue( +        std::make_pair(OVal.Val, OVal.Offset)); +  } + +  static bool isEqual(const OffsetValue &LHS, const OffsetValue &RHS) { +    return LHS == RHS; +  } +}; + +// Specialize DenseMapInfo for OffsetInstantiatedValue. +template <> struct DenseMapInfo<OffsetInstantiatedValue> { +  static OffsetInstantiatedValue getEmptyKey() { +    return OffsetInstantiatedValue{ +        DenseMapInfo<InstantiatedValue>::getEmptyKey(), +        DenseMapInfo<int64_t>::getEmptyKey()}; +  } + +  static OffsetInstantiatedValue getTombstoneKey() { +    return OffsetInstantiatedValue{ +        DenseMapInfo<InstantiatedValue>::getTombstoneKey(), +        DenseMapInfo<int64_t>::getEmptyKey()}; +  } + +  static unsigned getHashValue(const OffsetInstantiatedValue &OVal) { +    return DenseMapInfo<std::pair<InstantiatedValue, int64_t>>::getHashValue( +        std::make_pair(OVal.IVal, OVal.Offset)); +  } + +  static bool isEqual(const OffsetInstantiatedValue &LHS, +                      const OffsetInstantiatedValue &RHS) { +    return LHS == RHS; +  } +}; + +} // end namespace llvm + +class CFLAndersAAResult::FunctionInfo { +  /// Map a value to other values that may alias it +  /// Since the alias relation is symmetric, to save some space we assume values +  /// are properly ordered: if a and b alias each other, and a < b, then b is in +  /// AliasMap[a] but not vice versa. +  DenseMap<const Value *, std::vector<OffsetValue>> AliasMap; + +  /// Map a value to its corresponding AliasAttrs +  DenseMap<const Value *, AliasAttrs> AttrMap; + +  /// Summary of externally visible effects. +  AliasSummary Summary; + +  Optional<AliasAttrs> getAttrs(const Value *) const; + +public: +  FunctionInfo(const Function &, const SmallVectorImpl<Value *> &, +               const ReachabilitySet &, const AliasAttrMap &); + +  bool mayAlias(const Value *, LocationSize, const Value *, LocationSize) const; +  const AliasSummary &getAliasSummary() const { return Summary; } +}; + +static bool hasReadOnlyState(StateSet Set) { +  return (Set & StateSet(ReadOnlyStateMask)).any(); +} + +static bool hasWriteOnlyState(StateSet Set) { +  return (Set & StateSet(WriteOnlyStateMask)).any(); +} + +static Optional<InterfaceValue> +getInterfaceValue(InstantiatedValue IValue, +                  const SmallVectorImpl<Value *> &RetVals) { +  auto Val = IValue.Val; + +  Optional<unsigned> Index; +  if (auto Arg = dyn_cast<Argument>(Val)) +    Index = Arg->getArgNo() + 1; +  else if (is_contained(RetVals, Val)) +    Index = 0; + +  if (Index) +    return InterfaceValue{*Index, IValue.DerefLevel}; +  return None; +} + +static void populateAttrMap(DenseMap<const Value *, AliasAttrs> &AttrMap, +                            const AliasAttrMap &AMap) { +  for (const auto &Mapping : AMap.mappings()) { +    auto IVal = Mapping.first; + +    // Insert IVal into the map +    auto &Attr = AttrMap[IVal.Val]; +    // AttrMap only cares about top-level values +    if (IVal.DerefLevel == 0) +      Attr |= Mapping.second; +  } +} + +static void +populateAliasMap(DenseMap<const Value *, std::vector<OffsetValue>> &AliasMap, +                 const ReachabilitySet &ReachSet) { +  for (const auto &OuterMapping : ReachSet.value_mappings()) { +    // AliasMap only cares about top-level values +    if (OuterMapping.first.DerefLevel > 0) +      continue; + +    auto Val = OuterMapping.first.Val; +    auto &AliasList = AliasMap[Val]; +    for (const auto &InnerMapping : OuterMapping.second) { +      // Again, AliasMap only cares about top-level values +      if (InnerMapping.first.DerefLevel == 0) +        AliasList.push_back(OffsetValue{InnerMapping.first.Val, UnknownOffset}); +    } + +    // Sort AliasList for faster lookup +    llvm::sort(AliasList.begin(), AliasList.end()); +  } +} + +static void populateExternalRelations( +    SmallVectorImpl<ExternalRelation> &ExtRelations, const Function &Fn, +    const SmallVectorImpl<Value *> &RetVals, const ReachabilitySet &ReachSet) { +  // If a function only returns one of its argument X, then X will be both an +  // argument and a return value at the same time. This is an edge case that +  // needs special handling here. +  for (const auto &Arg : Fn.args()) { +    if (is_contained(RetVals, &Arg)) { +      auto ArgVal = InterfaceValue{Arg.getArgNo() + 1, 0}; +      auto RetVal = InterfaceValue{0, 0}; +      ExtRelations.push_back(ExternalRelation{ArgVal, RetVal, 0}); +    } +  } + +  // Below is the core summary construction logic. +  // A naive solution of adding only the value aliases that are parameters or +  // return values in ReachSet to the summary won't work: It is possible that a +  // parameter P is written into an intermediate value I, and the function +  // subsequently returns *I. In that case, *I is does not value alias anything +  // in ReachSet, and the naive solution will miss a summary edge from (P, 1) to +  // (I, 1). +  // To account for the aforementioned case, we need to check each non-parameter +  // and non-return value for the possibility of acting as an intermediate. +  // 'ValueMap' here records, for each value, which InterfaceValues read from or +  // write into it. If both the read list and the write list of a given value +  // are non-empty, we know that a particular value is an intermidate and we +  // need to add summary edges from the writes to the reads. +  DenseMap<Value *, ValueSummary> ValueMap; +  for (const auto &OuterMapping : ReachSet.value_mappings()) { +    if (auto Dst = getInterfaceValue(OuterMapping.first, RetVals)) { +      for (const auto &InnerMapping : OuterMapping.second) { +        // If Src is a param/return value, we get a same-level assignment. +        if (auto Src = getInterfaceValue(InnerMapping.first, RetVals)) { +          // This may happen if both Dst and Src are return values +          if (*Dst == *Src) +            continue; + +          if (hasReadOnlyState(InnerMapping.second)) +            ExtRelations.push_back(ExternalRelation{*Dst, *Src, UnknownOffset}); +          // No need to check for WriteOnly state, since ReachSet is symmetric +        } else { +          // If Src is not a param/return, add it to ValueMap +          auto SrcIVal = InnerMapping.first; +          if (hasReadOnlyState(InnerMapping.second)) +            ValueMap[SrcIVal.Val].FromRecords.push_back( +                ValueSummary::Record{*Dst, SrcIVal.DerefLevel}); +          if (hasWriteOnlyState(InnerMapping.second)) +            ValueMap[SrcIVal.Val].ToRecords.push_back( +                ValueSummary::Record{*Dst, SrcIVal.DerefLevel}); +        } +      } +    } +  } + +  for (const auto &Mapping : ValueMap) { +    for (const auto &FromRecord : Mapping.second.FromRecords) { +      for (const auto &ToRecord : Mapping.second.ToRecords) { +        auto ToLevel = ToRecord.DerefLevel; +        auto FromLevel = FromRecord.DerefLevel; +        // Same-level assignments should have already been processed by now +        if (ToLevel == FromLevel) +          continue; + +        auto SrcIndex = FromRecord.IValue.Index; +        auto SrcLevel = FromRecord.IValue.DerefLevel; +        auto DstIndex = ToRecord.IValue.Index; +        auto DstLevel = ToRecord.IValue.DerefLevel; +        if (ToLevel > FromLevel) +          SrcLevel += ToLevel - FromLevel; +        else +          DstLevel += FromLevel - ToLevel; + +        ExtRelations.push_back(ExternalRelation{ +            InterfaceValue{SrcIndex, SrcLevel}, +            InterfaceValue{DstIndex, DstLevel}, UnknownOffset}); +      } +    } +  } + +  // Remove duplicates in ExtRelations +  llvm::sort(ExtRelations.begin(), ExtRelations.end()); +  ExtRelations.erase(std::unique(ExtRelations.begin(), ExtRelations.end()), +                     ExtRelations.end()); +} + +static void populateExternalAttributes( +    SmallVectorImpl<ExternalAttribute> &ExtAttributes, const Function &Fn, +    const SmallVectorImpl<Value *> &RetVals, const AliasAttrMap &AMap) { +  for (const auto &Mapping : AMap.mappings()) { +    if (auto IVal = getInterfaceValue(Mapping.first, RetVals)) { +      auto Attr = getExternallyVisibleAttrs(Mapping.second); +      if (Attr.any()) +        ExtAttributes.push_back(ExternalAttribute{*IVal, Attr}); +    } +  } +} + +CFLAndersAAResult::FunctionInfo::FunctionInfo( +    const Function &Fn, const SmallVectorImpl<Value *> &RetVals, +    const ReachabilitySet &ReachSet, const AliasAttrMap &AMap) { +  populateAttrMap(AttrMap, AMap); +  populateExternalAttributes(Summary.RetParamAttributes, Fn, RetVals, AMap); +  populateAliasMap(AliasMap, ReachSet); +  populateExternalRelations(Summary.RetParamRelations, Fn, RetVals, ReachSet); +} + +Optional<AliasAttrs> +CFLAndersAAResult::FunctionInfo::getAttrs(const Value *V) const { +  assert(V != nullptr); + +  auto Itr = AttrMap.find(V); +  if (Itr != AttrMap.end()) +    return Itr->second; +  return None; +} + +bool CFLAndersAAResult::FunctionInfo::mayAlias(const Value *LHS, +                                               LocationSize LHSSize, +                                               const Value *RHS, +                                               LocationSize RHSSize) const { +  assert(LHS && RHS); + +  // Check if we've seen LHS and RHS before. Sometimes LHS or RHS can be created +  // after the analysis gets executed, and we want to be conservative in those +  // cases. +  auto MaybeAttrsA = getAttrs(LHS); +  auto MaybeAttrsB = getAttrs(RHS); +  if (!MaybeAttrsA || !MaybeAttrsB) +    return true; + +  // Check AliasAttrs before AliasMap lookup since it's cheaper +  auto AttrsA = *MaybeAttrsA; +  auto AttrsB = *MaybeAttrsB; +  if (hasUnknownOrCallerAttr(AttrsA)) +    return AttrsB.any(); +  if (hasUnknownOrCallerAttr(AttrsB)) +    return AttrsA.any(); +  if (isGlobalOrArgAttr(AttrsA)) +    return isGlobalOrArgAttr(AttrsB); +  if (isGlobalOrArgAttr(AttrsB)) +    return isGlobalOrArgAttr(AttrsA); + +  // At this point both LHS and RHS should point to locally allocated objects + +  auto Itr = AliasMap.find(LHS); +  if (Itr != AliasMap.end()) { + +    // Find out all (X, Offset) where X == RHS +    auto Comparator = [](OffsetValue LHS, OffsetValue RHS) { +      return std::less<const Value *>()(LHS.Val, RHS.Val); +    }; +#ifdef EXPENSIVE_CHECKS +    assert(std::is_sorted(Itr->second.begin(), Itr->second.end(), Comparator)); +#endif +    auto RangePair = std::equal_range(Itr->second.begin(), Itr->second.end(), +                                      OffsetValue{RHS, 0}, Comparator); + +    if (RangePair.first != RangePair.second) { +      // Be conservative about UnknownSize +      if (LHSSize == MemoryLocation::UnknownSize || +          RHSSize == MemoryLocation::UnknownSize) +        return true; + +      for (const auto &OVal : make_range(RangePair)) { +        // Be conservative about UnknownOffset +        if (OVal.Offset == UnknownOffset) +          return true; + +        // We know that LHS aliases (RHS + OVal.Offset) if the control flow +        // reaches here. The may-alias query essentially becomes integer +        // range-overlap queries over two ranges [OVal.Offset, OVal.Offset + +        // LHSSize) and [0, RHSSize). + +        // Try to be conservative on super large offsets +        if (LLVM_UNLIKELY(LHSSize > INT64_MAX || RHSSize > INT64_MAX)) +          return true; + +        auto LHSStart = OVal.Offset; +        // FIXME: Do we need to guard against integer overflow? +        auto LHSEnd = OVal.Offset + static_cast<int64_t>(LHSSize); +        auto RHSStart = 0; +        auto RHSEnd = static_cast<int64_t>(RHSSize); +        if (LHSEnd > RHSStart && LHSStart < RHSEnd) +          return true; +      } +    } +  } + +  return false; +} + +static void propagate(InstantiatedValue From, InstantiatedValue To, +                      MatchState State, ReachabilitySet &ReachSet, +                      std::vector<WorkListItem> &WorkList) { +  if (From == To) +    return; +  if (ReachSet.insert(From, To, State)) +    WorkList.push_back(WorkListItem{From, To, State}); +} + +static void initializeWorkList(std::vector<WorkListItem> &WorkList, +                               ReachabilitySet &ReachSet, +                               const CFLGraph &Graph) { +  for (const auto &Mapping : Graph.value_mappings()) { +    auto Val = Mapping.first; +    auto &ValueInfo = Mapping.second; +    assert(ValueInfo.getNumLevels() > 0); + +    // Insert all immediate assignment neighbors to the worklist +    for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { +      auto Src = InstantiatedValue{Val, I}; +      // If there's an assignment edge from X to Y, it means Y is reachable from +      // X at S2 and X is reachable from Y at S1 +      for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges) { +        propagate(Edge.Other, Src, MatchState::FlowFromReadOnly, ReachSet, +                  WorkList); +        propagate(Src, Edge.Other, MatchState::FlowToWriteOnly, ReachSet, +                  WorkList); +      } +    } +  } +} + +static Optional<InstantiatedValue> getNodeBelow(const CFLGraph &Graph, +                                                InstantiatedValue V) { +  auto NodeBelow = InstantiatedValue{V.Val, V.DerefLevel + 1}; +  if (Graph.getNode(NodeBelow)) +    return NodeBelow; +  return None; +} + +static void processWorkListItem(const WorkListItem &Item, const CFLGraph &Graph, +                                ReachabilitySet &ReachSet, AliasMemSet &MemSet, +                                std::vector<WorkListItem> &WorkList) { +  auto FromNode = Item.From; +  auto ToNode = Item.To; + +  auto NodeInfo = Graph.getNode(ToNode); +  assert(NodeInfo != nullptr); + +  // TODO: propagate field offsets + +  // FIXME: Here is a neat trick we can do: since both ReachSet and MemSet holds +  // relations that are symmetric, we could actually cut the storage by half by +  // sorting FromNode and ToNode before insertion happens. + +  // The newly added value alias pair may potentially generate more memory +  // alias pairs. Check for them here. +  auto FromNodeBelow = getNodeBelow(Graph, FromNode); +  auto ToNodeBelow = getNodeBelow(Graph, ToNode); +  if (FromNodeBelow && ToNodeBelow && +      MemSet.insert(*FromNodeBelow, *ToNodeBelow)) { +    propagate(*FromNodeBelow, *ToNodeBelow, +              MatchState::FlowFromMemAliasNoReadWrite, ReachSet, WorkList); +    for (const auto &Mapping : ReachSet.reachableValueAliases(*FromNodeBelow)) { +      auto Src = Mapping.first; +      auto MemAliasPropagate = [&](MatchState FromState, MatchState ToState) { +        if (Mapping.second.test(static_cast<size_t>(FromState))) +          propagate(Src, *ToNodeBelow, ToState, ReachSet, WorkList); +      }; + +      MemAliasPropagate(MatchState::FlowFromReadOnly, +                        MatchState::FlowFromMemAliasReadOnly); +      MemAliasPropagate(MatchState::FlowToWriteOnly, +                        MatchState::FlowToMemAliasWriteOnly); +      MemAliasPropagate(MatchState::FlowToReadWrite, +                        MatchState::FlowToMemAliasReadWrite); +    } +  } + +  // This is the core of the state machine walking algorithm. We expand ReachSet +  // based on which state we are at (which in turn dictates what edges we +  // should examine) +  // From a high-level point of view, the state machine here guarantees two +  // properties: +  // - If *X and *Y are memory aliases, then X and Y are value aliases +  // - If Y is an alias of X, then reverse assignment edges (if there is any) +  // should precede any assignment edges on the path from X to Y. +  auto NextAssignState = [&](MatchState State) { +    for (const auto &AssignEdge : NodeInfo->Edges) +      propagate(FromNode, AssignEdge.Other, State, ReachSet, WorkList); +  }; +  auto NextRevAssignState = [&](MatchState State) { +    for (const auto &RevAssignEdge : NodeInfo->ReverseEdges) +      propagate(FromNode, RevAssignEdge.Other, State, ReachSet, WorkList); +  }; +  auto NextMemState = [&](MatchState State) { +    if (auto AliasSet = MemSet.getMemoryAliases(ToNode)) { +      for (const auto &MemAlias : *AliasSet) +        propagate(FromNode, MemAlias, State, ReachSet, WorkList); +    } +  }; + +  switch (Item.State) { +  case MatchState::FlowFromReadOnly: +    NextRevAssignState(MatchState::FlowFromReadOnly); +    NextAssignState(MatchState::FlowToReadWrite); +    NextMemState(MatchState::FlowFromMemAliasReadOnly); +    break; + +  case MatchState::FlowFromMemAliasNoReadWrite: +    NextRevAssignState(MatchState::FlowFromReadOnly); +    NextAssignState(MatchState::FlowToWriteOnly); +    break; + +  case MatchState::FlowFromMemAliasReadOnly: +    NextRevAssignState(MatchState::FlowFromReadOnly); +    NextAssignState(MatchState::FlowToReadWrite); +    break; + +  case MatchState::FlowToWriteOnly: +    NextAssignState(MatchState::FlowToWriteOnly); +    NextMemState(MatchState::FlowToMemAliasWriteOnly); +    break; + +  case MatchState::FlowToReadWrite: +    NextAssignState(MatchState::FlowToReadWrite); +    NextMemState(MatchState::FlowToMemAliasReadWrite); +    break; + +  case MatchState::FlowToMemAliasWriteOnly: +    NextAssignState(MatchState::FlowToWriteOnly); +    break; + +  case MatchState::FlowToMemAliasReadWrite: +    NextAssignState(MatchState::FlowToReadWrite); +    break; +  } +} + +static AliasAttrMap buildAttrMap(const CFLGraph &Graph, +                                 const ReachabilitySet &ReachSet) { +  AliasAttrMap AttrMap; +  std::vector<InstantiatedValue> WorkList, NextList; + +  // Initialize each node with its original AliasAttrs in CFLGraph +  for (const auto &Mapping : Graph.value_mappings()) { +    auto Val = Mapping.first; +    auto &ValueInfo = Mapping.second; +    for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { +      auto Node = InstantiatedValue{Val, I}; +      AttrMap.add(Node, ValueInfo.getNodeInfoAtLevel(I).Attr); +      WorkList.push_back(Node); +    } +  } + +  while (!WorkList.empty()) { +    for (const auto &Dst : WorkList) { +      auto DstAttr = AttrMap.getAttrs(Dst); +      if (DstAttr.none()) +        continue; + +      // Propagate attr on the same level +      for (const auto &Mapping : ReachSet.reachableValueAliases(Dst)) { +        auto Src = Mapping.first; +        if (AttrMap.add(Src, DstAttr)) +          NextList.push_back(Src); +      } + +      // Propagate attr to the levels below +      auto DstBelow = getNodeBelow(Graph, Dst); +      while (DstBelow) { +        if (AttrMap.add(*DstBelow, DstAttr)) { +          NextList.push_back(*DstBelow); +          break; +        } +        DstBelow = getNodeBelow(Graph, *DstBelow); +      } +    } +    WorkList.swap(NextList); +    NextList.clear(); +  } + +  return AttrMap; +} + +CFLAndersAAResult::FunctionInfo +CFLAndersAAResult::buildInfoFrom(const Function &Fn) { +  CFLGraphBuilder<CFLAndersAAResult> GraphBuilder( +      *this, TLI, +      // Cast away the constness here due to GraphBuilder's API requirement +      const_cast<Function &>(Fn)); +  auto &Graph = GraphBuilder.getCFLGraph(); + +  ReachabilitySet ReachSet; +  AliasMemSet MemSet; + +  std::vector<WorkListItem> WorkList, NextList; +  initializeWorkList(WorkList, ReachSet, Graph); +  // TODO: make sure we don't stop before the fix point is reached +  while (!WorkList.empty()) { +    for (const auto &Item : WorkList) +      processWorkListItem(Item, Graph, ReachSet, MemSet, NextList); + +    NextList.swap(WorkList); +    NextList.clear(); +  } + +  // Now that we have all the reachability info, propagate AliasAttrs according +  // to it +  auto IValueAttrMap = buildAttrMap(Graph, ReachSet); + +  return FunctionInfo(Fn, GraphBuilder.getReturnValues(), ReachSet, +                      std::move(IValueAttrMap)); +} + +void CFLAndersAAResult::scan(const Function &Fn) { +  auto InsertPair = Cache.insert(std::make_pair(&Fn, Optional<FunctionInfo>())); +  (void)InsertPair; +  assert(InsertPair.second && +         "Trying to scan a function that has already been cached"); + +  // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call +  // may get evaluated after operator[], potentially triggering a DenseMap +  // resize and invalidating the reference returned by operator[] +  auto FunInfo = buildInfoFrom(Fn); +  Cache[&Fn] = std::move(FunInfo); +  Handles.emplace_front(const_cast<Function *>(&Fn), this); +} + +void CFLAndersAAResult::evict(const Function *Fn) { Cache.erase(Fn); } + +const Optional<CFLAndersAAResult::FunctionInfo> & +CFLAndersAAResult::ensureCached(const Function &Fn) { +  auto Iter = Cache.find(&Fn); +  if (Iter == Cache.end()) { +    scan(Fn); +    Iter = Cache.find(&Fn); +    assert(Iter != Cache.end()); +    assert(Iter->second.hasValue()); +  } +  return Iter->second; +} + +const AliasSummary *CFLAndersAAResult::getAliasSummary(const Function &Fn) { +  auto &FunInfo = ensureCached(Fn); +  if (FunInfo.hasValue()) +    return &FunInfo->getAliasSummary(); +  else +    return nullptr; +} + +AliasResult CFLAndersAAResult::query(const MemoryLocation &LocA, +                                     const MemoryLocation &LocB) { +  auto *ValA = LocA.Ptr; +  auto *ValB = LocB.Ptr; + +  if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy()) +    return NoAlias; + +  auto *Fn = parentFunctionOfValue(ValA); +  if (!Fn) { +    Fn = parentFunctionOfValue(ValB); +    if (!Fn) { +      // The only times this is known to happen are when globals + InlineAsm are +      // involved +      LLVM_DEBUG( +          dbgs() +          << "CFLAndersAA: could not extract parent function information.\n"); +      return MayAlias; +    } +  } else { +    assert(!parentFunctionOfValue(ValB) || parentFunctionOfValue(ValB) == Fn); +  } + +  assert(Fn != nullptr); +  auto &FunInfo = ensureCached(*Fn); + +  // AliasMap lookup +  if (FunInfo->mayAlias(ValA, LocA.Size, ValB, LocB.Size)) +    return MayAlias; +  return NoAlias; +} + +AliasResult CFLAndersAAResult::alias(const MemoryLocation &LocA, +                                     const MemoryLocation &LocB) { +  if (LocA.Ptr == LocB.Ptr) +    return MustAlias; + +  // Comparisons between global variables and other constants should be +  // handled by BasicAA. +  // CFLAndersAA may report NoAlias when comparing a GlobalValue and +  // ConstantExpr, but every query needs to have at least one Value tied to a +  // Function, and neither GlobalValues nor ConstantExprs are. +  if (isa<Constant>(LocA.Ptr) && isa<Constant>(LocB.Ptr)) +    return AAResultBase::alias(LocA, LocB); + +  AliasResult QueryResult = query(LocA, LocB); +  if (QueryResult == MayAlias) +    return AAResultBase::alias(LocA, LocB); + +  return QueryResult; +} + +AnalysisKey CFLAndersAA::Key; + +CFLAndersAAResult CFLAndersAA::run(Function &F, FunctionAnalysisManager &AM) { +  return CFLAndersAAResult(AM.getResult<TargetLibraryAnalysis>(F)); +} + +char CFLAndersAAWrapperPass::ID = 0; +INITIALIZE_PASS(CFLAndersAAWrapperPass, "cfl-anders-aa", +                "Inclusion-Based CFL Alias Analysis", false, true) + +ImmutablePass *llvm::createCFLAndersAAWrapperPass() { +  return new CFLAndersAAWrapperPass(); +} + +CFLAndersAAWrapperPass::CFLAndersAAWrapperPass() : ImmutablePass(ID) { +  initializeCFLAndersAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void CFLAndersAAWrapperPass::initializePass() { +  auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>(); +  Result.reset(new CFLAndersAAResult(TLIWP.getTLI())); +} + +void CFLAndersAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/contrib/llvm/lib/Analysis/CFLGraph.h b/contrib/llvm/lib/Analysis/CFLGraph.h new file mode 100644 index 000000000000..86812009da7c --- /dev/null +++ b/contrib/llvm/lib/Analysis/CFLGraph.h @@ -0,0 +1,654 @@ +//===- CFLGraph.h - Abstract stratified sets implementation. -----*- C++-*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file defines CFLGraph, an auxiliary data structure used by CFL-based +/// alias analysis. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_ANALYSIS_CFLGRAPH_H +#define LLVM_LIB_ANALYSIS_CFLGRAPH_H + +#include "AliasAnalysisSummary.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include <cassert> +#include <cstdint> +#include <vector> + +namespace llvm { +namespace cflaa { + +/// The Program Expression Graph (PEG) of CFL analysis +/// CFLGraph is auxiliary data structure used by CFL-based alias analysis to +/// describe flow-insensitive pointer-related behaviors. Given an LLVM function, +/// the main purpose of this graph is to abstract away unrelated facts and +/// translate the rest into a form that can be easily digested by CFL analyses. +/// Each Node in the graph is an InstantiatedValue, and each edge represent a +/// pointer assignment between InstantiatedValue. Pointer +/// references/dereferences are not explicitly stored in the graph: we +/// implicitly assume that for each node (X, I) it has a dereference edge to (X, +/// I+1) and a reference edge to (X, I-1). +class CFLGraph { +public: +  using Node = InstantiatedValue; + +  struct Edge { +    Node Other; +    int64_t Offset; +  }; + +  using EdgeList = std::vector<Edge>; + +  struct NodeInfo { +    EdgeList Edges, ReverseEdges; +    AliasAttrs Attr; +  }; + +  class ValueInfo { +    std::vector<NodeInfo> Levels; + +  public: +    bool addNodeToLevel(unsigned Level) { +      auto NumLevels = Levels.size(); +      if (NumLevels > Level) +        return false; +      Levels.resize(Level + 1); +      return true; +    } + +    NodeInfo &getNodeInfoAtLevel(unsigned Level) { +      assert(Level < Levels.size()); +      return Levels[Level]; +    } +    const NodeInfo &getNodeInfoAtLevel(unsigned Level) const { +      assert(Level < Levels.size()); +      return Levels[Level]; +    } + +    unsigned getNumLevels() const { return Levels.size(); } +  }; + +private: +  using ValueMap = DenseMap<Value *, ValueInfo>; + +  ValueMap ValueImpls; + +  NodeInfo *getNode(Node N) { +    auto Itr = ValueImpls.find(N.Val); +    if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel) +      return nullptr; +    return &Itr->second.getNodeInfoAtLevel(N.DerefLevel); +  } + +public: +  using const_value_iterator = ValueMap::const_iterator; + +  bool addNode(Node N, AliasAttrs Attr = AliasAttrs()) { +    assert(N.Val != nullptr); +    auto &ValInfo = ValueImpls[N.Val]; +    auto Changed = ValInfo.addNodeToLevel(N.DerefLevel); +    ValInfo.getNodeInfoAtLevel(N.DerefLevel).Attr |= Attr; +    return Changed; +  } + +  void addAttr(Node N, AliasAttrs Attr) { +    auto *Info = getNode(N); +    assert(Info != nullptr); +    Info->Attr |= Attr; +  } + +  void addEdge(Node From, Node To, int64_t Offset = 0) { +    auto *FromInfo = getNode(From); +    assert(FromInfo != nullptr); +    auto *ToInfo = getNode(To); +    assert(ToInfo != nullptr); + +    FromInfo->Edges.push_back(Edge{To, Offset}); +    ToInfo->ReverseEdges.push_back(Edge{From, Offset}); +  } + +  const NodeInfo *getNode(Node N) const { +    auto Itr = ValueImpls.find(N.Val); +    if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel) +      return nullptr; +    return &Itr->second.getNodeInfoAtLevel(N.DerefLevel); +  } + +  AliasAttrs attrFor(Node N) const { +    auto *Info = getNode(N); +    assert(Info != nullptr); +    return Info->Attr; +  } + +  iterator_range<const_value_iterator> value_mappings() const { +    return make_range<const_value_iterator>(ValueImpls.begin(), +                                            ValueImpls.end()); +  } +}; + +///A builder class used to create CFLGraph instance from a given function +/// The CFL-AA that uses this builder must provide its own type as a template +/// argument. This is necessary for interprocedural processing: CFLGraphBuilder +/// needs a way of obtaining the summary of other functions when callinsts are +/// encountered. +/// As a result, we expect the said CFL-AA to expose a getAliasSummary() public +/// member function that takes a Function& and returns the corresponding summary +/// as a const AliasSummary*. +template <typename CFLAA> class CFLGraphBuilder { +  // Input of the builder +  CFLAA &Analysis; +  const TargetLibraryInfo &TLI; + +  // Output of the builder +  CFLGraph Graph; +  SmallVector<Value *, 4> ReturnedValues; + +  // Helper class +  /// Gets the edges our graph should have, based on an Instruction* +  class GetEdgesVisitor : public InstVisitor<GetEdgesVisitor, void> { +    CFLAA &AA; +    const DataLayout &DL; +    const TargetLibraryInfo &TLI; + +    CFLGraph &Graph; +    SmallVectorImpl<Value *> &ReturnValues; + +    static bool hasUsefulEdges(ConstantExpr *CE) { +      // ConstantExpr doesn't have terminators, invokes, or fences, so only +      // needs +      // to check for compares. +      return CE->getOpcode() != Instruction::ICmp && +             CE->getOpcode() != Instruction::FCmp; +    } + +    // Returns possible functions called by CS into the given SmallVectorImpl. +    // Returns true if targets found, false otherwise. +    static bool getPossibleTargets(CallSite CS, +                                   SmallVectorImpl<Function *> &Output) { +      if (auto *Fn = CS.getCalledFunction()) { +        Output.push_back(Fn); +        return true; +      } + +      // TODO: If the call is indirect, we might be able to enumerate all +      // potential +      // targets of the call and return them, rather than just failing. +      return false; +    } + +    void addNode(Value *Val, AliasAttrs Attr = AliasAttrs()) { +      assert(Val != nullptr && Val->getType()->isPointerTy()); +      if (auto GVal = dyn_cast<GlobalValue>(Val)) { +        if (Graph.addNode(InstantiatedValue{GVal, 0}, +                          getGlobalOrArgAttrFromValue(*GVal))) +          Graph.addNode(InstantiatedValue{GVal, 1}, getAttrUnknown()); +      } else if (auto CExpr = dyn_cast<ConstantExpr>(Val)) { +        if (hasUsefulEdges(CExpr)) { +          if (Graph.addNode(InstantiatedValue{CExpr, 0})) +            visitConstantExpr(CExpr); +        } +      } else +        Graph.addNode(InstantiatedValue{Val, 0}, Attr); +    } + +    void addAssignEdge(Value *From, Value *To, int64_t Offset = 0) { +      assert(From != nullptr && To != nullptr); +      if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy()) +        return; +      addNode(From); +      if (To != From) { +        addNode(To); +        Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 0}, +                      Offset); +      } +    } + +    void addDerefEdge(Value *From, Value *To, bool IsRead) { +      assert(From != nullptr && To != nullptr); +      // FIXME: This is subtly broken, due to how we model some instructions +      // (e.g. extractvalue, extractelement) as loads. Since those take +      // non-pointer operands, we'll entirely skip adding edges for those. +      // +      // addAssignEdge seems to have a similar issue with insertvalue, etc. +      if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy()) +        return; +      addNode(From); +      addNode(To); +      if (IsRead) { +        Graph.addNode(InstantiatedValue{From, 1}); +        Graph.addEdge(InstantiatedValue{From, 1}, InstantiatedValue{To, 0}); +      } else { +        Graph.addNode(InstantiatedValue{To, 1}); +        Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 1}); +      } +    } + +    void addLoadEdge(Value *From, Value *To) { addDerefEdge(From, To, true); } +    void addStoreEdge(Value *From, Value *To) { addDerefEdge(From, To, false); } + +  public: +    GetEdgesVisitor(CFLGraphBuilder &Builder, const DataLayout &DL) +        : AA(Builder.Analysis), DL(DL), TLI(Builder.TLI), Graph(Builder.Graph), +          ReturnValues(Builder.ReturnedValues) {} + +    void visitInstruction(Instruction &) { +      llvm_unreachable("Unsupported instruction encountered"); +    } + +    void visitReturnInst(ReturnInst &Inst) { +      if (auto RetVal = Inst.getReturnValue()) { +        if (RetVal->getType()->isPointerTy()) { +          addNode(RetVal); +          ReturnValues.push_back(RetVal); +        } +      } +    } + +    void visitPtrToIntInst(PtrToIntInst &Inst) { +      auto *Ptr = Inst.getOperand(0); +      addNode(Ptr, getAttrEscaped()); +    } + +    void visitIntToPtrInst(IntToPtrInst &Inst) { +      auto *Ptr = &Inst; +      addNode(Ptr, getAttrUnknown()); +    } + +    void visitCastInst(CastInst &Inst) { +      auto *Src = Inst.getOperand(0); +      addAssignEdge(Src, &Inst); +    } + +    void visitBinaryOperator(BinaryOperator &Inst) { +      auto *Op1 = Inst.getOperand(0); +      auto *Op2 = Inst.getOperand(1); +      addAssignEdge(Op1, &Inst); +      addAssignEdge(Op2, &Inst); +    } + +    void visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { +      auto *Ptr = Inst.getPointerOperand(); +      auto *Val = Inst.getNewValOperand(); +      addStoreEdge(Val, Ptr); +    } + +    void visitAtomicRMWInst(AtomicRMWInst &Inst) { +      auto *Ptr = Inst.getPointerOperand(); +      auto *Val = Inst.getValOperand(); +      addStoreEdge(Val, Ptr); +    } + +    void visitPHINode(PHINode &Inst) { +      for (Value *Val : Inst.incoming_values()) +        addAssignEdge(Val, &Inst); +    } + +    void visitGEP(GEPOperator &GEPOp) { +      uint64_t Offset = UnknownOffset; +      APInt APOffset(DL.getPointerSizeInBits(GEPOp.getPointerAddressSpace()), +                     0); +      if (GEPOp.accumulateConstantOffset(DL, APOffset)) +        Offset = APOffset.getSExtValue(); + +      auto *Op = GEPOp.getPointerOperand(); +      addAssignEdge(Op, &GEPOp, Offset); +    } + +    void visitGetElementPtrInst(GetElementPtrInst &Inst) { +      auto *GEPOp = cast<GEPOperator>(&Inst); +      visitGEP(*GEPOp); +    } + +    void visitSelectInst(SelectInst &Inst) { +      // Condition is not processed here (The actual statement producing +      // the condition result is processed elsewhere). For select, the +      // condition is evaluated, but not loaded, stored, or assigned +      // simply as a result of being the condition of a select. + +      auto *TrueVal = Inst.getTrueValue(); +      auto *FalseVal = Inst.getFalseValue(); +      addAssignEdge(TrueVal, &Inst); +      addAssignEdge(FalseVal, &Inst); +    } + +    void visitAllocaInst(AllocaInst &Inst) { addNode(&Inst); } + +    void visitLoadInst(LoadInst &Inst) { +      auto *Ptr = Inst.getPointerOperand(); +      auto *Val = &Inst; +      addLoadEdge(Ptr, Val); +    } + +    void visitStoreInst(StoreInst &Inst) { +      auto *Ptr = Inst.getPointerOperand(); +      auto *Val = Inst.getValueOperand(); +      addStoreEdge(Val, Ptr); +    } + +    void visitVAArgInst(VAArgInst &Inst) { +      // We can't fully model va_arg here. For *Ptr = Inst.getOperand(0), it +      // does +      // two things: +      //  1. Loads a value from *((T*)*Ptr). +      //  2. Increments (stores to) *Ptr by some target-specific amount. +      // For now, we'll handle this like a landingpad instruction (by placing +      // the +      // result in its own group, and having that group alias externals). +      if (Inst.getType()->isPointerTy()) +        addNode(&Inst, getAttrUnknown()); +    } + +    static bool isFunctionExternal(Function *Fn) { +      return !Fn->hasExactDefinition(); +    } + +    bool tryInterproceduralAnalysis(CallSite CS, +                                    const SmallVectorImpl<Function *> &Fns) { +      assert(Fns.size() > 0); + +      if (CS.arg_size() > MaxSupportedArgsInSummary) +        return false; + +      // Exit early if we'll fail anyway +      for (auto *Fn : Fns) { +        if (isFunctionExternal(Fn) || Fn->isVarArg()) +          return false; +        // Fail if the caller does not provide enough arguments +        assert(Fn->arg_size() <= CS.arg_size()); +        if (!AA.getAliasSummary(*Fn)) +          return false; +      } + +      for (auto *Fn : Fns) { +        auto Summary = AA.getAliasSummary(*Fn); +        assert(Summary != nullptr); + +        auto &RetParamRelations = Summary->RetParamRelations; +        for (auto &Relation : RetParamRelations) { +          auto IRelation = instantiateExternalRelation(Relation, CS); +          if (IRelation.hasValue()) { +            Graph.addNode(IRelation->From); +            Graph.addNode(IRelation->To); +            Graph.addEdge(IRelation->From, IRelation->To); +          } +        } + +        auto &RetParamAttributes = Summary->RetParamAttributes; +        for (auto &Attribute : RetParamAttributes) { +          auto IAttr = instantiateExternalAttribute(Attribute, CS); +          if (IAttr.hasValue()) +            Graph.addNode(IAttr->IValue, IAttr->Attr); +        } +      } + +      return true; +    } + +    void visitCallSite(CallSite CS) { +      auto Inst = CS.getInstruction(); + +      // Make sure all arguments and return value are added to the graph first +      for (Value *V : CS.args()) +        if (V->getType()->isPointerTy()) +          addNode(V); +      if (Inst->getType()->isPointerTy()) +        addNode(Inst); + +      // Check if Inst is a call to a library function that +      // allocates/deallocates on the heap. Those kinds of functions do not +      // introduce any aliases. +      // TODO: address other common library functions such as realloc(), +      // strdup(), etc. +      if (isMallocOrCallocLikeFn(Inst, &TLI) || isFreeCall(Inst, &TLI)) +        return; + +      // TODO: Add support for noalias args/all the other fun function +      // attributes that we can tack on. +      SmallVector<Function *, 4> Targets; +      if (getPossibleTargets(CS, Targets)) +        if (tryInterproceduralAnalysis(CS, Targets)) +          return; + +      // Because the function is opaque, we need to note that anything +      // could have happened to the arguments (unless the function is marked +      // readonly or readnone), and that the result could alias just about +      // anything, too (unless the result is marked noalias). +      if (!CS.onlyReadsMemory()) +        for (Value *V : CS.args()) { +          if (V->getType()->isPointerTy()) { +            // The argument itself escapes. +            Graph.addAttr(InstantiatedValue{V, 0}, getAttrEscaped()); +            // The fate of argument memory is unknown. Note that since +            // AliasAttrs is transitive with respect to dereference, we only +            // need to specify it for the first-level memory. +            Graph.addNode(InstantiatedValue{V, 1}, getAttrUnknown()); +          } +        } + +      if (Inst->getType()->isPointerTy()) { +        auto *Fn = CS.getCalledFunction(); +        if (Fn == nullptr || !Fn->returnDoesNotAlias()) +          // No need to call addNode() since we've added Inst at the +          // beginning of this function and we know it is not a global. +          Graph.addAttr(InstantiatedValue{Inst, 0}, getAttrUnknown()); +      } +    } + +    /// Because vectors/aggregates are immutable and unaddressable, there's +    /// nothing we can do to coax a value out of them, other than calling +    /// Extract{Element,Value}. We can effectively treat them as pointers to +    /// arbitrary memory locations we can store in and load from. +    void visitExtractElementInst(ExtractElementInst &Inst) { +      auto *Ptr = Inst.getVectorOperand(); +      auto *Val = &Inst; +      addLoadEdge(Ptr, Val); +    } + +    void visitInsertElementInst(InsertElementInst &Inst) { +      auto *Vec = Inst.getOperand(0); +      auto *Val = Inst.getOperand(1); +      addAssignEdge(Vec, &Inst); +      addStoreEdge(Val, &Inst); +    } + +    void visitLandingPadInst(LandingPadInst &Inst) { +      // Exceptions come from "nowhere", from our analysis' perspective. +      // So we place the instruction its own group, noting that said group may +      // alias externals +      if (Inst.getType()->isPointerTy()) +        addNode(&Inst, getAttrUnknown()); +    } + +    void visitInsertValueInst(InsertValueInst &Inst) { +      auto *Agg = Inst.getOperand(0); +      auto *Val = Inst.getOperand(1); +      addAssignEdge(Agg, &Inst); +      addStoreEdge(Val, &Inst); +    } + +    void visitExtractValueInst(ExtractValueInst &Inst) { +      auto *Ptr = Inst.getAggregateOperand(); +      addLoadEdge(Ptr, &Inst); +    } + +    void visitShuffleVectorInst(ShuffleVectorInst &Inst) { +      auto *From1 = Inst.getOperand(0); +      auto *From2 = Inst.getOperand(1); +      addAssignEdge(From1, &Inst); +      addAssignEdge(From2, &Inst); +    } + +    void visitConstantExpr(ConstantExpr *CE) { +      switch (CE->getOpcode()) { +      case Instruction::GetElementPtr: { +        auto GEPOp = cast<GEPOperator>(CE); +        visitGEP(*GEPOp); +        break; +      } + +      case Instruction::PtrToInt: { +        addNode(CE->getOperand(0), getAttrEscaped()); +        break; +      } + +      case Instruction::IntToPtr: { +        addNode(CE, getAttrUnknown()); +        break; +      } + +      case Instruction::BitCast: +      case Instruction::AddrSpaceCast: +      case Instruction::Trunc: +      case Instruction::ZExt: +      case Instruction::SExt: +      case Instruction::FPExt: +      case Instruction::FPTrunc: +      case Instruction::UIToFP: +      case Instruction::SIToFP: +      case Instruction::FPToUI: +      case Instruction::FPToSI: { +        addAssignEdge(CE->getOperand(0), CE); +        break; +      } + +      case Instruction::Select: { +        addAssignEdge(CE->getOperand(1), CE); +        addAssignEdge(CE->getOperand(2), CE); +        break; +      } + +      case Instruction::InsertElement: +      case Instruction::InsertValue: { +        addAssignEdge(CE->getOperand(0), CE); +        addStoreEdge(CE->getOperand(1), CE); +        break; +      } + +      case Instruction::ExtractElement: +      case Instruction::ExtractValue: { +        addLoadEdge(CE->getOperand(0), CE); +        break; +      } + +      case Instruction::Add: +      case Instruction::Sub: +      case Instruction::FSub: +      case Instruction::Mul: +      case Instruction::FMul: +      case Instruction::UDiv: +      case Instruction::SDiv: +      case Instruction::FDiv: +      case Instruction::URem: +      case Instruction::SRem: +      case Instruction::FRem: +      case Instruction::And: +      case Instruction::Or: +      case Instruction::Xor: +      case Instruction::Shl: +      case Instruction::LShr: +      case Instruction::AShr: +      case Instruction::ICmp: +      case Instruction::FCmp: +      case Instruction::ShuffleVector: { +        addAssignEdge(CE->getOperand(0), CE); +        addAssignEdge(CE->getOperand(1), CE); +        break; +      } + +      default: +        llvm_unreachable("Unknown instruction type encountered!"); +      } +    } +  }; + +  // Helper functions + +  // Determines whether or not we an instruction is useless to us (e.g. +  // FenceInst) +  static bool hasUsefulEdges(Instruction *Inst) { +    bool IsNonInvokeRetTerminator = isa<TerminatorInst>(Inst) && +                                    !isa<InvokeInst>(Inst) && +                                    !isa<ReturnInst>(Inst); +    return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) && +           !IsNonInvokeRetTerminator; +  } + +  void addArgumentToGraph(Argument &Arg) { +    if (Arg.getType()->isPointerTy()) { +      Graph.addNode(InstantiatedValue{&Arg, 0}, +                    getGlobalOrArgAttrFromValue(Arg)); +      // Pointees of a formal parameter is known to the caller +      Graph.addNode(InstantiatedValue{&Arg, 1}, getAttrCaller()); +    } +  } + +  // Given an Instruction, this will add it to the graph, along with any +  // Instructions that are potentially only available from said Instruction +  // For example, given the following line: +  //   %0 = load i16* getelementptr ([1 x i16]* @a, 0, 0), align 2 +  // addInstructionToGraph would add both the `load` and `getelementptr` +  // instructions to the graph appropriately. +  void addInstructionToGraph(GetEdgesVisitor &Visitor, Instruction &Inst) { +    if (!hasUsefulEdges(&Inst)) +      return; + +    Visitor.visit(Inst); +  } + +  // Builds the graph needed for constructing the StratifiedSets for the given +  // function +  void buildGraphFrom(Function &Fn) { +    GetEdgesVisitor Visitor(*this, Fn.getParent()->getDataLayout()); + +    for (auto &Bb : Fn.getBasicBlockList()) +      for (auto &Inst : Bb.getInstList()) +        addInstructionToGraph(Visitor, Inst); + +    for (auto &Arg : Fn.args()) +      addArgumentToGraph(Arg); +  } + +public: +  CFLGraphBuilder(CFLAA &Analysis, const TargetLibraryInfo &TLI, Function &Fn) +      : Analysis(Analysis), TLI(TLI) { +    buildGraphFrom(Fn); +  } + +  const CFLGraph &getCFLGraph() const { return Graph; } +  const SmallVector<Value *, 4> &getReturnValues() const { +    return ReturnedValues; +  } +}; + +} // end namespace cflaa +} // end namespace llvm + +#endif // LLVM_LIB_ANALYSIS_CFLGRAPH_H diff --git a/contrib/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp new file mode 100644 index 000000000000..30ce13578e54 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CFLSteensAliasAnalysis.cpp @@ -0,0 +1,358 @@ +//===- CFLSteensAliasAnalysis.cpp - Unification-based Alias Analysis ------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-base, summary-based alias analysis algorithm. It +// does not depend on types. The algorithm is a mixture of the one described in +// "Demand-driven alias analysis for C" by Xin Zheng and Radu Rugina, and "Fast +// algorithms for Dyck-CFL-reachability with applications to Alias Analysis" by +// Zhang Q, Lyu M R, Yuan H, and Su Z. -- to summarize the papers, we build a +// graph of the uses of a variable, where each node is a memory location, and +// each edge is an action that happened on that memory location.  The "actions" +// can be one of Dereference, Reference, or Assign. The precision of this +// analysis is roughly the same as that of an one level context-sensitive +// Steensgaard's algorithm. +// +// Two variables are considered as aliasing iff you can reach one value's node +// from the other value's node and the language formed by concatenating all of +// the edge labels (actions) conforms to a context-free grammar. +// +// Because this algorithm requires a graph search on each query, we execute the +// algorithm outlined in "Fast algorithms..." (mentioned above) +// in order to transform the graph into sets of variables that may alias in +// ~nlogn time (n = number of variables), which makes queries take constant +// time. +//===----------------------------------------------------------------------===// + +// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and +// CFLSteensAA is interprocedural. This is *technically* A Bad Thing, because +// FunctionPasses are only allowed to inspect the Function that they're being +// run on. Realistically, this likely isn't a problem until we allow +// FunctionPasses to run concurrently. + +#include "llvm/Analysis/CFLSteensAliasAnalysis.h" +#include "AliasAnalysisSummary.h" +#include "CFLGraph.h" +#include "StratifiedSets.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <limits> +#include <memory> +#include <utility> + +using namespace llvm; +using namespace llvm::cflaa; + +#define DEBUG_TYPE "cfl-steens-aa" + +CFLSteensAAResult::CFLSteensAAResult(const TargetLibraryInfo &TLI) +    : AAResultBase(), TLI(TLI) {} +CFLSteensAAResult::CFLSteensAAResult(CFLSteensAAResult &&Arg) +    : AAResultBase(std::move(Arg)), TLI(Arg.TLI) {} +CFLSteensAAResult::~CFLSteensAAResult() = default; + +/// Information we have about a function and would like to keep around. +class CFLSteensAAResult::FunctionInfo { +  StratifiedSets<InstantiatedValue> Sets; +  AliasSummary Summary; + +public: +  FunctionInfo(Function &Fn, const SmallVectorImpl<Value *> &RetVals, +               StratifiedSets<InstantiatedValue> S); + +  const StratifiedSets<InstantiatedValue> &getStratifiedSets() const { +    return Sets; +  } + +  const AliasSummary &getAliasSummary() const { return Summary; } +}; + +const StratifiedIndex StratifiedLink::SetSentinel = +    std::numeric_limits<StratifiedIndex>::max(); + +//===----------------------------------------------------------------------===// +// Function declarations that require types defined in the namespace above +//===----------------------------------------------------------------------===// + +/// Determines whether it would be pointless to add the given Value to our sets. +static bool canSkipAddingToSets(Value *Val) { +  // Constants can share instances, which may falsely unify multiple +  // sets, e.g. in +  // store i32* null, i32** %ptr1 +  // store i32* null, i32** %ptr2 +  // clearly ptr1 and ptr2 should not be unified into the same set, so +  // we should filter out the (potentially shared) instance to +  // i32* null. +  if (isa<Constant>(Val)) { +    // TODO: Because all of these things are constant, we can determine whether +    // the data is *actually* mutable at graph building time. This will probably +    // come for free/cheap with offset awareness. +    bool CanStoreMutableData = isa<GlobalValue>(Val) || +                               isa<ConstantExpr>(Val) || +                               isa<ConstantAggregate>(Val); +    return !CanStoreMutableData; +  } + +  return false; +} + +CFLSteensAAResult::FunctionInfo::FunctionInfo( +    Function &Fn, const SmallVectorImpl<Value *> &RetVals, +    StratifiedSets<InstantiatedValue> S) +    : Sets(std::move(S)) { +  // Historically, an arbitrary upper-bound of 50 args was selected. We may want +  // to remove this if it doesn't really matter in practice. +  if (Fn.arg_size() > MaxSupportedArgsInSummary) +    return; + +  DenseMap<StratifiedIndex, InterfaceValue> InterfaceMap; + +  // Our intention here is to record all InterfaceValues that share the same +  // StratifiedIndex in RetParamRelations. For each valid InterfaceValue, we +  // have its StratifiedIndex scanned here and check if the index is presented +  // in InterfaceMap: if it is not, we add the correspondence to the map; +  // otherwise, an aliasing relation is found and we add it to +  // RetParamRelations. + +  auto AddToRetParamRelations = [&](unsigned InterfaceIndex, +                                    StratifiedIndex SetIndex) { +    unsigned Level = 0; +    while (true) { +      InterfaceValue CurrValue{InterfaceIndex, Level}; + +      auto Itr = InterfaceMap.find(SetIndex); +      if (Itr != InterfaceMap.end()) { +        if (CurrValue != Itr->second) +          Summary.RetParamRelations.push_back( +              ExternalRelation{CurrValue, Itr->second, UnknownOffset}); +        break; +      } + +      auto &Link = Sets.getLink(SetIndex); +      InterfaceMap.insert(std::make_pair(SetIndex, CurrValue)); +      auto ExternalAttrs = getExternallyVisibleAttrs(Link.Attrs); +      if (ExternalAttrs.any()) +        Summary.RetParamAttributes.push_back( +            ExternalAttribute{CurrValue, ExternalAttrs}); + +      if (!Link.hasBelow()) +        break; + +      ++Level; +      SetIndex = Link.Below; +    } +  }; + +  // Populate RetParamRelations for return values +  for (auto *RetVal : RetVals) { +    assert(RetVal != nullptr); +    assert(RetVal->getType()->isPointerTy()); +    auto RetInfo = Sets.find(InstantiatedValue{RetVal, 0}); +    if (RetInfo.hasValue()) +      AddToRetParamRelations(0, RetInfo->Index); +  } + +  // Populate RetParamRelations for parameters +  unsigned I = 0; +  for (auto &Param : Fn.args()) { +    if (Param.getType()->isPointerTy()) { +      auto ParamInfo = Sets.find(InstantiatedValue{&Param, 0}); +      if (ParamInfo.hasValue()) +        AddToRetParamRelations(I + 1, ParamInfo->Index); +    } +    ++I; +  } +} + +// Builds the graph + StratifiedSets for a function. +CFLSteensAAResult::FunctionInfo CFLSteensAAResult::buildSetsFrom(Function *Fn) { +  CFLGraphBuilder<CFLSteensAAResult> GraphBuilder(*this, TLI, *Fn); +  StratifiedSetsBuilder<InstantiatedValue> SetBuilder; + +  // Add all CFLGraph nodes and all Dereference edges to StratifiedSets +  auto &Graph = GraphBuilder.getCFLGraph(); +  for (const auto &Mapping : Graph.value_mappings()) { +    auto Val = Mapping.first; +    if (canSkipAddingToSets(Val)) +      continue; +    auto &ValueInfo = Mapping.second; + +    assert(ValueInfo.getNumLevels() > 0); +    SetBuilder.add(InstantiatedValue{Val, 0}); +    SetBuilder.noteAttributes(InstantiatedValue{Val, 0}, +                              ValueInfo.getNodeInfoAtLevel(0).Attr); +    for (unsigned I = 0, E = ValueInfo.getNumLevels() - 1; I < E; ++I) { +      SetBuilder.add(InstantiatedValue{Val, I + 1}); +      SetBuilder.noteAttributes(InstantiatedValue{Val, I + 1}, +                                ValueInfo.getNodeInfoAtLevel(I + 1).Attr); +      SetBuilder.addBelow(InstantiatedValue{Val, I}, +                          InstantiatedValue{Val, I + 1}); +    } +  } + +  // Add all assign edges to StratifiedSets +  for (const auto &Mapping : Graph.value_mappings()) { +    auto Val = Mapping.first; +    if (canSkipAddingToSets(Val)) +      continue; +    auto &ValueInfo = Mapping.second; + +    for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { +      auto Src = InstantiatedValue{Val, I}; +      for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges) +        SetBuilder.addWith(Src, Edge.Other); +    } +  } + +  return FunctionInfo(*Fn, GraphBuilder.getReturnValues(), SetBuilder.build()); +} + +void CFLSteensAAResult::scan(Function *Fn) { +  auto InsertPair = Cache.insert(std::make_pair(Fn, Optional<FunctionInfo>())); +  (void)InsertPair; +  assert(InsertPair.second && +         "Trying to scan a function that has already been cached"); + +  // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call +  // may get evaluated after operator[], potentially triggering a DenseMap +  // resize and invalidating the reference returned by operator[] +  auto FunInfo = buildSetsFrom(Fn); +  Cache[Fn] = std::move(FunInfo); + +  Handles.emplace_front(Fn, this); +} + +void CFLSteensAAResult::evict(Function *Fn) { Cache.erase(Fn); } + +/// Ensures that the given function is available in the cache, and returns the +/// entry. +const Optional<CFLSteensAAResult::FunctionInfo> & +CFLSteensAAResult::ensureCached(Function *Fn) { +  auto Iter = Cache.find(Fn); +  if (Iter == Cache.end()) { +    scan(Fn); +    Iter = Cache.find(Fn); +    assert(Iter != Cache.end()); +    assert(Iter->second.hasValue()); +  } +  return Iter->second; +} + +const AliasSummary *CFLSteensAAResult::getAliasSummary(Function &Fn) { +  auto &FunInfo = ensureCached(&Fn); +  if (FunInfo.hasValue()) +    return &FunInfo->getAliasSummary(); +  else +    return nullptr; +} + +AliasResult CFLSteensAAResult::query(const MemoryLocation &LocA, +                                     const MemoryLocation &LocB) { +  auto *ValA = const_cast<Value *>(LocA.Ptr); +  auto *ValB = const_cast<Value *>(LocB.Ptr); + +  if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy()) +    return NoAlias; + +  Function *Fn = nullptr; +  Function *MaybeFnA = const_cast<Function *>(parentFunctionOfValue(ValA)); +  Function *MaybeFnB = const_cast<Function *>(parentFunctionOfValue(ValB)); +  if (!MaybeFnA && !MaybeFnB) { +    // The only times this is known to happen are when globals + InlineAsm are +    // involved +    LLVM_DEBUG( +        dbgs() +        << "CFLSteensAA: could not extract parent function information.\n"); +    return MayAlias; +  } + +  if (MaybeFnA) { +    Fn = MaybeFnA; +    assert((!MaybeFnB || MaybeFnB == MaybeFnA) && +           "Interprocedural queries not supported"); +  } else { +    Fn = MaybeFnB; +  } + +  assert(Fn != nullptr); +  auto &MaybeInfo = ensureCached(Fn); +  assert(MaybeInfo.hasValue()); + +  auto &Sets = MaybeInfo->getStratifiedSets(); +  auto MaybeA = Sets.find(InstantiatedValue{ValA, 0}); +  if (!MaybeA.hasValue()) +    return MayAlias; + +  auto MaybeB = Sets.find(InstantiatedValue{ValB, 0}); +  if (!MaybeB.hasValue()) +    return MayAlias; + +  auto SetA = *MaybeA; +  auto SetB = *MaybeB; +  auto AttrsA = Sets.getLink(SetA.Index).Attrs; +  auto AttrsB = Sets.getLink(SetB.Index).Attrs; + +  // If both values are local (meaning the corresponding set has attribute +  // AttrNone or AttrEscaped), then we know that CFLSteensAA fully models them: +  // they may-alias each other if and only if they are in the same set. +  // If at least one value is non-local (meaning it either is global/argument or +  // it comes from unknown sources like integer cast), the situation becomes a +  // bit more interesting. We follow three general rules described below: +  // - Non-local values may alias each other +  // - AttrNone values do not alias any non-local values +  // - AttrEscaped do not alias globals/arguments, but they may alias +  // AttrUnknown values +  if (SetA.Index == SetB.Index) +    return MayAlias; +  if (AttrsA.none() || AttrsB.none()) +    return NoAlias; +  if (hasUnknownOrCallerAttr(AttrsA) || hasUnknownOrCallerAttr(AttrsB)) +    return MayAlias; +  if (isGlobalOrArgAttr(AttrsA) && isGlobalOrArgAttr(AttrsB)) +    return MayAlias; +  return NoAlias; +} + +AnalysisKey CFLSteensAA::Key; + +CFLSteensAAResult CFLSteensAA::run(Function &F, FunctionAnalysisManager &AM) { +  return CFLSteensAAResult(AM.getResult<TargetLibraryAnalysis>(F)); +} + +char CFLSteensAAWrapperPass::ID = 0; +INITIALIZE_PASS(CFLSteensAAWrapperPass, "cfl-steens-aa", +                "Unification-Based CFL Alias Analysis", false, true) + +ImmutablePass *llvm::createCFLSteensAAWrapperPass() { +  return new CFLSteensAAWrapperPass(); +} + +CFLSteensAAWrapperPass::CFLSteensAAWrapperPass() : ImmutablePass(ID) { +  initializeCFLSteensAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void CFLSteensAAWrapperPass::initializePass() { +  auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>(); +  Result.reset(new CFLSteensAAResult(TLIWP.getTLI())); +} + +void CFLSteensAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/contrib/llvm/lib/Analysis/CGSCCPassManager.cpp b/contrib/llvm/lib/Analysis/CGSCCPassManager.cpp new file mode 100644 index 000000000000..b325afb8e7c5 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CGSCCPassManager.cpp @@ -0,0 +1,689 @@ +//===- CGSCCPassManager.cpp - Managing & running CGSCC passes -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CGSCCPassManager.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <iterator> + +#define DEBUG_TYPE "cgscc" + +using namespace llvm; + +// Explicit template instantiations and specialization definitions for core +// template typedefs. +namespace llvm { + +// Explicit instantiations for the core proxy templates. +template class AllAnalysesOn<LazyCallGraph::SCC>; +template class AnalysisManager<LazyCallGraph::SCC, LazyCallGraph &>; +template class PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, +                           LazyCallGraph &, CGSCCUpdateResult &>; +template class InnerAnalysisManagerProxy<CGSCCAnalysisManager, Module>; +template class OuterAnalysisManagerProxy<ModuleAnalysisManager, +                                         LazyCallGraph::SCC, LazyCallGraph &>; +template class OuterAnalysisManagerProxy<CGSCCAnalysisManager, Function>; + +/// Explicitly specialize the pass manager run method to handle call graph +/// updates. +template <> +PreservedAnalyses +PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &, +            CGSCCUpdateResult &>::run(LazyCallGraph::SCC &InitialC, +                                      CGSCCAnalysisManager &AM, +                                      LazyCallGraph &G, CGSCCUpdateResult &UR) { +  PreservedAnalyses PA = PreservedAnalyses::all(); + +  if (DebugLogging) +    dbgs() << "Starting CGSCC pass manager run.\n"; + +  // The SCC may be refined while we are running passes over it, so set up +  // a pointer that we can update. +  LazyCallGraph::SCC *C = &InitialC; + +  for (auto &Pass : Passes) { +    if (DebugLogging) +      dbgs() << "Running pass: " << Pass->name() << " on " << *C << "\n"; + +    PreservedAnalyses PassPA = Pass->run(*C, AM, G, UR); + +    // Update the SCC if necessary. +    C = UR.UpdatedC ? UR.UpdatedC : C; + +    // If the CGSCC pass wasn't able to provide a valid updated SCC, the +    // current SCC may simply need to be skipped if invalid. +    if (UR.InvalidatedSCCs.count(C)) { +      LLVM_DEBUG(dbgs() << "Skipping invalidated root or island SCC!\n"); +      break; +    } +    // Check that we didn't miss any update scenario. +    assert(C->begin() != C->end() && "Cannot have an empty SCC!"); + +    // Update the analysis manager as each pass runs and potentially +    // invalidates analyses. +    AM.invalidate(*C, PassPA); + +    // Finally, we intersect the final preserved analyses to compute the +    // aggregate preserved set for this pass manager. +    PA.intersect(std::move(PassPA)); + +    // FIXME: Historically, the pass managers all called the LLVM context's +    // yield function here. We don't have a generic way to acquire the +    // context and it isn't yet clear what the right pattern is for yielding +    // in the new pass manager so it is currently omitted. +    // ...getContext().yield(); +  } + +  // Invalidation was handled after each pass in the above loop for the current +  // SCC. Therefore, the remaining analysis results in the AnalysisManager are +  // preserved. We mark this with a set so that we don't need to inspect each +  // one individually. +  PA.preserveSet<AllAnalysesOn<LazyCallGraph::SCC>>(); + +  if (DebugLogging) +    dbgs() << "Finished CGSCC pass manager run.\n"; + +  return PA; +} + +bool CGSCCAnalysisManagerModuleProxy::Result::invalidate( +    Module &M, const PreservedAnalyses &PA, +    ModuleAnalysisManager::Invalidator &Inv) { +  // If literally everything is preserved, we're done. +  if (PA.areAllPreserved()) +    return false; // This is still a valid proxy. + +  // If this proxy or the call graph is going to be invalidated, we also need +  // to clear all the keys coming from that analysis. +  // +  // We also directly invalidate the FAM's module proxy if necessary, and if +  // that proxy isn't preserved we can't preserve this proxy either. We rely on +  // it to handle module -> function analysis invalidation in the face of +  // structural changes and so if it's unavailable we conservatively clear the +  // entire SCC layer as well rather than trying to do invalidation ourselves. +  auto PAC = PA.getChecker<CGSCCAnalysisManagerModuleProxy>(); +  if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>()) || +      Inv.invalidate<LazyCallGraphAnalysis>(M, PA) || +      Inv.invalidate<FunctionAnalysisManagerModuleProxy>(M, PA)) { +    InnerAM->clear(); + +    // And the proxy itself should be marked as invalid so that we can observe +    // the new call graph. This isn't strictly necessary because we cheat +    // above, but is still useful. +    return true; +  } + +  // Directly check if the relevant set is preserved so we can short circuit +  // invalidating SCCs below. +  bool AreSCCAnalysesPreserved = +      PA.allAnalysesInSetPreserved<AllAnalysesOn<LazyCallGraph::SCC>>(); + +  // Ok, we have a graph, so we can propagate the invalidation down into it. +  G->buildRefSCCs(); +  for (auto &RC : G->postorder_ref_sccs()) +    for (auto &C : RC) { +      Optional<PreservedAnalyses> InnerPA; + +      // Check to see whether the preserved set needs to be adjusted based on +      // module-level analysis invalidation triggering deferred invalidation +      // for this SCC. +      if (auto *OuterProxy = +              InnerAM->getCachedResult<ModuleAnalysisManagerCGSCCProxy>(C)) +        for (const auto &OuterInvalidationPair : +             OuterProxy->getOuterInvalidations()) { +          AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first; +          const auto &InnerAnalysisIDs = OuterInvalidationPair.second; +          if (Inv.invalidate(OuterAnalysisID, M, PA)) { +            if (!InnerPA) +              InnerPA = PA; +            for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) +              InnerPA->abandon(InnerAnalysisID); +          } +        } + +      // Check if we needed a custom PA set. If so we'll need to run the inner +      // invalidation. +      if (InnerPA) { +        InnerAM->invalidate(C, *InnerPA); +        continue; +      } + +      // Otherwise we only need to do invalidation if the original PA set didn't +      // preserve all SCC analyses. +      if (!AreSCCAnalysesPreserved) +        InnerAM->invalidate(C, PA); +    } + +  // Return false to indicate that this result is still a valid proxy. +  return false; +} + +template <> +CGSCCAnalysisManagerModuleProxy::Result +CGSCCAnalysisManagerModuleProxy::run(Module &M, ModuleAnalysisManager &AM) { +  // Force the Function analysis manager to also be available so that it can +  // be accessed in an SCC analysis and proxied onward to function passes. +  // FIXME: It is pretty awkward to just drop the result here and assert that +  // we can find it again later. +  (void)AM.getResult<FunctionAnalysisManagerModuleProxy>(M); + +  return Result(*InnerAM, AM.getResult<LazyCallGraphAnalysis>(M)); +} + +AnalysisKey FunctionAnalysisManagerCGSCCProxy::Key; + +FunctionAnalysisManagerCGSCCProxy::Result +FunctionAnalysisManagerCGSCCProxy::run(LazyCallGraph::SCC &C, +                                       CGSCCAnalysisManager &AM, +                                       LazyCallGraph &CG) { +  // Collect the FunctionAnalysisManager from the Module layer and use that to +  // build the proxy result. +  // +  // This allows us to rely on the FunctionAnalysisMangaerModuleProxy to +  // invalidate the function analyses. +  auto &MAM = AM.getResult<ModuleAnalysisManagerCGSCCProxy>(C, CG).getManager(); +  Module &M = *C.begin()->getFunction().getParent(); +  auto *FAMProxy = MAM.getCachedResult<FunctionAnalysisManagerModuleProxy>(M); +  assert(FAMProxy && "The CGSCC pass manager requires that the FAM module " +                     "proxy is run on the module prior to entering the CGSCC " +                     "walk."); + +  // Note that we special-case invalidation handling of this proxy in the CGSCC +  // analysis manager's Module proxy. This avoids the need to do anything +  // special here to recompute all of this if ever the FAM's module proxy goes +  // away. +  return Result(FAMProxy->getManager()); +} + +bool FunctionAnalysisManagerCGSCCProxy::Result::invalidate( +    LazyCallGraph::SCC &C, const PreservedAnalyses &PA, +    CGSCCAnalysisManager::Invalidator &Inv) { +  // If literally everything is preserved, we're done. +  if (PA.areAllPreserved()) +    return false; // This is still a valid proxy. + +  // If this proxy isn't marked as preserved, then even if the result remains +  // valid, the key itself may no longer be valid, so we clear everything. +  // +  // Note that in order to preserve this proxy, a module pass must ensure that +  // the FAM has been completely updated to handle the deletion of functions. +  // Specifically, any FAM-cached results for those functions need to have been +  // forcibly cleared. When preserved, this proxy will only invalidate results +  // cached on functions *still in the module* at the end of the module pass. +  auto PAC = PA.getChecker<FunctionAnalysisManagerCGSCCProxy>(); +  if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<LazyCallGraph::SCC>>()) { +    for (LazyCallGraph::Node &N : C) +      FAM->clear(N.getFunction(), N.getFunction().getName()); + +    return true; +  } + +  // Directly check if the relevant set is preserved. +  bool AreFunctionAnalysesPreserved = +      PA.allAnalysesInSetPreserved<AllAnalysesOn<Function>>(); + +  // Now walk all the functions to see if any inner analysis invalidation is +  // necessary. +  for (LazyCallGraph::Node &N : C) { +    Function &F = N.getFunction(); +    Optional<PreservedAnalyses> FunctionPA; + +    // Check to see whether the preserved set needs to be pruned based on +    // SCC-level analysis invalidation that triggers deferred invalidation +    // registered with the outer analysis manager proxy for this function. +    if (auto *OuterProxy = +            FAM->getCachedResult<CGSCCAnalysisManagerFunctionProxy>(F)) +      for (const auto &OuterInvalidationPair : +           OuterProxy->getOuterInvalidations()) { +        AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first; +        const auto &InnerAnalysisIDs = OuterInvalidationPair.second; +        if (Inv.invalidate(OuterAnalysisID, C, PA)) { +          if (!FunctionPA) +            FunctionPA = PA; +          for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) +            FunctionPA->abandon(InnerAnalysisID); +        } +      } + +    // Check if we needed a custom PA set, and if so we'll need to run the +    // inner invalidation. +    if (FunctionPA) { +      FAM->invalidate(F, *FunctionPA); +      continue; +    } + +    // Otherwise we only need to do invalidation if the original PA set didn't +    // preserve all function analyses. +    if (!AreFunctionAnalysesPreserved) +      FAM->invalidate(F, PA); +  } + +  // Return false to indicate that this result is still a valid proxy. +  return false; +} + +} // end namespace llvm + +/// When a new SCC is created for the graph and there might be function +/// analysis results cached for the functions now in that SCC two forms of +/// updates are required. +/// +/// First, a proxy from the SCC to the FunctionAnalysisManager needs to be +/// created so that any subsequent invalidation events to the SCC are +/// propagated to the function analysis results cached for functions within it. +/// +/// Second, if any of the functions within the SCC have analysis results with +/// outer analysis dependencies, then those dependencies would point to the +/// *wrong* SCC's analysis result. We forcibly invalidate the necessary +/// function analyses so that they don't retain stale handles. +static void updateNewSCCFunctionAnalyses(LazyCallGraph::SCC &C, +                                         LazyCallGraph &G, +                                         CGSCCAnalysisManager &AM) { +  // Get the relevant function analysis manager. +  auto &FAM = +      AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, G).getManager(); + +  // Now walk the functions in this SCC and invalidate any function analysis +  // results that might have outer dependencies on an SCC analysis. +  for (LazyCallGraph::Node &N : C) { +    Function &F = N.getFunction(); + +    auto *OuterProxy = +        FAM.getCachedResult<CGSCCAnalysisManagerFunctionProxy>(F); +    if (!OuterProxy) +      // No outer analyses were queried, nothing to do. +      continue; + +    // Forcibly abandon all the inner analyses with dependencies, but +    // invalidate nothing else. +    auto PA = PreservedAnalyses::all(); +    for (const auto &OuterInvalidationPair : +         OuterProxy->getOuterInvalidations()) { +      const auto &InnerAnalysisIDs = OuterInvalidationPair.second; +      for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) +        PA.abandon(InnerAnalysisID); +    } + +    // Now invalidate anything we found. +    FAM.invalidate(F, PA); +  } +} + +/// Helper function to update both the \c CGSCCAnalysisManager \p AM and the \c +/// CGSCCPassManager's \c CGSCCUpdateResult \p UR based on a range of newly +/// added SCCs. +/// +/// The range of new SCCs must be in postorder already. The SCC they were split +/// out of must be provided as \p C. The current node being mutated and +/// triggering updates must be passed as \p N. +/// +/// This function returns the SCC containing \p N. This will be either \p C if +/// no new SCCs have been split out, or it will be the new SCC containing \p N. +template <typename SCCRangeT> +static LazyCallGraph::SCC * +incorporateNewSCCRange(const SCCRangeT &NewSCCRange, LazyCallGraph &G, +                       LazyCallGraph::Node &N, LazyCallGraph::SCC *C, +                       CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) { +  using SCC = LazyCallGraph::SCC; + +  if (NewSCCRange.begin() == NewSCCRange.end()) +    return C; + +  // Add the current SCC to the worklist as its shape has changed. +  UR.CWorklist.insert(C); +  LLVM_DEBUG(dbgs() << "Enqueuing the existing SCC in the worklist:" << *C +                    << "\n"); + +  SCC *OldC = C; + +  // Update the current SCC. Note that if we have new SCCs, this must actually +  // change the SCC. +  assert(C != &*NewSCCRange.begin() && +         "Cannot insert new SCCs without changing current SCC!"); +  C = &*NewSCCRange.begin(); +  assert(G.lookupSCC(N) == C && "Failed to update current SCC!"); + +  // If we had a cached FAM proxy originally, we will want to create more of +  // them for each SCC that was split off. +  bool NeedFAMProxy = +      AM.getCachedResult<FunctionAnalysisManagerCGSCCProxy>(*OldC) != nullptr; + +  // We need to propagate an invalidation call to all but the newly current SCC +  // because the outer pass manager won't do that for us after splitting them. +  // FIXME: We should accept a PreservedAnalysis from the CG updater so that if +  // there are preserved analysis we can avoid invalidating them here for +  // split-off SCCs. +  // We know however that this will preserve any FAM proxy so go ahead and mark +  // that. +  PreservedAnalyses PA; +  PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); +  AM.invalidate(*OldC, PA); + +  // Ensure the now-current SCC's function analyses are updated. +  if (NeedFAMProxy) +    updateNewSCCFunctionAnalyses(*C, G, AM); + +  for (SCC &NewC : llvm::reverse(make_range(std::next(NewSCCRange.begin()), +                                            NewSCCRange.end()))) { +    assert(C != &NewC && "No need to re-visit the current SCC!"); +    assert(OldC != &NewC && "Already handled the original SCC!"); +    UR.CWorklist.insert(&NewC); +    LLVM_DEBUG(dbgs() << "Enqueuing a newly formed SCC:" << NewC << "\n"); + +    // Ensure new SCCs' function analyses are updated. +    if (NeedFAMProxy) +      updateNewSCCFunctionAnalyses(NewC, G, AM); + +    // Also propagate a normal invalidation to the new SCC as only the current +    // will get one from the pass manager infrastructure. +    AM.invalidate(NewC, PA); +  } +  return C; +} + +LazyCallGraph::SCC &llvm::updateCGAndAnalysisManagerForFunctionPass( +    LazyCallGraph &G, LazyCallGraph::SCC &InitialC, LazyCallGraph::Node &N, +    CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR) { +  using Node = LazyCallGraph::Node; +  using Edge = LazyCallGraph::Edge; +  using SCC = LazyCallGraph::SCC; +  using RefSCC = LazyCallGraph::RefSCC; + +  RefSCC &InitialRC = InitialC.getOuterRefSCC(); +  SCC *C = &InitialC; +  RefSCC *RC = &InitialRC; +  Function &F = N.getFunction(); + +  // Walk the function body and build up the set of retained, promoted, and +  // demoted edges. +  SmallVector<Constant *, 16> Worklist; +  SmallPtrSet<Constant *, 16> Visited; +  SmallPtrSet<Node *, 16> RetainedEdges; +  SmallSetVector<Node *, 4> PromotedRefTargets; +  SmallSetVector<Node *, 4> DemotedCallTargets; + +  // First walk the function and handle all called functions. We do this first +  // because if there is a single call edge, whether there are ref edges is +  // irrelevant. +  for (Instruction &I : instructions(F)) +    if (auto CS = CallSite(&I)) +      if (Function *Callee = CS.getCalledFunction()) +        if (Visited.insert(Callee).second && !Callee->isDeclaration()) { +          Node &CalleeN = *G.lookup(*Callee); +          Edge *E = N->lookup(CalleeN); +          // FIXME: We should really handle adding new calls. While it will +          // make downstream usage more complex, there is no fundamental +          // limitation and it will allow passes within the CGSCC to be a bit +          // more flexible in what transforms they can do. Until then, we +          // verify that new calls haven't been introduced. +          assert(E && "No function transformations should introduce *new* " +                      "call edges! Any new calls should be modeled as " +                      "promoted existing ref edges!"); +          bool Inserted = RetainedEdges.insert(&CalleeN).second; +          (void)Inserted; +          assert(Inserted && "We should never visit a function twice."); +          if (!E->isCall()) +            PromotedRefTargets.insert(&CalleeN); +        } + +  // Now walk all references. +  for (Instruction &I : instructions(F)) +    for (Value *Op : I.operand_values()) +      if (auto *C = dyn_cast<Constant>(Op)) +        if (Visited.insert(C).second) +          Worklist.push_back(C); + +  auto VisitRef = [&](Function &Referee) { +    Node &RefereeN = *G.lookup(Referee); +    Edge *E = N->lookup(RefereeN); +    // FIXME: Similarly to new calls, we also currently preclude +    // introducing new references. See above for details. +    assert(E && "No function transformations should introduce *new* ref " +                "edges! Any new ref edges would require IPO which " +                "function passes aren't allowed to do!"); +    bool Inserted = RetainedEdges.insert(&RefereeN).second; +    (void)Inserted; +    assert(Inserted && "We should never visit a function twice."); +    if (E->isCall()) +      DemotedCallTargets.insert(&RefereeN); +  }; +  LazyCallGraph::visitReferences(Worklist, Visited, VisitRef); + +  // Include synthetic reference edges to known, defined lib functions. +  for (auto *F : G.getLibFunctions()) +    // While the list of lib functions doesn't have repeats, don't re-visit +    // anything handled above. +    if (!Visited.count(F)) +      VisitRef(*F); + +  // First remove all of the edges that are no longer present in this function. +  // The first step makes these edges uniformly ref edges and accumulates them +  // into a separate data structure so removal doesn't invalidate anything. +  SmallVector<Node *, 4> DeadTargets; +  for (Edge &E : *N) { +    if (RetainedEdges.count(&E.getNode())) +      continue; + +    SCC &TargetC = *G.lookupSCC(E.getNode()); +    RefSCC &TargetRC = TargetC.getOuterRefSCC(); +    if (&TargetRC == RC && E.isCall()) { +      if (C != &TargetC) { +        // For separate SCCs this is trivial. +        RC->switchTrivialInternalEdgeToRef(N, E.getNode()); +      } else { +        // Now update the call graph. +        C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, E.getNode()), +                                   G, N, C, AM, UR); +      } +    } + +    // Now that this is ready for actual removal, put it into our list. +    DeadTargets.push_back(&E.getNode()); +  } +  // Remove the easy cases quickly and actually pull them out of our list. +  DeadTargets.erase( +      llvm::remove_if(DeadTargets, +                      [&](Node *TargetN) { +                        SCC &TargetC = *G.lookupSCC(*TargetN); +                        RefSCC &TargetRC = TargetC.getOuterRefSCC(); + +                        // We can't trivially remove internal targets, so skip +                        // those. +                        if (&TargetRC == RC) +                          return false; + +                        RC->removeOutgoingEdge(N, *TargetN); +                        LLVM_DEBUG(dbgs() << "Deleting outgoing edge from '" +                                          << N << "' to '" << TargetN << "'\n"); +                        return true; +                      }), +      DeadTargets.end()); + +  // Now do a batch removal of the internal ref edges left. +  auto NewRefSCCs = RC->removeInternalRefEdge(N, DeadTargets); +  if (!NewRefSCCs.empty()) { +    // The old RefSCC is dead, mark it as such. +    UR.InvalidatedRefSCCs.insert(RC); + +    // Note that we don't bother to invalidate analyses as ref-edge +    // connectivity is not really observable in any way and is intended +    // exclusively to be used for ordering of transforms rather than for +    // analysis conclusions. + +    // Update RC to the "bottom". +    assert(G.lookupSCC(N) == C && "Changed the SCC when splitting RefSCCs!"); +    RC = &C->getOuterRefSCC(); +    assert(G.lookupRefSCC(N) == RC && "Failed to update current RefSCC!"); + +    // The RC worklist is in reverse postorder, so we enqueue the new ones in +    // RPO except for the one which contains the source node as that is the +    // "bottom" we will continue processing in the bottom-up walk. +    assert(NewRefSCCs.front() == RC && +           "New current RefSCC not first in the returned list!"); +    for (RefSCC *NewRC : llvm::reverse(make_range(std::next(NewRefSCCs.begin()), +                                                  NewRefSCCs.end()))) { +      assert(NewRC != RC && "Should not encounter the current RefSCC further " +                            "in the postorder list of new RefSCCs."); +      UR.RCWorklist.insert(NewRC); +      LLVM_DEBUG(dbgs() << "Enqueuing a new RefSCC in the update worklist: " +                        << *NewRC << "\n"); +    } +  } + +  // Next demote all the call edges that are now ref edges. This helps make +  // the SCCs small which should minimize the work below as we don't want to +  // form cycles that this would break. +  for (Node *RefTarget : DemotedCallTargets) { +    SCC &TargetC = *G.lookupSCC(*RefTarget); +    RefSCC &TargetRC = TargetC.getOuterRefSCC(); + +    // The easy case is when the target RefSCC is not this RefSCC. This is +    // only supported when the target RefSCC is a child of this RefSCC. +    if (&TargetRC != RC) { +      assert(RC->isAncestorOf(TargetRC) && +             "Cannot potentially form RefSCC cycles here!"); +      RC->switchOutgoingEdgeToRef(N, *RefTarget); +      LLVM_DEBUG(dbgs() << "Switch outgoing call edge to a ref edge from '" << N +                        << "' to '" << *RefTarget << "'\n"); +      continue; +    } + +    // We are switching an internal call edge to a ref edge. This may split up +    // some SCCs. +    if (C != &TargetC) { +      // For separate SCCs this is trivial. +      RC->switchTrivialInternalEdgeToRef(N, *RefTarget); +      continue; +    } + +    // Now update the call graph. +    C = incorporateNewSCCRange(RC->switchInternalEdgeToRef(N, *RefTarget), G, N, +                               C, AM, UR); +  } + +  // Now promote ref edges into call edges. +  for (Node *CallTarget : PromotedRefTargets) { +    SCC &TargetC = *G.lookupSCC(*CallTarget); +    RefSCC &TargetRC = TargetC.getOuterRefSCC(); + +    // The easy case is when the target RefSCC is not this RefSCC. This is +    // only supported when the target RefSCC is a child of this RefSCC. +    if (&TargetRC != RC) { +      assert(RC->isAncestorOf(TargetRC) && +             "Cannot potentially form RefSCC cycles here!"); +      RC->switchOutgoingEdgeToCall(N, *CallTarget); +      LLVM_DEBUG(dbgs() << "Switch outgoing ref edge to a call edge from '" << N +                        << "' to '" << *CallTarget << "'\n"); +      continue; +    } +    LLVM_DEBUG(dbgs() << "Switch an internal ref edge to a call edge from '" +                      << N << "' to '" << *CallTarget << "'\n"); + +    // Otherwise we are switching an internal ref edge to a call edge. This +    // may merge away some SCCs, and we add those to the UpdateResult. We also +    // need to make sure to update the worklist in the event SCCs have moved +    // before the current one in the post-order sequence +    bool HasFunctionAnalysisProxy = false; +    auto InitialSCCIndex = RC->find(*C) - RC->begin(); +    bool FormedCycle = RC->switchInternalEdgeToCall( +        N, *CallTarget, [&](ArrayRef<SCC *> MergedSCCs) { +          for (SCC *MergedC : MergedSCCs) { +            assert(MergedC != &TargetC && "Cannot merge away the target SCC!"); + +            HasFunctionAnalysisProxy |= +                AM.getCachedResult<FunctionAnalysisManagerCGSCCProxy>( +                    *MergedC) != nullptr; + +            // Mark that this SCC will no longer be valid. +            UR.InvalidatedSCCs.insert(MergedC); + +            // FIXME: We should really do a 'clear' here to forcibly release +            // memory, but we don't have a good way of doing that and +            // preserving the function analyses. +            auto PA = PreservedAnalyses::allInSet<AllAnalysesOn<Function>>(); +            PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); +            AM.invalidate(*MergedC, PA); +          } +        }); + +    // If we formed a cycle by creating this call, we need to update more data +    // structures. +    if (FormedCycle) { +      C = &TargetC; +      assert(G.lookupSCC(N) == C && "Failed to update current SCC!"); + +      // If one of the invalidated SCCs had a cached proxy to a function +      // analysis manager, we need to create a proxy in the new current SCC as +      // the invalidated SCCs had their functions moved. +      if (HasFunctionAnalysisProxy) +        AM.getResult<FunctionAnalysisManagerCGSCCProxy>(*C, G); + +      // Any analyses cached for this SCC are no longer precise as the shape +      // has changed by introducing this cycle. However, we have taken care to +      // update the proxies so it remains valide. +      auto PA = PreservedAnalyses::allInSet<AllAnalysesOn<Function>>(); +      PA.preserve<FunctionAnalysisManagerCGSCCProxy>(); +      AM.invalidate(*C, PA); +    } +    auto NewSCCIndex = RC->find(*C) - RC->begin(); +    // If we have actually moved an SCC to be topologically "below" the current +    // one due to merging, we will need to revisit the current SCC after +    // visiting those moved SCCs. +    // +    // It is critical that we *do not* revisit the current SCC unless we +    // actually move SCCs in the process of merging because otherwise we may +    // form a cycle where an SCC is split apart, merged, split, merged and so +    // on infinitely. +    if (InitialSCCIndex < NewSCCIndex) { +      // Put our current SCC back onto the worklist as we'll visit other SCCs +      // that are now definitively ordered prior to the current one in the +      // post-order sequence, and may end up observing more precise context to +      // optimize the current SCC. +      UR.CWorklist.insert(C); +      LLVM_DEBUG(dbgs() << "Enqueuing the existing SCC in the worklist: " << *C +                        << "\n"); +      // Enqueue in reverse order as we pop off the back of the worklist. +      for (SCC &MovedC : llvm::reverse(make_range(RC->begin() + InitialSCCIndex, +                                                  RC->begin() + NewSCCIndex))) { +        UR.CWorklist.insert(&MovedC); +        LLVM_DEBUG(dbgs() << "Enqueuing a newly earlier in post-order SCC: " +                          << MovedC << "\n"); +      } +    } +  } + +  assert(!UR.InvalidatedSCCs.count(C) && "Invalidated the current SCC!"); +  assert(!UR.InvalidatedRefSCCs.count(RC) && "Invalidated the current RefSCC!"); +  assert(&C->getOuterRefSCC() == RC && "Current SCC not in current RefSCC!"); + +  // Record the current RefSCC and SCC for higher layers of the CGSCC pass +  // manager now that all the updates have been applied. +  if (RC != &InitialRC) +    UR.UpdatedRC = RC; +  if (C != &InitialC) +    UR.UpdatedC = C; + +  return *C; +} diff --git a/contrib/llvm/lib/Analysis/CallGraph.cpp b/contrib/llvm/lib/Analysis/CallGraph.cpp new file mode 100644 index 000000000000..cbdf5f63c557 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CallGraph.cpp @@ -0,0 +1,329 @@ +//===- CallGraph.cpp - Build a Module's call graph ------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CallGraph.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> + +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Implementations of the CallGraph class methods. +// + +CallGraph::CallGraph(Module &M) +    : M(M), ExternalCallingNode(getOrInsertFunction(nullptr)), +      CallsExternalNode(llvm::make_unique<CallGraphNode>(nullptr)) { +  // Add every function to the call graph. +  for (Function &F : M) +    addToCallGraph(&F); +} + +CallGraph::CallGraph(CallGraph &&Arg) +    : M(Arg.M), FunctionMap(std::move(Arg.FunctionMap)), +      ExternalCallingNode(Arg.ExternalCallingNode), +      CallsExternalNode(std::move(Arg.CallsExternalNode)) { +  Arg.FunctionMap.clear(); +  Arg.ExternalCallingNode = nullptr; +} + +CallGraph::~CallGraph() { +  // CallsExternalNode is not in the function map, delete it explicitly. +  if (CallsExternalNode) +    CallsExternalNode->allReferencesDropped(); + +// Reset all node's use counts to zero before deleting them to prevent an +// assertion from firing. +#ifndef NDEBUG +  for (auto &I : FunctionMap) +    I.second->allReferencesDropped(); +#endif +} + +void CallGraph::addToCallGraph(Function *F) { +  CallGraphNode *Node = getOrInsertFunction(F); + +  // If this function has external linkage or has its address taken, anything +  // could call it. +  if (!F->hasLocalLinkage() || F->hasAddressTaken()) +    ExternalCallingNode->addCalledFunction(CallSite(), Node); + +  // If this function is not defined in this translation unit, it could call +  // anything. +  if (F->isDeclaration() && !F->isIntrinsic()) +    Node->addCalledFunction(CallSite(), CallsExternalNode.get()); + +  // Look for calls by this function. +  for (BasicBlock &BB : *F) +    for (Instruction &I : BB) { +      if (auto CS = CallSite(&I)) { +        const Function *Callee = CS.getCalledFunction(); +        if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) +          // Indirect calls of intrinsics are not allowed so no need to check. +          // We can be more precise here by using TargetArg returned by +          // Intrinsic::isLeaf. +          Node->addCalledFunction(CS, CallsExternalNode.get()); +        else if (!Callee->isIntrinsic()) +          Node->addCalledFunction(CS, getOrInsertFunction(Callee)); +      } +    } +} + +void CallGraph::print(raw_ostream &OS) const { +  // Print in a deterministic order by sorting CallGraphNodes by name.  We do +  // this here to avoid slowing down the non-printing fast path. + +  SmallVector<CallGraphNode *, 16> Nodes; +  Nodes.reserve(FunctionMap.size()); + +  for (const auto &I : *this) +    Nodes.push_back(I.second.get()); + +  llvm::sort(Nodes.begin(), Nodes.end(), +             [](CallGraphNode *LHS, CallGraphNode *RHS) { +    if (Function *LF = LHS->getFunction()) +      if (Function *RF = RHS->getFunction()) +        return LF->getName() < RF->getName(); + +    return RHS->getFunction() != nullptr; +  }); + +  for (CallGraphNode *CN : Nodes) +    CN->print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void CallGraph::dump() const { print(dbgs()); } +#endif + +// removeFunctionFromModule - Unlink the function from this module, returning +// it.  Because this removes the function from the module, the call graph node +// is destroyed.  This is only valid if the function does not call any other +// functions (ie, there are no edges in it's CGN).  The easiest way to do this +// is to dropAllReferences before calling this. +// +Function *CallGraph::removeFunctionFromModule(CallGraphNode *CGN) { +  assert(CGN->empty() && "Cannot remove function from call " +         "graph if it references other functions!"); +  Function *F = CGN->getFunction(); // Get the function for the call graph node +  FunctionMap.erase(F);             // Remove the call graph node from the map + +  M.getFunctionList().remove(F); +  return F; +} + +/// spliceFunction - Replace the function represented by this node by another. +/// This does not rescan the body of the function, so it is suitable when +/// splicing the body of the old function to the new while also updating all +/// callers from old to new. +void CallGraph::spliceFunction(const Function *From, const Function *To) { +  assert(FunctionMap.count(From) && "No CallGraphNode for function!"); +  assert(!FunctionMap.count(To) && +         "Pointing CallGraphNode at a function that already exists"); +  FunctionMapTy::iterator I = FunctionMap.find(From); +  I->second->F = const_cast<Function*>(To); +  FunctionMap[To] = std::move(I->second); +  FunctionMap.erase(I); +} + +// getOrInsertFunction - This method is identical to calling operator[], but +// it will insert a new CallGraphNode for the specified function if one does +// not already exist. +CallGraphNode *CallGraph::getOrInsertFunction(const Function *F) { +  auto &CGN = FunctionMap[F]; +  if (CGN) +    return CGN.get(); + +  assert((!F || F->getParent() == &M) && "Function not in current module!"); +  CGN = llvm::make_unique<CallGraphNode>(const_cast<Function *>(F)); +  return CGN.get(); +} + +//===----------------------------------------------------------------------===// +// Implementations of the CallGraphNode class methods. +// + +void CallGraphNode::print(raw_ostream &OS) const { +  if (Function *F = getFunction()) +    OS << "Call graph node for function: '" << F->getName() << "'"; +  else +    OS << "Call graph node <<null function>>"; + +  OS << "<<" << this << ">>  #uses=" << getNumReferences() << '\n'; + +  for (const auto &I : *this) { +    OS << "  CS<" << I.first << "> calls "; +    if (Function *FI = I.second->getFunction()) +      OS << "function '" << FI->getName() <<"'\n"; +    else +      OS << "external node\n"; +  } +  OS << '\n'; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void CallGraphNode::dump() const { print(dbgs()); } +#endif + +/// removeCallEdgeFor - This method removes the edge in the node for the +/// specified call site.  Note that this method takes linear time, so it +/// should be used sparingly. +void CallGraphNode::removeCallEdgeFor(CallSite CS) { +  for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { +    assert(I != CalledFunctions.end() && "Cannot find callsite to remove!"); +    if (I->first == CS.getInstruction()) { +      I->second->DropRef(); +      *I = CalledFunctions.back(); +      CalledFunctions.pop_back(); +      return; +    } +  } +} + +// removeAnyCallEdgeTo - This method removes any call edges from this node to +// the specified callee function.  This takes more time to execute than +// removeCallEdgeTo, so it should not be used unless necessary. +void CallGraphNode::removeAnyCallEdgeTo(CallGraphNode *Callee) { +  for (unsigned i = 0, e = CalledFunctions.size(); i != e; ++i) +    if (CalledFunctions[i].second == Callee) { +      Callee->DropRef(); +      CalledFunctions[i] = CalledFunctions.back(); +      CalledFunctions.pop_back(); +      --i; --e; +    } +} + +/// removeOneAbstractEdgeTo - Remove one edge associated with a null callsite +/// from this node to the specified callee function. +void CallGraphNode::removeOneAbstractEdgeTo(CallGraphNode *Callee) { +  for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { +    assert(I != CalledFunctions.end() && "Cannot find callee to remove!"); +    CallRecord &CR = *I; +    if (CR.second == Callee && CR.first == nullptr) { +      Callee->DropRef(); +      *I = CalledFunctions.back(); +      CalledFunctions.pop_back(); +      return; +    } +  } +} + +/// replaceCallEdge - This method replaces the edge in the node for the +/// specified call site with a new one.  Note that this method takes linear +/// time, so it should be used sparingly. +void CallGraphNode::replaceCallEdge(CallSite CS, +                                    CallSite NewCS, CallGraphNode *NewNode){ +  for (CalledFunctionsVector::iterator I = CalledFunctions.begin(); ; ++I) { +    assert(I != CalledFunctions.end() && "Cannot find callsite to remove!"); +    if (I->first == CS.getInstruction()) { +      I->second->DropRef(); +      I->first = NewCS.getInstruction(); +      I->second = NewNode; +      NewNode->AddRef(); +      return; +    } +  } +} + +// Provide an explicit template instantiation for the static ID. +AnalysisKey CallGraphAnalysis::Key; + +PreservedAnalyses CallGraphPrinterPass::run(Module &M, +                                            ModuleAnalysisManager &AM) { +  AM.getResult<CallGraphAnalysis>(M).print(OS); +  return PreservedAnalyses::all(); +} + +//===----------------------------------------------------------------------===// +// Out-of-line definitions of CallGraphAnalysis class members. +// + +//===----------------------------------------------------------------------===// +// Implementations of the CallGraphWrapperPass class methods. +// + +CallGraphWrapperPass::CallGraphWrapperPass() : ModulePass(ID) { +  initializeCallGraphWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +CallGraphWrapperPass::~CallGraphWrapperPass() = default; + +void CallGraphWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +} + +bool CallGraphWrapperPass::runOnModule(Module &M) { +  // All the real work is done in the constructor for the CallGraph. +  G.reset(new CallGraph(M)); +  return false; +} + +INITIALIZE_PASS(CallGraphWrapperPass, "basiccg", "CallGraph Construction", +                false, true) + +char CallGraphWrapperPass::ID = 0; + +void CallGraphWrapperPass::releaseMemory() { G.reset(); } + +void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const { +  if (!G) { +    OS << "No call graph has been built!\n"; +    return; +  } + +  // Just delegate. +  G->print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD +void CallGraphWrapperPass::dump() const { print(dbgs(), nullptr); } +#endif + +namespace { + +struct CallGraphPrinterLegacyPass : public ModulePass { +  static char ID; // Pass ID, replacement for typeid + +  CallGraphPrinterLegacyPass() : ModulePass(ID) { +    initializeCallGraphPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); +  } + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.setPreservesAll(); +    AU.addRequiredTransitive<CallGraphWrapperPass>(); +  } + +  bool runOnModule(Module &M) override { +    getAnalysis<CallGraphWrapperPass>().print(errs(), &M); +    return false; +  } +}; + +} // end anonymous namespace + +char CallGraphPrinterLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(CallGraphPrinterLegacyPass, "print-callgraph", +                      "Print a call graph", true, true) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_END(CallGraphPrinterLegacyPass, "print-callgraph", +                    "Print a call graph", true, true) diff --git a/contrib/llvm/lib/Analysis/CallGraphSCCPass.cpp b/contrib/llvm/lib/Analysis/CallGraphSCCPass.cpp new file mode 100644 index 000000000000..4c33c420b65d --- /dev/null +++ b/contrib/llvm/lib/Analysis/CallGraphSCCPass.cpp @@ -0,0 +1,666 @@ +//===- CallGraphSCCPass.cpp - Pass that operates BU on call graph ---------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the CallGraphSCCPass class, which is used for passes +// which are implemented as bottom-up traversals on the call graph.  Because +// there may be cycles in the call graph, passes of this type operate on the +// call-graph in SCC order: that is, they process function bottom-up, except for +// recursive functions, which they process all at once. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManagers.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/OptBisect.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <string> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "cgscc-passmgr" + +static cl::opt<unsigned> +MaxIterations("max-cg-scc-iterations", cl::ReallyHidden, cl::init(4)); + +STATISTIC(MaxSCCIterations, "Maximum CGSCCPassMgr iterations on one SCC"); + +//===----------------------------------------------------------------------===// +// CGPassManager +// +/// CGPassManager manages FPPassManagers and CallGraphSCCPasses. + +namespace { + +class CGPassManager : public ModulePass, public PMDataManager { +public: +  static char ID; + +  explicit CGPassManager() : ModulePass(ID), PMDataManager() {} + +  /// Execute all of the passes scheduled for execution.  Keep track of +  /// whether any of the passes modifies the module, and if so, return true. +  bool runOnModule(Module &M) override; + +  using ModulePass::doInitialization; +  using ModulePass::doFinalization; + +  bool doInitialization(CallGraph &CG); +  bool doFinalization(CallGraph &CG); + +  /// Pass Manager itself does not invalidate any analysis info. +  void getAnalysisUsage(AnalysisUsage &Info) const override { +    // CGPassManager walks SCC and it needs CallGraph. +    Info.addRequired<CallGraphWrapperPass>(); +    Info.setPreservesAll(); +  } + +  StringRef getPassName() const override { return "CallGraph Pass Manager"; } + +  PMDataManager *getAsPMDataManager() override { return this; } +  Pass *getAsPass() override { return this; } + +  // Print passes managed by this manager +  void dumpPassStructure(unsigned Offset) override { +    errs().indent(Offset*2) << "Call Graph SCC Pass Manager\n"; +    for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +      Pass *P = getContainedPass(Index); +      P->dumpPassStructure(Offset + 1); +      dumpLastUses(P, Offset+1); +    } +  } + +  Pass *getContainedPass(unsigned N) { +    assert(N < PassVector.size() && "Pass number out of range!"); +    return static_cast<Pass *>(PassVector[N]); +  } + +  PassManagerType getPassManagerType() const override { +    return PMT_CallGraphPassManager; +  } + +private: +  bool RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG, +                         bool &DevirtualizedCall); + +  bool RunPassOnSCC(Pass *P, CallGraphSCC &CurSCC, +                    CallGraph &CG, bool &CallGraphUpToDate, +                    bool &DevirtualizedCall); +  bool RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG, +                        bool IsCheckingMode); +}; + +} // end anonymous namespace. + +char CGPassManager::ID = 0; + +bool CGPassManager::RunPassOnSCC(Pass *P, CallGraphSCC &CurSCC, +                                 CallGraph &CG, bool &CallGraphUpToDate, +                                 bool &DevirtualizedCall) { +  bool Changed = false; +  PMDataManager *PM = P->getAsPMDataManager(); +  Module &M = CG.getModule(); + +  if (!PM) { +    CallGraphSCCPass *CGSP = (CallGraphSCCPass*)P; +    if (!CallGraphUpToDate) { +      DevirtualizedCall |= RefreshCallGraph(CurSCC, CG, false); +      CallGraphUpToDate = true; +    } + +    { +      unsigned InstrCount = 0; +      bool EmitICRemark = M.shouldEmitInstrCountChangedRemark(); +      TimeRegion PassTimer(getPassTimer(CGSP)); +      if (EmitICRemark) +        InstrCount = initSizeRemarkInfo(M); +      Changed = CGSP->runOnSCC(CurSCC); + +      // If the pass modified the module, it may have modified the instruction +      // count of the module. Try emitting a remark. +      if (EmitICRemark) +        emitInstrCountChangedRemark(P, M, InstrCount); +    } + +    // After the CGSCCPass is done, when assertions are enabled, use +    // RefreshCallGraph to verify that the callgraph was correctly updated. +#ifndef NDEBUG +    if (Changed) +      RefreshCallGraph(CurSCC, CG, true); +#endif + +    return Changed; +  } + +  assert(PM->getPassManagerType() == PMT_FunctionPassManager && +         "Invalid CGPassManager member"); +  FPPassManager *FPP = (FPPassManager*)P; + +  // Run pass P on all functions in the current SCC. +  for (CallGraphNode *CGN : CurSCC) { +    if (Function *F = CGN->getFunction()) { +      dumpPassInfo(P, EXECUTION_MSG, ON_FUNCTION_MSG, F->getName()); +      { +        TimeRegion PassTimer(getPassTimer(FPP)); +        Changed |= FPP->runOnFunction(*F); +      } +      F->getContext().yield(); +    } +  } + +  // The function pass(es) modified the IR, they may have clobbered the +  // callgraph. +  if (Changed && CallGraphUpToDate) { +    LLVM_DEBUG(dbgs() << "CGSCCPASSMGR: Pass Dirtied SCC: " << P->getPassName() +                      << '\n'); +    CallGraphUpToDate = false; +  } +  return Changed; +} + +/// Scan the functions in the specified CFG and resync the +/// callgraph with the call sites found in it.  This is used after +/// FunctionPasses have potentially munged the callgraph, and can be used after +/// CallGraphSCC passes to verify that they correctly updated the callgraph. +/// +/// This function returns true if it devirtualized an existing function call, +/// meaning it turned an indirect call into a direct call.  This happens when +/// a function pass like GVN optimizes away stuff feeding the indirect call. +/// This never happens in checking mode. +bool CGPassManager::RefreshCallGraph(const CallGraphSCC &CurSCC, CallGraph &CG, +                                     bool CheckingMode) { +  DenseMap<Value*, CallGraphNode*> CallSites; + +  LLVM_DEBUG(dbgs() << "CGSCCPASSMGR: Refreshing SCC with " << CurSCC.size() +                    << " nodes:\n"; +             for (CallGraphNode *CGN +                  : CurSCC) CGN->dump();); + +  bool MadeChange = false; +  bool DevirtualizedCall = false; + +  // Scan all functions in the SCC. +  unsigned FunctionNo = 0; +  for (CallGraphSCC::iterator SCCIdx = CurSCC.begin(), E = CurSCC.end(); +       SCCIdx != E; ++SCCIdx, ++FunctionNo) { +    CallGraphNode *CGN = *SCCIdx; +    Function *F = CGN->getFunction(); +    if (!F || F->isDeclaration()) continue; + +    // Walk the function body looking for call sites.  Sync up the call sites in +    // CGN with those actually in the function. + +    // Keep track of the number of direct and indirect calls that were +    // invalidated and removed. +    unsigned NumDirectRemoved = 0, NumIndirectRemoved = 0; + +    // Get the set of call sites currently in the function. +    for (CallGraphNode::iterator I = CGN->begin(), E = CGN->end(); I != E; ) { +      // If this call site is null, then the function pass deleted the call +      // entirely and the WeakTrackingVH nulled it out. +      if (!I->first || +          // If we've already seen this call site, then the FunctionPass RAUW'd +          // one call with another, which resulted in two "uses" in the edge +          // list of the same call. +          CallSites.count(I->first) || + +          // If the call edge is not from a call or invoke, or it is a +          // instrinsic call, then the function pass RAUW'd a call with +          // another value. This can happen when constant folding happens +          // of well known functions etc. +          !CallSite(I->first) || +          (CallSite(I->first).getCalledFunction() && +           CallSite(I->first).getCalledFunction()->isIntrinsic() && +           Intrinsic::isLeaf( +               CallSite(I->first).getCalledFunction()->getIntrinsicID()))) { +        assert(!CheckingMode && +               "CallGraphSCCPass did not update the CallGraph correctly!"); + +        // If this was an indirect call site, count it. +        if (!I->second->getFunction()) +          ++NumIndirectRemoved; +        else +          ++NumDirectRemoved; + +        // Just remove the edge from the set of callees, keep track of whether +        // I points to the last element of the vector. +        bool WasLast = I + 1 == E; +        CGN->removeCallEdge(I); + +        // If I pointed to the last element of the vector, we have to bail out: +        // iterator checking rejects comparisons of the resultant pointer with +        // end. +        if (WasLast) +          break; +        E = CGN->end(); +        continue; +      } + +      assert(!CallSites.count(I->first) && +             "Call site occurs in node multiple times"); + +      CallSite CS(I->first); +      if (CS) { +        Function *Callee = CS.getCalledFunction(); +        // Ignore intrinsics because they're not really function calls. +        if (!Callee || !(Callee->isIntrinsic())) +          CallSites.insert(std::make_pair(I->first, I->second)); +      } +      ++I; +    } + +    // Loop over all of the instructions in the function, getting the callsites. +    // Keep track of the number of direct/indirect calls added. +    unsigned NumDirectAdded = 0, NumIndirectAdded = 0; + +    for (BasicBlock &BB : *F) +      for (Instruction &I : BB) { +        CallSite CS(&I); +        if (!CS) continue; +        Function *Callee = CS.getCalledFunction(); +        if (Callee && Callee->isIntrinsic()) continue; + +        // If this call site already existed in the callgraph, just verify it +        // matches up to expectations and remove it from CallSites. +        DenseMap<Value*, CallGraphNode*>::iterator ExistingIt = +          CallSites.find(CS.getInstruction()); +        if (ExistingIt != CallSites.end()) { +          CallGraphNode *ExistingNode = ExistingIt->second; + +          // Remove from CallSites since we have now seen it. +          CallSites.erase(ExistingIt); + +          // Verify that the callee is right. +          if (ExistingNode->getFunction() == CS.getCalledFunction()) +            continue; + +          // If we are in checking mode, we are not allowed to actually mutate +          // the callgraph.  If this is a case where we can infer that the +          // callgraph is less precise than it could be (e.g. an indirect call +          // site could be turned direct), don't reject it in checking mode, and +          // don't tweak it to be more precise. +          if (CheckingMode && CS.getCalledFunction() && +              ExistingNode->getFunction() == nullptr) +            continue; + +          assert(!CheckingMode && +                 "CallGraphSCCPass did not update the CallGraph correctly!"); + +          // If not, we either went from a direct call to indirect, indirect to +          // direct, or direct to different direct. +          CallGraphNode *CalleeNode; +          if (Function *Callee = CS.getCalledFunction()) { +            CalleeNode = CG.getOrInsertFunction(Callee); +            // Keep track of whether we turned an indirect call into a direct +            // one. +            if (!ExistingNode->getFunction()) { +              DevirtualizedCall = true; +              LLVM_DEBUG(dbgs() << "  CGSCCPASSMGR: Devirtualized call to '" +                                << Callee->getName() << "'\n"); +            } +          } else { +            CalleeNode = CG.getCallsExternalNode(); +          } + +          // Update the edge target in CGN. +          CGN->replaceCallEdge(CS, CS, CalleeNode); +          MadeChange = true; +          continue; +        } + +        assert(!CheckingMode && +               "CallGraphSCCPass did not update the CallGraph correctly!"); + +        // If the call site didn't exist in the CGN yet, add it. +        CallGraphNode *CalleeNode; +        if (Function *Callee = CS.getCalledFunction()) { +          CalleeNode = CG.getOrInsertFunction(Callee); +          ++NumDirectAdded; +        } else { +          CalleeNode = CG.getCallsExternalNode(); +          ++NumIndirectAdded; +        } + +        CGN->addCalledFunction(CS, CalleeNode); +        MadeChange = true; +      } + +    // We scanned the old callgraph node, removing invalidated call sites and +    // then added back newly found call sites.  One thing that can happen is +    // that an old indirect call site was deleted and replaced with a new direct +    // call.  In this case, we have devirtualized a call, and CGSCCPM would like +    // to iteratively optimize the new code.  Unfortunately, we don't really +    // have a great way to detect when this happens.  As an approximation, we +    // just look at whether the number of indirect calls is reduced and the +    // number of direct calls is increased.  There are tons of ways to fool this +    // (e.g. DCE'ing an indirect call and duplicating an unrelated block with a +    // direct call) but this is close enough. +    if (NumIndirectRemoved > NumIndirectAdded && +        NumDirectRemoved < NumDirectAdded) +      DevirtualizedCall = true; + +    // After scanning this function, if we still have entries in callsites, then +    // they are dangling pointers.  WeakTrackingVH should save us for this, so +    // abort if +    // this happens. +    assert(CallSites.empty() && "Dangling pointers found in call sites map"); + +    // Periodically do an explicit clear to remove tombstones when processing +    // large scc's. +    if ((FunctionNo & 15) == 15) +      CallSites.clear(); +  } + +  LLVM_DEBUG(if (MadeChange) { +    dbgs() << "CGSCCPASSMGR: Refreshed SCC is now:\n"; +    for (CallGraphNode *CGN : CurSCC) +      CGN->dump(); +    if (DevirtualizedCall) +      dbgs() << "CGSCCPASSMGR: Refresh devirtualized a call!\n"; +  } else { +    dbgs() << "CGSCCPASSMGR: SCC Refresh didn't change call graph.\n"; +  }); +  (void)MadeChange; + +  return DevirtualizedCall; +} + +/// Execute the body of the entire pass manager on the specified SCC. +/// This keeps track of whether a function pass devirtualizes +/// any calls and returns it in DevirtualizedCall. +bool CGPassManager::RunAllPassesOnSCC(CallGraphSCC &CurSCC, CallGraph &CG, +                                      bool &DevirtualizedCall) { +  bool Changed = false; + +  // Keep track of whether the callgraph is known to be up-to-date or not. +  // The CGSSC pass manager runs two types of passes: +  // CallGraphSCC Passes and other random function passes.  Because other +  // random function passes are not CallGraph aware, they may clobber the +  // call graph by introducing new calls or deleting other ones.  This flag +  // is set to false when we run a function pass so that we know to clean up +  // the callgraph when we need to run a CGSCCPass again. +  bool CallGraphUpToDate = true; + +  // Run all passes on current SCC. +  for (unsigned PassNo = 0, e = getNumContainedPasses(); +       PassNo != e; ++PassNo) { +    Pass *P = getContainedPass(PassNo); + +    // If we're in -debug-pass=Executions mode, construct the SCC node list, +    // otherwise avoid constructing this string as it is expensive. +    if (isPassDebuggingExecutionsOrMore()) { +      std::string Functions; +  #ifndef NDEBUG +      raw_string_ostream OS(Functions); +      for (CallGraphSCC::iterator I = CurSCC.begin(), E = CurSCC.end(); +           I != E; ++I) { +        if (I != CurSCC.begin()) OS << ", "; +        (*I)->print(OS); +      } +      OS.flush(); +  #endif +      dumpPassInfo(P, EXECUTION_MSG, ON_CG_MSG, Functions); +    } +    dumpRequiredSet(P); + +    initializeAnalysisImpl(P); + +    // Actually run this pass on the current SCC. +    Changed |= RunPassOnSCC(P, CurSCC, CG, +                            CallGraphUpToDate, DevirtualizedCall); + +    if (Changed) +      dumpPassInfo(P, MODIFICATION_MSG, ON_CG_MSG, ""); +    dumpPreservedSet(P); + +    verifyPreservedAnalysis(P); +    removeNotPreservedAnalysis(P); +    recordAvailableAnalysis(P); +    removeDeadPasses(P, "", ON_CG_MSG); +  } + +  // If the callgraph was left out of date (because the last pass run was a +  // functionpass), refresh it before we move on to the next SCC. +  if (!CallGraphUpToDate) +    DevirtualizedCall |= RefreshCallGraph(CurSCC, CG, false); +  return Changed; +} + +/// Execute all of the passes scheduled for execution.  Keep track of +/// whether any of the passes modifies the module, and if so, return true. +bool CGPassManager::runOnModule(Module &M) { +  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); +  bool Changed = doInitialization(CG); + +  // Walk the callgraph in bottom-up SCC order. +  scc_iterator<CallGraph*> CGI = scc_begin(&CG); + +  CallGraphSCC CurSCC(CG, &CGI); +  while (!CGI.isAtEnd()) { +    // Copy the current SCC and increment past it so that the pass can hack +    // on the SCC if it wants to without invalidating our iterator. +    const std::vector<CallGraphNode *> &NodeVec = *CGI; +    CurSCC.initialize(NodeVec); +    ++CGI; + +    // At the top level, we run all the passes in this pass manager on the +    // functions in this SCC.  However, we support iterative compilation in the +    // case where a function pass devirtualizes a call to a function.  For +    // example, it is very common for a function pass (often GVN or instcombine) +    // to eliminate the addressing that feeds into a call.  With that improved +    // information, we would like the call to be an inline candidate, infer +    // mod-ref information etc. +    // +    // Because of this, we allow iteration up to a specified iteration count. +    // This only happens in the case of a devirtualized call, so we only burn +    // compile time in the case that we're making progress.  We also have a hard +    // iteration count limit in case there is crazy code. +    unsigned Iteration = 0; +    bool DevirtualizedCall = false; +    do { +      LLVM_DEBUG(if (Iteration) dbgs() +                 << "  SCCPASSMGR: Re-visiting SCC, iteration #" << Iteration +                 << '\n'); +      DevirtualizedCall = false; +      Changed |= RunAllPassesOnSCC(CurSCC, CG, DevirtualizedCall); +    } while (Iteration++ < MaxIterations && DevirtualizedCall); + +    if (DevirtualizedCall) +      LLVM_DEBUG(dbgs() << "  CGSCCPASSMGR: Stopped iteration after " +                        << Iteration +                        << " times, due to -max-cg-scc-iterations\n"); + +    MaxSCCIterations.updateMax(Iteration); +  } +  Changed |= doFinalization(CG); +  return Changed; +} + +/// Initialize CG +bool CGPassManager::doInitialization(CallGraph &CG) { +  bool Changed = false; +  for (unsigned i = 0, e = getNumContainedPasses(); i != e; ++i) { +    if (PMDataManager *PM = getContainedPass(i)->getAsPMDataManager()) { +      assert(PM->getPassManagerType() == PMT_FunctionPassManager && +             "Invalid CGPassManager member"); +      Changed |= ((FPPassManager*)PM)->doInitialization(CG.getModule()); +    } else { +      Changed |= ((CallGraphSCCPass*)getContainedPass(i))->doInitialization(CG); +    } +  } +  return Changed; +} + +/// Finalize CG +bool CGPassManager::doFinalization(CallGraph &CG) { +  bool Changed = false; +  for (unsigned i = 0, e = getNumContainedPasses(); i != e; ++i) { +    if (PMDataManager *PM = getContainedPass(i)->getAsPMDataManager()) { +      assert(PM->getPassManagerType() == PMT_FunctionPassManager && +             "Invalid CGPassManager member"); +      Changed |= ((FPPassManager*)PM)->doFinalization(CG.getModule()); +    } else { +      Changed |= ((CallGraphSCCPass*)getContainedPass(i))->doFinalization(CG); +    } +  } +  return Changed; +} + +//===----------------------------------------------------------------------===// +// CallGraphSCC Implementation +//===----------------------------------------------------------------------===// + +/// This informs the SCC and the pass manager that the specified +/// Old node has been deleted, and New is to be used in its place. +void CallGraphSCC::ReplaceNode(CallGraphNode *Old, CallGraphNode *New) { +  assert(Old != New && "Should not replace node with self"); +  for (unsigned i = 0; ; ++i) { +    assert(i != Nodes.size() && "Node not in SCC"); +    if (Nodes[i] != Old) continue; +    Nodes[i] = New; +    break; +  } + +  // Update the active scc_iterator so that it doesn't contain dangling +  // pointers to the old CallGraphNode. +  scc_iterator<CallGraph*> *CGI = (scc_iterator<CallGraph*>*)Context; +  CGI->ReplaceNode(Old, New); +} + +//===----------------------------------------------------------------------===// +// CallGraphSCCPass Implementation +//===----------------------------------------------------------------------===// + +/// Assign pass manager to manage this pass. +void CallGraphSCCPass::assignPassManager(PMStack &PMS, +                                         PassManagerType PreferredType) { +  // Find CGPassManager +  while (!PMS.empty() && +         PMS.top()->getPassManagerType() > PMT_CallGraphPassManager) +    PMS.pop(); + +  assert(!PMS.empty() && "Unable to handle Call Graph Pass"); +  CGPassManager *CGP; + +  if (PMS.top()->getPassManagerType() == PMT_CallGraphPassManager) +    CGP = (CGPassManager*)PMS.top(); +  else { +    // Create new Call Graph SCC Pass Manager if it does not exist. +    assert(!PMS.empty() && "Unable to create Call Graph Pass Manager"); +    PMDataManager *PMD = PMS.top(); + +    // [1] Create new Call Graph Pass Manager +    CGP = new CGPassManager(); + +    // [2] Set up new manager's top level manager +    PMTopLevelManager *TPM = PMD->getTopLevelManager(); +    TPM->addIndirectPassManager(CGP); + +    // [3] Assign manager to manage this new manager. This may create +    // and push new managers into PMS +    Pass *P = CGP; +    TPM->schedulePass(P); + +    // [4] Push new manager into PMS +    PMS.push(CGP); +  } + +  CGP->add(this); +} + +/// For this class, we declare that we require and preserve the call graph. +/// If the derived class implements this method, it should +/// always explicitly call the implementation here. +void CallGraphSCCPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.addRequired<CallGraphWrapperPass>(); +  AU.addPreserved<CallGraphWrapperPass>(); +} + +//===----------------------------------------------------------------------===// +// PrintCallGraphPass Implementation +//===----------------------------------------------------------------------===// + +namespace { + +  /// PrintCallGraphPass - Print a Module corresponding to a call graph. +  /// +  class PrintCallGraphPass : public CallGraphSCCPass { +    std::string Banner; +    raw_ostream &OS;       // raw_ostream to print on. + +  public: +    static char ID; + +    PrintCallGraphPass(const std::string &B, raw_ostream &OS) +      : CallGraphSCCPass(ID), Banner(B), OS(OS) {} + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } + +    bool runOnSCC(CallGraphSCC &SCC) override { +      bool BannerPrinted = false; +      auto PrintBannerOnce = [&] () { +        if (BannerPrinted) +          return; +        OS << Banner; +        BannerPrinted = true; +        }; +      for (CallGraphNode *CGN : SCC) { +        if (Function *F = CGN->getFunction()) { +          if (!F->isDeclaration() && isFunctionInPrintList(F->getName())) { +            PrintBannerOnce(); +            F->print(OS); +          } +        } else if (isFunctionInPrintList("*")) { +          PrintBannerOnce(); +          OS << "\nPrinting <null> Function\n"; +        } +      } +      return false; +    } + +    StringRef getPassName() const override { return "Print CallGraph IR"; } +  }; + +} // end anonymous namespace. + +char PrintCallGraphPass::ID = 0; + +Pass *CallGraphSCCPass::createPrinterPass(raw_ostream &OS, +                                          const std::string &Banner) const { +  return new PrintCallGraphPass(Banner, OS); +} + +bool CallGraphSCCPass::skipSCC(CallGraphSCC &SCC) const { +  return !SCC.getCallGraph().getModule() +              .getContext() +              .getOptPassGate() +              .shouldRunPass(this, SCC); +} + +char DummyCGSCCPass::ID = 0; + +INITIALIZE_PASS(DummyCGSCCPass, "DummyCGSCCPass", "DummyCGSCCPass", false, +                false) diff --git a/contrib/llvm/lib/Analysis/CallPrinter.cpp b/contrib/llvm/lib/Analysis/CallPrinter.cpp new file mode 100644 index 000000000000..e7017e77652a --- /dev/null +++ b/contrib/llvm/lib/Analysis/CallPrinter.cpp @@ -0,0 +1,92 @@ +//===- CallPrinter.cpp - DOT printer for call graph -----------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines '-dot-callgraph', which emit a callgraph.<fnname>.dot +// containing the call graph of a module. +// +// There is also a pass available to directly call dotty ('-view-callgraph'). +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CallPrinter.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/DOTGraphTraitsPass.h" + +using namespace llvm; + +namespace llvm { + +template <> struct DOTGraphTraits<CallGraph *> : public DefaultDOTGraphTraits { +  DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + +  static std::string getGraphName(CallGraph *Graph) { return "Call graph"; } + +  std::string getNodeLabel(CallGraphNode *Node, CallGraph *Graph) { +    if (Function *Func = Node->getFunction()) +      return Func->getName(); + +    return "external node"; +  } +}; + +struct AnalysisCallGraphWrapperPassTraits { +  static CallGraph *getGraph(CallGraphWrapperPass *P) { +    return &P->getCallGraph(); +  } +}; + +} // end llvm namespace + +namespace { + +struct CallGraphViewer +    : public DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *, +                                        AnalysisCallGraphWrapperPassTraits> { +  static char ID; + +  CallGraphViewer() +      : DOTGraphTraitsModuleViewer<CallGraphWrapperPass, true, CallGraph *, +                                   AnalysisCallGraphWrapperPassTraits>( +            "callgraph", ID) { +    initializeCallGraphViewerPass(*PassRegistry::getPassRegistry()); +  } +}; + +struct CallGraphDOTPrinter : public DOTGraphTraitsModulePrinter< +                              CallGraphWrapperPass, true, CallGraph *, +                              AnalysisCallGraphWrapperPassTraits> { +  static char ID; + +  CallGraphDOTPrinter() +      : DOTGraphTraitsModulePrinter<CallGraphWrapperPass, true, CallGraph *, +                                    AnalysisCallGraphWrapperPassTraits>( +            "callgraph", ID) { +    initializeCallGraphDOTPrinterPass(*PassRegistry::getPassRegistry()); +  } +}; + +} // end anonymous namespace + +char CallGraphViewer::ID = 0; +INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false, +                false) + +char CallGraphDOTPrinter::ID = 0; +INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph", +                "Print call graph to 'dot' file", false, false) + +// Create methods available outside of this file, to use them +// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by +// the link time optimization. + +ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); } + +ModulePass *llvm::createCallGraphDOTPrinterPass() { +  return new CallGraphDOTPrinter(); +} diff --git a/contrib/llvm/lib/Analysis/CaptureTracking.cpp b/contrib/llvm/lib/Analysis/CaptureTracking.cpp new file mode 100644 index 000000000000..d4f73bdb4361 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CaptureTracking.cpp @@ -0,0 +1,364 @@ +//===--- CaptureTracking.cpp - Determine whether a pointer is captured ----===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains routines that help determine which pointers are captured. +// A pointer value is captured if the function makes a copy of any part of the +// pointer that outlives the call.  Not being captured means, more or less, that +// the pointer is only dereferenced and not stored in a global.  Returning part +// of the pointer as the function return value may or may not count as capturing +// the pointer, depending on the context. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" + +using namespace llvm; + +CaptureTracker::~CaptureTracker() {} + +bool CaptureTracker::shouldExplore(const Use *U) { return true; } + +namespace { +  struct SimpleCaptureTracker : public CaptureTracker { +    explicit SimpleCaptureTracker(bool ReturnCaptures) +      : ReturnCaptures(ReturnCaptures), Captured(false) {} + +    void tooManyUses() override { Captured = true; } + +    bool captured(const Use *U) override { +      if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures) +        return false; + +      Captured = true; +      return true; +    } + +    bool ReturnCaptures; + +    bool Captured; +  }; + +  /// Only find pointer captures which happen before the given instruction. Uses +  /// the dominator tree to determine whether one instruction is before another. +  /// Only support the case where the Value is defined in the same basic block +  /// as the given instruction and the use. +  struct CapturesBefore : public CaptureTracker { + +    CapturesBefore(bool ReturnCaptures, const Instruction *I, const DominatorTree *DT, +                   bool IncludeI, OrderedBasicBlock *IC) +      : OrderedBB(IC), BeforeHere(I), DT(DT), +        ReturnCaptures(ReturnCaptures), IncludeI(IncludeI), Captured(false) {} + +    void tooManyUses() override { Captured = true; } + +    bool isSafeToPrune(Instruction *I) { +      BasicBlock *BB = I->getParent(); +      // We explore this usage only if the usage can reach "BeforeHere". +      // If use is not reachable from entry, there is no need to explore. +      if (BeforeHere != I && !DT->isReachableFromEntry(BB)) +        return true; + +      // Compute the case where both instructions are inside the same basic +      // block. Since instructions in the same BB as BeforeHere are numbered in +      // 'OrderedBB', avoid using 'dominates' and 'isPotentiallyReachable' +      // which are very expensive for large basic blocks. +      if (BB == BeforeHere->getParent()) { +        // 'I' dominates 'BeforeHere' => not safe to prune. +        // +        // The value defined by an invoke dominates an instruction only +        // if it dominates every instruction in UseBB. A PHI is dominated only +        // if the instruction dominates every possible use in the UseBB. Since +        // UseBB == BB, avoid pruning. +        if (isa<InvokeInst>(BeforeHere) || isa<PHINode>(I) || I == BeforeHere) +          return false; +        if (!OrderedBB->dominates(BeforeHere, I)) +          return false; + +        // 'BeforeHere' comes before 'I', it's safe to prune if we also +        // guarantee that 'I' never reaches 'BeforeHere' through a back-edge or +        // by its successors, i.e, prune if: +        // +        //  (1) BB is an entry block or have no successors. +        //  (2) There's no path coming back through BB successors. +        if (BB == &BB->getParent()->getEntryBlock() || +            !BB->getTerminator()->getNumSuccessors()) +          return true; + +        SmallVector<BasicBlock*, 32> Worklist; +        Worklist.append(succ_begin(BB), succ_end(BB)); +        return !isPotentiallyReachableFromMany(Worklist, BB, DT); +      } + +      // If the value is defined in the same basic block as use and BeforeHere, +      // there is no need to explore the use if BeforeHere dominates use. +      // Check whether there is a path from I to BeforeHere. +      if (BeforeHere != I && DT->dominates(BeforeHere, I) && +          !isPotentiallyReachable(I, BeforeHere, DT)) +        return true; + +      return false; +    } + +    bool shouldExplore(const Use *U) override { +      Instruction *I = cast<Instruction>(U->getUser()); + +      if (BeforeHere == I && !IncludeI) +        return false; + +      if (isSafeToPrune(I)) +        return false; + +      return true; +    } + +    bool captured(const Use *U) override { +      if (isa<ReturnInst>(U->getUser()) && !ReturnCaptures) +        return false; + +      if (!shouldExplore(U)) +        return false; + +      Captured = true; +      return true; +    } + +    OrderedBasicBlock *OrderedBB; +    const Instruction *BeforeHere; +    const DominatorTree *DT; + +    bool ReturnCaptures; +    bool IncludeI; + +    bool Captured; +  }; +} + +/// PointerMayBeCaptured - Return true if this pointer value may be captured +/// by the enclosing function (which is required to exist).  This routine can +/// be expensive, so consider caching the results.  The boolean ReturnCaptures +/// specifies whether returning the value (or part of it) from the function +/// counts as capturing it or not.  The boolean StoreCaptures specified whether +/// storing the value (or part of it) into memory anywhere automatically +/// counts as capturing it or not. +bool llvm::PointerMayBeCaptured(const Value *V, +                                bool ReturnCaptures, bool StoreCaptures) { +  assert(!isa<GlobalValue>(V) && +         "It doesn't make sense to ask whether a global is captured."); + +  // TODO: If StoreCaptures is not true, we could do Fancy analysis +  // to determine whether this store is not actually an escape point. +  // In that case, BasicAliasAnalysis should be updated as well to +  // take advantage of this. +  (void)StoreCaptures; + +  SimpleCaptureTracker SCT(ReturnCaptures); +  PointerMayBeCaptured(V, &SCT); +  return SCT.Captured; +} + +/// PointerMayBeCapturedBefore - Return true if this pointer value may be +/// captured by the enclosing function (which is required to exist). If a +/// DominatorTree is provided, only captures which happen before the given +/// instruction are considered. This routine can be expensive, so consider +/// caching the results.  The boolean ReturnCaptures specifies whether +/// returning the value (or part of it) from the function counts as capturing +/// it or not.  The boolean StoreCaptures specified whether storing the value +/// (or part of it) into memory anywhere automatically counts as capturing it +/// or not. A ordered basic block \p OBB can be used in order to speed up +/// queries about relative order among instructions in the same basic block. +bool llvm::PointerMayBeCapturedBefore(const Value *V, bool ReturnCaptures, +                                      bool StoreCaptures, const Instruction *I, +                                      const DominatorTree *DT, bool IncludeI, +                                      OrderedBasicBlock *OBB) { +  assert(!isa<GlobalValue>(V) && +         "It doesn't make sense to ask whether a global is captured."); +  bool UseNewOBB = OBB == nullptr; + +  if (!DT) +    return PointerMayBeCaptured(V, ReturnCaptures, StoreCaptures); +  if (UseNewOBB) +    OBB = new OrderedBasicBlock(I->getParent()); + +  // TODO: See comment in PointerMayBeCaptured regarding what could be done +  // with StoreCaptures. + +  CapturesBefore CB(ReturnCaptures, I, DT, IncludeI, OBB); +  PointerMayBeCaptured(V, &CB); + +  if (UseNewOBB) +    delete OBB; +  return CB.Captured; +} + +/// TODO: Write a new FunctionPass AliasAnalysis so that it can keep +/// a cache. Then we can move the code from BasicAliasAnalysis into +/// that path, and remove this threshold. +static int const Threshold = 20; + +void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { +  assert(V->getType()->isPointerTy() && "Capture is for pointers only!"); +  SmallVector<const Use *, Threshold> Worklist; +  SmallSet<const Use *, Threshold> Visited; + +  auto AddUses = [&](const Value *V) { +    int Count = 0; +    for (const Use &U : V->uses()) { +      // If there are lots of uses, conservatively say that the value +      // is captured to avoid taking too much compile time. +      if (Count++ >= Threshold) +        return Tracker->tooManyUses(); +      if (!Visited.insert(&U).second) +        continue; +      if (!Tracker->shouldExplore(&U)) +        continue; +      Worklist.push_back(&U); +    } +  }; +  AddUses(V); + +  while (!Worklist.empty()) { +    const Use *U = Worklist.pop_back_val(); +    Instruction *I = cast<Instruction>(U->getUser()); +    V = U->get(); + +    switch (I->getOpcode()) { +    case Instruction::Call: +    case Instruction::Invoke: { +      CallSite CS(I); +      // Not captured if the callee is readonly, doesn't return a copy through +      // its return value and doesn't unwind (a readonly function can leak bits +      // by throwing an exception or not depending on the input value). +      if (CS.onlyReadsMemory() && CS.doesNotThrow() && I->getType()->isVoidTy()) +        break; + +      // The pointer is not captured if returned pointer is not captured. +      // NOTE: CaptureTracking users should not assume that only functions +      // marked with nocapture do not capture. This means that places like +      // GetUnderlyingObject in ValueTracking or DecomposeGEPExpression +      // in BasicAA also need to know about this property. +      if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(CS)) { +        AddUses(I); +        break; +      } + +      // Volatile operations effectively capture the memory location that they +      // load and store to. +      if (auto *MI = dyn_cast<MemIntrinsic>(I)) +        if (MI->isVolatile()) +          if (Tracker->captured(U)) +            return; + +      // Not captured if only passed via 'nocapture' arguments.  Note that +      // calling a function pointer does not in itself cause the pointer to +      // be captured.  This is a subtle point considering that (for example) +      // the callee might return its own address.  It is analogous to saying +      // that loading a value from a pointer does not cause the pointer to be +      // captured, even though the loaded value might be the pointer itself +      // (think of self-referential objects). +      CallSite::data_operand_iterator B = +        CS.data_operands_begin(), E = CS.data_operands_end(); +      for (CallSite::data_operand_iterator A = B; A != E; ++A) +        if (A->get() == V && !CS.doesNotCapture(A - B)) +          // The parameter is not marked 'nocapture' - captured. +          if (Tracker->captured(U)) +            return; +      break; +    } +    case Instruction::Load: +      // Volatile loads make the address observable. +      if (cast<LoadInst>(I)->isVolatile()) +        if (Tracker->captured(U)) +          return; +      break; +    case Instruction::VAArg: +      // "va-arg" from a pointer does not cause it to be captured. +      break; +    case Instruction::Store: +        // Stored the pointer - conservatively assume it may be captured. +        // Volatile stores make the address observable. +      if (V == I->getOperand(0) || cast<StoreInst>(I)->isVolatile()) +        if (Tracker->captured(U)) +          return; +      break; +    case Instruction::AtomicRMW: { +      // atomicrmw conceptually includes both a load and store from +      // the same location. +      // As with a store, the location being accessed is not captured, +      // but the value being stored is. +      // Volatile stores make the address observable. +      auto *ARMWI = cast<AtomicRMWInst>(I); +      if (ARMWI->getValOperand() == V || ARMWI->isVolatile()) +        if (Tracker->captured(U)) +          return; +      break; +    } +    case Instruction::AtomicCmpXchg: { +      // cmpxchg conceptually includes both a load and store from +      // the same location. +      // As with a store, the location being accessed is not captured, +      // but the value being stored is. +      // Volatile stores make the address observable. +      auto *ACXI = cast<AtomicCmpXchgInst>(I); +      if (ACXI->getCompareOperand() == V || ACXI->getNewValOperand() == V || +          ACXI->isVolatile()) +        if (Tracker->captured(U)) +          return; +      break; +    } +    case Instruction::BitCast: +    case Instruction::GetElementPtr: +    case Instruction::PHI: +    case Instruction::Select: +    case Instruction::AddrSpaceCast: +      // The original value is not captured via this if the new value isn't. +      AddUses(I); +      break; +    case Instruction::ICmp: { +      // Don't count comparisons of a no-alias return value against null as +      // captures. This allows us to ignore comparisons of malloc results +      // with null, for example. +      if (ConstantPointerNull *CPN = +          dyn_cast<ConstantPointerNull>(I->getOperand(1))) +        if (CPN->getType()->getAddressSpace() == 0) +          if (isNoAliasCall(V->stripPointerCasts())) +            break; +      // Comparison against value stored in global variable. Given the pointer +      // does not escape, its value cannot be guessed and stored separately in a +      // global variable. +      unsigned OtherIndex = (I->getOperand(0) == V) ? 1 : 0; +      auto *LI = dyn_cast<LoadInst>(I->getOperand(OtherIndex)); +      if (LI && isa<GlobalVariable>(LI->getPointerOperand())) +        break; +      // Otherwise, be conservative. There are crazy ways to capture pointers +      // using comparisons. +      if (Tracker->captured(U)) +        return; +      break; +    } +    default: +      // Something else - be conservative and say it is captured. +      if (Tracker->captured(U)) +        return; +      break; +    } +  } + +  // All uses examined. +} diff --git a/contrib/llvm/lib/Analysis/CmpInstAnalysis.cpp b/contrib/llvm/lib/Analysis/CmpInstAnalysis.cpp new file mode 100644 index 000000000000..159c1a2d135a --- /dev/null +++ b/contrib/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -0,0 +1,144 @@ +//===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file holds routines to help analyse compare instructions +// and fold them into constants or other compare instructions +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CmpInstAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; + +unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) { +  ICmpInst::Predicate Pred = InvertPred ? ICI->getInversePredicate() +                                        : ICI->getPredicate(); +  switch (Pred) { +      // False -> 0 +    case ICmpInst::ICMP_UGT: return 1;  // 001 +    case ICmpInst::ICMP_SGT: return 1;  // 001 +    case ICmpInst::ICMP_EQ:  return 2;  // 010 +    case ICmpInst::ICMP_UGE: return 3;  // 011 +    case ICmpInst::ICMP_SGE: return 3;  // 011 +    case ICmpInst::ICMP_ULT: return 4;  // 100 +    case ICmpInst::ICMP_SLT: return 4;  // 100 +    case ICmpInst::ICMP_NE:  return 5;  // 101 +    case ICmpInst::ICMP_ULE: return 6;  // 110 +    case ICmpInst::ICMP_SLE: return 6;  // 110 +      // True -> 7 +    default: +      llvm_unreachable("Invalid ICmp predicate!"); +  } +} + +Value *llvm::getICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, +                          CmpInst::Predicate &NewICmpPred) { +  switch (Code) { +    default: llvm_unreachable("Illegal ICmp code!"); +    case 0: // False. +      return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); +    case 1: NewICmpPred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; +    case 2: NewICmpPred = ICmpInst::ICMP_EQ; break; +    case 3: NewICmpPred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; +    case 4: NewICmpPred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; +    case 5: NewICmpPred = ICmpInst::ICMP_NE; break; +    case 6: NewICmpPred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; +    case 7: // True. +      return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); +  } +  return nullptr; +} + +bool llvm::PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) { +  return (CmpInst::isSigned(p1) == CmpInst::isSigned(p2)) || +         (CmpInst::isSigned(p1) && ICmpInst::isEquality(p2)) || +         (CmpInst::isSigned(p2) && ICmpInst::isEquality(p1)); +} + +bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, +                                CmpInst::Predicate &Pred, +                                Value *&X, APInt &Mask, bool LookThruTrunc) { +  using namespace PatternMatch; + +  const APInt *C; +  if (!match(RHS, m_APInt(C))) +    return false; + +  switch (Pred) { +  default: +    return false; +  case ICmpInst::ICMP_SLT: +    // X < 0 is equivalent to (X & SignMask) != 0. +    if (!C->isNullValue()) +      return false; +    Mask = APInt::getSignMask(C->getBitWidth()); +    Pred = ICmpInst::ICMP_NE; +    break; +  case ICmpInst::ICMP_SLE: +    // X <= -1 is equivalent to (X & SignMask) != 0. +    if (!C->isAllOnesValue()) +      return false; +    Mask = APInt::getSignMask(C->getBitWidth()); +    Pred = ICmpInst::ICMP_NE; +    break; +  case ICmpInst::ICMP_SGT: +    // X > -1 is equivalent to (X & SignMask) == 0. +    if (!C->isAllOnesValue()) +      return false; +    Mask = APInt::getSignMask(C->getBitWidth()); +    Pred = ICmpInst::ICMP_EQ; +    break; +  case ICmpInst::ICMP_SGE: +    // X >= 0 is equivalent to (X & SignMask) == 0. +    if (!C->isNullValue()) +      return false; +    Mask = APInt::getSignMask(C->getBitWidth()); +    Pred = ICmpInst::ICMP_EQ; +    break; +  case ICmpInst::ICMP_ULT: +    // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. +    if (!C->isPowerOf2()) +      return false; +    Mask = -*C; +    Pred = ICmpInst::ICMP_EQ; +    break; +  case ICmpInst::ICMP_ULE: +    // X <=u 2^n-1 is equivalent to (X & ~(2^n-1)) == 0. +    if (!(*C + 1).isPowerOf2()) +      return false; +    Mask = ~*C; +    Pred = ICmpInst::ICMP_EQ; +    break; +  case ICmpInst::ICMP_UGT: +    // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0. +    if (!(*C + 1).isPowerOf2()) +      return false; +    Mask = ~*C; +    Pred = ICmpInst::ICMP_NE; +    break; +  case ICmpInst::ICMP_UGE: +    // X >=u 2^n is equivalent to (X & ~(2^n-1)) != 0. +    if (!C->isPowerOf2()) +      return false; +    Mask = -*C; +    Pred = ICmpInst::ICMP_NE; +    break; +  } + +  if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) { +    Mask = Mask.zext(X->getType()->getScalarSizeInBits()); +  } else { +    X = LHS; +  } + +  return true; +} diff --git a/contrib/llvm/lib/Analysis/CodeMetrics.cpp b/contrib/llvm/lib/Analysis/CodeMetrics.cpp new file mode 100644 index 000000000000..46cc87d2b178 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CodeMetrics.cpp @@ -0,0 +1,199 @@ +//===- CodeMetrics.cpp - Code cost measurements ---------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements code cost measurement utilities. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +#define DEBUG_TYPE "code-metrics" + +using namespace llvm; + +static void +appendSpeculatableOperands(const Value *V, +                           SmallPtrSetImpl<const Value *> &Visited, +                           SmallVectorImpl<const Value *> &Worklist) { +  const User *U = dyn_cast<User>(V); +  if (!U) +    return; + +  for (const Value *Operand : U->operands()) +    if (Visited.insert(Operand).second) +      if (isSafeToSpeculativelyExecute(Operand)) +        Worklist.push_back(Operand); +} + +static void completeEphemeralValues(SmallPtrSetImpl<const Value *> &Visited, +                                    SmallVectorImpl<const Value *> &Worklist, +                                    SmallPtrSetImpl<const Value *> &EphValues) { +  // Note: We don't speculate PHIs here, so we'll miss instruction chains kept +  // alive only by ephemeral values. + +  // Walk the worklist using an index but without caching the size so we can +  // append more entries as we process the worklist. This forms a queue without +  // quadratic behavior by just leaving processed nodes at the head of the +  // worklist forever. +  for (int i = 0; i < (int)Worklist.size(); ++i) { +    const Value *V = Worklist[i]; + +    assert(Visited.count(V) && +           "Failed to add a worklist entry to our visited set!"); + +    // If all uses of this value are ephemeral, then so is this value. +    if (!all_of(V->users(), [&](const User *U) { return EphValues.count(U); })) +      continue; + +    EphValues.insert(V); +    LLVM_DEBUG(dbgs() << "Ephemeral Value: " << *V << "\n"); + +    // Append any more operands to consider. +    appendSpeculatableOperands(V, Visited, Worklist); +  } +} + +// Find all ephemeral values. +void CodeMetrics::collectEphemeralValues( +    const Loop *L, AssumptionCache *AC, +    SmallPtrSetImpl<const Value *> &EphValues) { +  SmallPtrSet<const Value *, 32> Visited; +  SmallVector<const Value *, 16> Worklist; + +  for (auto &AssumeVH : AC->assumptions()) { +    if (!AssumeVH) +      continue; +    Instruction *I = cast<Instruction>(AssumeVH); + +    // Filter out call sites outside of the loop so we don't do a function's +    // worth of work for each of its loops (and, in the common case, ephemeral +    // values in the loop are likely due to @llvm.assume calls in the loop). +    if (!L->contains(I->getParent())) +      continue; + +    if (EphValues.insert(I).second) +      appendSpeculatableOperands(I, Visited, Worklist); +  } + +  completeEphemeralValues(Visited, Worklist, EphValues); +} + +void CodeMetrics::collectEphemeralValues( +    const Function *F, AssumptionCache *AC, +    SmallPtrSetImpl<const Value *> &EphValues) { +  SmallPtrSet<const Value *, 32> Visited; +  SmallVector<const Value *, 16> Worklist; + +  for (auto &AssumeVH : AC->assumptions()) { +    if (!AssumeVH) +      continue; +    Instruction *I = cast<Instruction>(AssumeVH); +    assert(I->getParent()->getParent() == F && +           "Found assumption for the wrong function!"); + +    if (EphValues.insert(I).second) +      appendSpeculatableOperands(I, Visited, Worklist); +  } + +  completeEphemeralValues(Visited, Worklist, EphValues); +} + +/// Fill in the current structure with information gleaned from the specified +/// block. +void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB, +                                    const TargetTransformInfo &TTI, +                                    const SmallPtrSetImpl<const Value*> &EphValues) { +  ++NumBlocks; +  unsigned NumInstsBeforeThisBB = NumInsts; +  for (const Instruction &I : *BB) { +    // Skip ephemeral values. +    if (EphValues.count(&I)) +      continue; + +    // Special handling for calls. +    if (isa<CallInst>(I) || isa<InvokeInst>(I)) { +      ImmutableCallSite CS(&I); + +      if (const Function *F = CS.getCalledFunction()) { +        // If a function is both internal and has a single use, then it is +        // extremely likely to get inlined in the future (it was probably +        // exposed by an interleaved devirtualization pass). +        if (!CS.isNoInline() && F->hasInternalLinkage() && F->hasOneUse()) +          ++NumInlineCandidates; + +        // If this call is to function itself, then the function is recursive. +        // Inlining it into other functions is a bad idea, because this is +        // basically just a form of loop peeling, and our metrics aren't useful +        // for that case. +        if (F == BB->getParent()) +          isRecursive = true; + +        if (TTI.isLoweredToCall(F)) +          ++NumCalls; +      } else { +        // We don't want inline asm to count as a call - that would prevent loop +        // unrolling. The argument setup cost is still real, though. +        if (!isa<InlineAsm>(CS.getCalledValue())) +          ++NumCalls; +      } +    } + +    if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { +      if (!AI->isStaticAlloca()) +        this->usesDynamicAlloca = true; +    } + +    if (isa<ExtractElementInst>(I) || I.getType()->isVectorTy()) +      ++NumVectorInsts; + +    if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) +      notDuplicatable = true; + +    if (const CallInst *CI = dyn_cast<CallInst>(&I)) { +      if (CI->cannotDuplicate()) +        notDuplicatable = true; +      if (CI->isConvergent()) +        convergent = true; +    } + +    if (const InvokeInst *InvI = dyn_cast<InvokeInst>(&I)) +      if (InvI->cannotDuplicate()) +        notDuplicatable = true; + +    NumInsts += TTI.getUserCost(&I); +  } + +  if (isa<ReturnInst>(BB->getTerminator())) +    ++NumRets; + +  // We never want to inline functions that contain an indirectbr.  This is +  // incorrect because all the blockaddress's (in static global initializers +  // for example) would be referring to the original function, and this indirect +  // jump would jump from the inlined copy of the function into the original +  // function which is extremely undefined behavior. +  // FIXME: This logic isn't really right; we can safely inline functions +  // with indirectbr's as long as no other function or global references the +  // blockaddress of a block within the current function.  And as a QOI issue, +  // if someone is using a blockaddress without an indirectbr, and that +  // reference somehow ends up in another function or global, we probably +  // don't want to inline this function. +  notDuplicatable |= isa<IndirectBrInst>(BB->getTerminator()); + +  // Remember NumInsts for this BB. +  NumBBInsts[BB] = NumInsts - NumInstsBeforeThisBB; +} diff --git a/contrib/llvm/lib/Analysis/ConstantFolding.cpp b/contrib/llvm/lib/Analysis/ConstantFolding.cpp new file mode 100644 index 000000000000..c5281c57bc19 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ConstantFolding.cpp @@ -0,0 +1,2262 @@ +//===-- ConstantFolding.cpp - Fold instructions into constants ------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines routines for folding instructions into constants. +// +// Also, to supplement the basic IR ConstantExpr simplifications, +// this file defines some additional folding routines that can make use of +// DataLayout information. These functions cannot go in IR due to library +// dependency issues. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/config.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" +#include <cassert> +#include <cerrno> +#include <cfenv> +#include <cmath> +#include <cstddef> +#include <cstdint> + +using namespace llvm; + +namespace { + +//===----------------------------------------------------------------------===// +// Constant Folding internal helper functions +//===----------------------------------------------------------------------===// + +static Constant *foldConstVectorToAPInt(APInt &Result, Type *DestTy, +                                        Constant *C, Type *SrcEltTy, +                                        unsigned NumSrcElts, +                                        const DataLayout &DL) { +  // Now that we know that the input value is a vector of integers, just shift +  // and insert them into our result. +  unsigned BitShift = DL.getTypeSizeInBits(SrcEltTy); +  for (unsigned i = 0; i != NumSrcElts; ++i) { +    Constant *Element; +    if (DL.isLittleEndian()) +      Element = C->getAggregateElement(NumSrcElts - i - 1); +    else +      Element = C->getAggregateElement(i); + +    if (Element && isa<UndefValue>(Element)) { +      Result <<= BitShift; +      continue; +    } + +    auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); +    if (!ElementCI) +      return ConstantExpr::getBitCast(C, DestTy); + +    Result <<= BitShift; +    Result |= ElementCI->getValue().zextOrSelf(Result.getBitWidth()); +  } + +  return nullptr; +} + +/// Constant fold bitcast, symbolically evaluating it with DataLayout. +/// This always returns a non-null constant, but it may be a +/// ConstantExpr if unfoldable. +Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { +  // Catch the obvious splat cases. +  if (C->isNullValue() && !DestTy->isX86_MMXTy()) +    return Constant::getNullValue(DestTy); +  if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && +      !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types! +    return Constant::getAllOnesValue(DestTy); + +  if (auto *VTy = dyn_cast<VectorType>(C->getType())) { +    // Handle a vector->scalar integer/fp cast. +    if (isa<IntegerType>(DestTy) || DestTy->isFloatingPointTy()) { +      unsigned NumSrcElts = VTy->getNumElements(); +      Type *SrcEltTy = VTy->getElementType(); + +      // If the vector is a vector of floating point, convert it to vector of int +      // to simplify things. +      if (SrcEltTy->isFloatingPointTy()) { +        unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits(); +        Type *SrcIVTy = +          VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElts); +        // Ask IR to do the conversion now that #elts line up. +        C = ConstantExpr::getBitCast(C, SrcIVTy); +      } + +      APInt Result(DL.getTypeSizeInBits(DestTy), 0); +      if (Constant *CE = foldConstVectorToAPInt(Result, DestTy, C, +                                                SrcEltTy, NumSrcElts, DL)) +        return CE; + +      if (isa<IntegerType>(DestTy)) +        return ConstantInt::get(DestTy, Result); + +      APFloat FP(DestTy->getFltSemantics(), Result); +      return ConstantFP::get(DestTy->getContext(), FP); +    } +  } + +  // The code below only handles casts to vectors currently. +  auto *DestVTy = dyn_cast<VectorType>(DestTy); +  if (!DestVTy) +    return ConstantExpr::getBitCast(C, DestTy); + +  // If this is a scalar -> vector cast, convert the input into a <1 x scalar> +  // vector so the code below can handle it uniformly. +  if (isa<ConstantFP>(C) || isa<ConstantInt>(C)) { +    Constant *Ops = C; // don't take the address of C! +    return FoldBitCast(ConstantVector::get(Ops), DestTy, DL); +  } + +  // If this is a bitcast from constant vector -> vector, fold it. +  if (!isa<ConstantDataVector>(C) && !isa<ConstantVector>(C)) +    return ConstantExpr::getBitCast(C, DestTy); + +  // If the element types match, IR can fold it. +  unsigned NumDstElt = DestVTy->getNumElements(); +  unsigned NumSrcElt = C->getType()->getVectorNumElements(); +  if (NumDstElt == NumSrcElt) +    return ConstantExpr::getBitCast(C, DestTy); + +  Type *SrcEltTy = C->getType()->getVectorElementType(); +  Type *DstEltTy = DestVTy->getElementType(); + +  // Otherwise, we're changing the number of elements in a vector, which +  // requires endianness information to do the right thing.  For example, +  //    bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) +  // folds to (little endian): +  //    <4 x i32> <i32 0, i32 0, i32 1, i32 0> +  // and to (big endian): +  //    <4 x i32> <i32 0, i32 0, i32 0, i32 1> + +  // First thing is first.  We only want to think about integer here, so if +  // we have something in FP form, recast it as integer. +  if (DstEltTy->isFloatingPointTy()) { +    // Fold to an vector of integers with same size as our FP type. +    unsigned FPWidth = DstEltTy->getPrimitiveSizeInBits(); +    Type *DestIVTy = +      VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumDstElt); +    // Recursively handle this integer conversion, if possible. +    C = FoldBitCast(C, DestIVTy, DL); + +    // Finally, IR can handle this now that #elts line up. +    return ConstantExpr::getBitCast(C, DestTy); +  } + +  // Okay, we know the destination is integer, if the input is FP, convert +  // it to integer first. +  if (SrcEltTy->isFloatingPointTy()) { +    unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits(); +    Type *SrcIVTy = +      VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElt); +    // Ask IR to do the conversion now that #elts line up. +    C = ConstantExpr::getBitCast(C, SrcIVTy); +    // If IR wasn't able to fold it, bail out. +    if (!isa<ConstantVector>(C) &&  // FIXME: Remove ConstantVector. +        !isa<ConstantDataVector>(C)) +      return C; +  } + +  // Now we know that the input and output vectors are both integer vectors +  // of the same size, and that their #elements is not the same.  Do the +  // conversion here, which depends on whether the input or output has +  // more elements. +  bool isLittleEndian = DL.isLittleEndian(); + +  SmallVector<Constant*, 32> Result; +  if (NumDstElt < NumSrcElt) { +    // Handle: bitcast (<4 x i32> <i32 0, i32 1, i32 2, i32 3> to <2 x i64>) +    Constant *Zero = Constant::getNullValue(DstEltTy); +    unsigned Ratio = NumSrcElt/NumDstElt; +    unsigned SrcBitSize = SrcEltTy->getPrimitiveSizeInBits(); +    unsigned SrcElt = 0; +    for (unsigned i = 0; i != NumDstElt; ++i) { +      // Build each element of the result. +      Constant *Elt = Zero; +      unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1); +      for (unsigned j = 0; j != Ratio; ++j) { +        Constant *Src = C->getAggregateElement(SrcElt++); +        if (Src && isa<UndefValue>(Src)) +          Src = Constant::getNullValue(C->getType()->getVectorElementType()); +        else +          Src = dyn_cast_or_null<ConstantInt>(Src); +        if (!Src)  // Reject constantexpr elements. +          return ConstantExpr::getBitCast(C, DestTy); + +        // Zero extend the element to the right size. +        Src = ConstantExpr::getZExt(Src, Elt->getType()); + +        // Shift it to the right place, depending on endianness. +        Src = ConstantExpr::getShl(Src, +                                   ConstantInt::get(Src->getType(), ShiftAmt)); +        ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize; + +        // Mix it in. +        Elt = ConstantExpr::getOr(Elt, Src); +      } +      Result.push_back(Elt); +    } +    return ConstantVector::get(Result); +  } + +  // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) +  unsigned Ratio = NumDstElt/NumSrcElt; +  unsigned DstBitSize = DL.getTypeSizeInBits(DstEltTy); + +  // Loop over each source value, expanding into multiple results. +  for (unsigned i = 0; i != NumSrcElt; ++i) { +    auto *Element = C->getAggregateElement(i); + +    if (!Element) // Reject constantexpr elements. +      return ConstantExpr::getBitCast(C, DestTy); + +    if (isa<UndefValue>(Element)) { +      // Correctly Propagate undef values. +      Result.append(Ratio, UndefValue::get(DstEltTy)); +      continue; +    } + +    auto *Src = dyn_cast<ConstantInt>(Element); +    if (!Src) +      return ConstantExpr::getBitCast(C, DestTy); + +    unsigned ShiftAmt = isLittleEndian ? 0 : DstBitSize*(Ratio-1); +    for (unsigned j = 0; j != Ratio; ++j) { +      // Shift the piece of the value into the right place, depending on +      // endianness. +      Constant *Elt = ConstantExpr::getLShr(Src, +                                  ConstantInt::get(Src->getType(), ShiftAmt)); +      ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize; + +      // Truncate the element to an integer with the same pointer size and +      // convert the element back to a pointer using a inttoptr. +      if (DstEltTy->isPointerTy()) { +        IntegerType *DstIntTy = Type::getIntNTy(C->getContext(), DstBitSize); +        Constant *CE = ConstantExpr::getTrunc(Elt, DstIntTy); +        Result.push_back(ConstantExpr::getIntToPtr(CE, DstEltTy)); +        continue; +      } + +      // Truncate and remember this piece. +      Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy)); +    } +  } + +  return ConstantVector::get(Result); +} + +} // end anonymous namespace + +/// If this constant is a constant offset from a global, return the global and +/// the constant. Because of constantexprs, this function is recursive. +bool llvm::IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, +                                      APInt &Offset, const DataLayout &DL) { +  // Trivial case, constant is the global. +  if ((GV = dyn_cast<GlobalValue>(C))) { +    unsigned BitWidth = DL.getIndexTypeSizeInBits(GV->getType()); +    Offset = APInt(BitWidth, 0); +    return true; +  } + +  // Otherwise, if this isn't a constant expr, bail out. +  auto *CE = dyn_cast<ConstantExpr>(C); +  if (!CE) return false; + +  // Look through ptr->int and ptr->ptr casts. +  if (CE->getOpcode() == Instruction::PtrToInt || +      CE->getOpcode() == Instruction::BitCast) +    return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, DL); + +  // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5) +  auto *GEP = dyn_cast<GEPOperator>(CE); +  if (!GEP) +    return false; + +  unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP->getType()); +  APInt TmpOffset(BitWidth, 0); + +  // If the base isn't a global+constant, we aren't either. +  if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, TmpOffset, DL)) +    return false; + +  // Otherwise, add any offset that our operands provide. +  if (!GEP->accumulateConstantOffset(DL, TmpOffset)) +    return false; + +  Offset = TmpOffset; +  return true; +} + +Constant *llvm::ConstantFoldLoadThroughBitcast(Constant *C, Type *DestTy, +                                         const DataLayout &DL) { +  do { +    Type *SrcTy = C->getType(); + +    // If the type sizes are the same and a cast is legal, just directly +    // cast the constant. +    if (DL.getTypeSizeInBits(DestTy) == DL.getTypeSizeInBits(SrcTy)) { +      Instruction::CastOps Cast = Instruction::BitCast; +      // If we are going from a pointer to int or vice versa, we spell the cast +      // differently. +      if (SrcTy->isIntegerTy() && DestTy->isPointerTy()) +        Cast = Instruction::IntToPtr; +      else if (SrcTy->isPointerTy() && DestTy->isIntegerTy()) +        Cast = Instruction::PtrToInt; + +      if (CastInst::castIsValid(Cast, C, DestTy)) +        return ConstantExpr::getCast(Cast, C, DestTy); +    } + +    // If this isn't an aggregate type, there is nothing we can do to drill down +    // and find a bitcastable constant. +    if (!SrcTy->isAggregateType()) +      return nullptr; + +    // We're simulating a load through a pointer that was bitcast to point to +    // a different type, so we can try to walk down through the initial +    // elements of an aggregate to see if some part of th e aggregate is +    // castable to implement the "load" semantic model. +    C = C->getAggregateElement(0u); +  } while (C); + +  return nullptr; +} + +namespace { + +/// Recursive helper to read bits out of global. C is the constant being copied +/// out of. ByteOffset is an offset into C. CurPtr is the pointer to copy +/// results into and BytesLeft is the number of bytes left in +/// the CurPtr buffer. DL is the DataLayout. +bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr, +                        unsigned BytesLeft, const DataLayout &DL) { +  assert(ByteOffset <= DL.getTypeAllocSize(C->getType()) && +         "Out of range access"); + +  // If this element is zero or undefined, we can just return since *CurPtr is +  // zero initialized. +  if (isa<ConstantAggregateZero>(C) || isa<UndefValue>(C)) +    return true; + +  if (auto *CI = dyn_cast<ConstantInt>(C)) { +    if (CI->getBitWidth() > 64 || +        (CI->getBitWidth() & 7) != 0) +      return false; + +    uint64_t Val = CI->getZExtValue(); +    unsigned IntBytes = unsigned(CI->getBitWidth()/8); + +    for (unsigned i = 0; i != BytesLeft && ByteOffset != IntBytes; ++i) { +      int n = ByteOffset; +      if (!DL.isLittleEndian()) +        n = IntBytes - n - 1; +      CurPtr[i] = (unsigned char)(Val >> (n * 8)); +      ++ByteOffset; +    } +    return true; +  } + +  if (auto *CFP = dyn_cast<ConstantFP>(C)) { +    if (CFP->getType()->isDoubleTy()) { +      C = FoldBitCast(C, Type::getInt64Ty(C->getContext()), DL); +      return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL); +    } +    if (CFP->getType()->isFloatTy()){ +      C = FoldBitCast(C, Type::getInt32Ty(C->getContext()), DL); +      return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL); +    } +    if (CFP->getType()->isHalfTy()){ +      C = FoldBitCast(C, Type::getInt16Ty(C->getContext()), DL); +      return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL); +    } +    return false; +  } + +  if (auto *CS = dyn_cast<ConstantStruct>(C)) { +    const StructLayout *SL = DL.getStructLayout(CS->getType()); +    unsigned Index = SL->getElementContainingOffset(ByteOffset); +    uint64_t CurEltOffset = SL->getElementOffset(Index); +    ByteOffset -= CurEltOffset; + +    while (true) { +      // If the element access is to the element itself and not to tail padding, +      // read the bytes from the element. +      uint64_t EltSize = DL.getTypeAllocSize(CS->getOperand(Index)->getType()); + +      if (ByteOffset < EltSize && +          !ReadDataFromGlobal(CS->getOperand(Index), ByteOffset, CurPtr, +                              BytesLeft, DL)) +        return false; + +      ++Index; + +      // Check to see if we read from the last struct element, if so we're done. +      if (Index == CS->getType()->getNumElements()) +        return true; + +      // If we read all of the bytes we needed from this element we're done. +      uint64_t NextEltOffset = SL->getElementOffset(Index); + +      if (BytesLeft <= NextEltOffset - CurEltOffset - ByteOffset) +        return true; + +      // Move to the next element of the struct. +      CurPtr += NextEltOffset - CurEltOffset - ByteOffset; +      BytesLeft -= NextEltOffset - CurEltOffset - ByteOffset; +      ByteOffset = 0; +      CurEltOffset = NextEltOffset; +    } +    // not reached. +  } + +  if (isa<ConstantArray>(C) || isa<ConstantVector>(C) || +      isa<ConstantDataSequential>(C)) { +    Type *EltTy = C->getType()->getSequentialElementType(); +    uint64_t EltSize = DL.getTypeAllocSize(EltTy); +    uint64_t Index = ByteOffset / EltSize; +    uint64_t Offset = ByteOffset - Index * EltSize; +    uint64_t NumElts; +    if (auto *AT = dyn_cast<ArrayType>(C->getType())) +      NumElts = AT->getNumElements(); +    else +      NumElts = C->getType()->getVectorNumElements(); + +    for (; Index != NumElts; ++Index) { +      if (!ReadDataFromGlobal(C->getAggregateElement(Index), Offset, CurPtr, +                              BytesLeft, DL)) +        return false; + +      uint64_t BytesWritten = EltSize - Offset; +      assert(BytesWritten <= EltSize && "Not indexing into this element?"); +      if (BytesWritten >= BytesLeft) +        return true; + +      Offset = 0; +      BytesLeft -= BytesWritten; +      CurPtr += BytesWritten; +    } +    return true; +  } + +  if (auto *CE = dyn_cast<ConstantExpr>(C)) { +    if (CE->getOpcode() == Instruction::IntToPtr && +        CE->getOperand(0)->getType() == DL.getIntPtrType(CE->getType())) { +      return ReadDataFromGlobal(CE->getOperand(0), ByteOffset, CurPtr, +                                BytesLeft, DL); +    } +  } + +  // Otherwise, unknown initializer type. +  return false; +} + +Constant *FoldReinterpretLoadFromConstPtr(Constant *C, Type *LoadTy, +                                          const DataLayout &DL) { +  auto *PTy = cast<PointerType>(C->getType()); +  auto *IntType = dyn_cast<IntegerType>(LoadTy); + +  // If this isn't an integer load we can't fold it directly. +  if (!IntType) { +    unsigned AS = PTy->getAddressSpace(); + +    // If this is a float/double load, we can try folding it as an int32/64 load +    // and then bitcast the result.  This can be useful for union cases.  Note +    // that address spaces don't matter here since we're not going to result in +    // an actual new load. +    Type *MapTy; +    if (LoadTy->isHalfTy()) +      MapTy = Type::getInt16Ty(C->getContext()); +    else if (LoadTy->isFloatTy()) +      MapTy = Type::getInt32Ty(C->getContext()); +    else if (LoadTy->isDoubleTy()) +      MapTy = Type::getInt64Ty(C->getContext()); +    else if (LoadTy->isVectorTy()) { +      MapTy = PointerType::getIntNTy(C->getContext(), +                                     DL.getTypeAllocSizeInBits(LoadTy)); +    } else +      return nullptr; + +    C = FoldBitCast(C, MapTy->getPointerTo(AS), DL); +    if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, MapTy, DL)) +      return FoldBitCast(Res, LoadTy, DL); +    return nullptr; +  } + +  unsigned BytesLoaded = (IntType->getBitWidth() + 7) / 8; +  if (BytesLoaded > 32 || BytesLoaded == 0) +    return nullptr; + +  GlobalValue *GVal; +  APInt OffsetAI; +  if (!IsConstantOffsetFromGlobal(C, GVal, OffsetAI, DL)) +    return nullptr; + +  auto *GV = dyn_cast<GlobalVariable>(GVal); +  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || +      !GV->getInitializer()->getType()->isSized()) +    return nullptr; + +  int64_t Offset = OffsetAI.getSExtValue(); +  int64_t InitializerSize = DL.getTypeAllocSize(GV->getInitializer()->getType()); + +  // If we're not accessing anything in this constant, the result is undefined. +  if (Offset + BytesLoaded <= 0) +    return UndefValue::get(IntType); + +  // If we're not accessing anything in this constant, the result is undefined. +  if (Offset >= InitializerSize) +    return UndefValue::get(IntType); + +  unsigned char RawBytes[32] = {0}; +  unsigned char *CurPtr = RawBytes; +  unsigned BytesLeft = BytesLoaded; + +  // If we're loading off the beginning of the global, some bytes may be valid. +  if (Offset < 0) { +    CurPtr += -Offset; +    BytesLeft += Offset; +    Offset = 0; +  } + +  if (!ReadDataFromGlobal(GV->getInitializer(), Offset, CurPtr, BytesLeft, DL)) +    return nullptr; + +  APInt ResultVal = APInt(IntType->getBitWidth(), 0); +  if (DL.isLittleEndian()) { +    ResultVal = RawBytes[BytesLoaded - 1]; +    for (unsigned i = 1; i != BytesLoaded; ++i) { +      ResultVal <<= 8; +      ResultVal |= RawBytes[BytesLoaded - 1 - i]; +    } +  } else { +    ResultVal = RawBytes[0]; +    for (unsigned i = 1; i != BytesLoaded; ++i) { +      ResultVal <<= 8; +      ResultVal |= RawBytes[i]; +    } +  } + +  return ConstantInt::get(IntType->getContext(), ResultVal); +} + +Constant *ConstantFoldLoadThroughBitcastExpr(ConstantExpr *CE, Type *DestTy, +                                             const DataLayout &DL) { +  auto *SrcPtr = CE->getOperand(0); +  auto *SrcPtrTy = dyn_cast<PointerType>(SrcPtr->getType()); +  if (!SrcPtrTy) +    return nullptr; +  Type *SrcTy = SrcPtrTy->getPointerElementType(); + +  Constant *C = ConstantFoldLoadFromConstPtr(SrcPtr, SrcTy, DL); +  if (!C) +    return nullptr; + +  return llvm::ConstantFoldLoadThroughBitcast(C, DestTy, DL); +} + +} // end anonymous namespace + +Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, +                                             const DataLayout &DL) { +  // First, try the easy cases: +  if (auto *GV = dyn_cast<GlobalVariable>(C)) +    if (GV->isConstant() && GV->hasDefinitiveInitializer()) +      return GV->getInitializer(); + +  if (auto *GA = dyn_cast<GlobalAlias>(C)) +    if (GA->getAliasee() && !GA->isInterposable()) +      return ConstantFoldLoadFromConstPtr(GA->getAliasee(), Ty, DL); + +  // If the loaded value isn't a constant expr, we can't handle it. +  auto *CE = dyn_cast<ConstantExpr>(C); +  if (!CE) +    return nullptr; + +  if (CE->getOpcode() == Instruction::GetElementPtr) { +    if (auto *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) { +      if (GV->isConstant() && GV->hasDefinitiveInitializer()) { +        if (Constant *V = +             ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE)) +          return V; +      } +    } +  } + +  if (CE->getOpcode() == Instruction::BitCast) +    if (Constant *LoadedC = ConstantFoldLoadThroughBitcastExpr(CE, Ty, DL)) +      return LoadedC; + +  // Instead of loading constant c string, use corresponding integer value +  // directly if string length is small enough. +  StringRef Str; +  if (getConstantStringInfo(CE, Str) && !Str.empty()) { +    size_t StrLen = Str.size(); +    unsigned NumBits = Ty->getPrimitiveSizeInBits(); +    // Replace load with immediate integer if the result is an integer or fp +    // value. +    if ((NumBits >> 3) == StrLen + 1 && (NumBits & 7) == 0 && +        (isa<IntegerType>(Ty) || Ty->isFloatingPointTy())) { +      APInt StrVal(NumBits, 0); +      APInt SingleChar(NumBits, 0); +      if (DL.isLittleEndian()) { +        for (unsigned char C : reverse(Str.bytes())) { +          SingleChar = static_cast<uint64_t>(C); +          StrVal = (StrVal << 8) | SingleChar; +        } +      } else { +        for (unsigned char C : Str.bytes()) { +          SingleChar = static_cast<uint64_t>(C); +          StrVal = (StrVal << 8) | SingleChar; +        } +        // Append NULL at the end. +        SingleChar = 0; +        StrVal = (StrVal << 8) | SingleChar; +      } + +      Constant *Res = ConstantInt::get(CE->getContext(), StrVal); +      if (Ty->isFloatingPointTy()) +        Res = ConstantExpr::getBitCast(Res, Ty); +      return Res; +    } +  } + +  // If this load comes from anywhere in a constant global, and if the global +  // is all undef or zero, we know what it loads. +  if (auto *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(CE, DL))) { +    if (GV->isConstant() && GV->hasDefinitiveInitializer()) { +      if (GV->getInitializer()->isNullValue()) +        return Constant::getNullValue(Ty); +      if (isa<UndefValue>(GV->getInitializer())) +        return UndefValue::get(Ty); +    } +  } + +  // Try hard to fold loads from bitcasted strange and non-type-safe things. +  return FoldReinterpretLoadFromConstPtr(CE, Ty, DL); +} + +namespace { + +Constant *ConstantFoldLoadInst(const LoadInst *LI, const DataLayout &DL) { +  if (LI->isVolatile()) return nullptr; + +  if (auto *C = dyn_cast<Constant>(LI->getOperand(0))) +    return ConstantFoldLoadFromConstPtr(C, LI->getType(), DL); + +  return nullptr; +} + +/// One of Op0/Op1 is a constant expression. +/// Attempt to symbolically evaluate the result of a binary operator merging +/// these together.  If target data info is available, it is provided as DL, +/// otherwise DL is null. +Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, Constant *Op1, +                                    const DataLayout &DL) { +  // SROA + +  // Fold (and 0xffffffff00000000, (shl x, 32)) -> shl. +  // Fold (lshr (or X, Y), 32) -> (lshr [X/Y], 32) if one doesn't contribute +  // bits. + +  if (Opc == Instruction::And) { +    KnownBits Known0 = computeKnownBits(Op0, DL); +    KnownBits Known1 = computeKnownBits(Op1, DL); +    if ((Known1.One | Known0.Zero).isAllOnesValue()) { +      // All the bits of Op0 that the 'and' could be masking are already zero. +      return Op0; +    } +    if ((Known0.One | Known1.Zero).isAllOnesValue()) { +      // All the bits of Op1 that the 'and' could be masking are already zero. +      return Op1; +    } + +    Known0.Zero |= Known1.Zero; +    Known0.One &= Known1.One; +    if (Known0.isConstant()) +      return ConstantInt::get(Op0->getType(), Known0.getConstant()); +  } + +  // If the constant expr is something like &A[123] - &A[4].f, fold this into a +  // constant.  This happens frequently when iterating over a global array. +  if (Opc == Instruction::Sub) { +    GlobalValue *GV1, *GV2; +    APInt Offs1, Offs2; + +    if (IsConstantOffsetFromGlobal(Op0, GV1, Offs1, DL)) +      if (IsConstantOffsetFromGlobal(Op1, GV2, Offs2, DL) && GV1 == GV2) { +        unsigned OpSize = DL.getTypeSizeInBits(Op0->getType()); + +        // (&GV+C1) - (&GV+C2) -> C1-C2, pointer arithmetic cannot overflow. +        // PtrToInt may change the bitwidth so we have convert to the right size +        // first. +        return ConstantInt::get(Op0->getType(), Offs1.zextOrTrunc(OpSize) - +                                                Offs2.zextOrTrunc(OpSize)); +      } +  } + +  return nullptr; +} + +/// If array indices are not pointer-sized integers, explicitly cast them so +/// that they aren't implicitly casted by the getelementptr. +Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops, +                         Type *ResultTy, Optional<unsigned> InRangeIndex, +                         const DataLayout &DL, const TargetLibraryInfo *TLI) { +  Type *IntPtrTy = DL.getIntPtrType(ResultTy); +  Type *IntPtrScalarTy = IntPtrTy->getScalarType(); + +  bool Any = false; +  SmallVector<Constant*, 32> NewIdxs; +  for (unsigned i = 1, e = Ops.size(); i != e; ++i) { +    if ((i == 1 || +         !isa<StructType>(GetElementPtrInst::getIndexedType( +             SrcElemTy, Ops.slice(1, i - 1)))) && +        Ops[i]->getType()->getScalarType() != IntPtrScalarTy) { +      Any = true; +      Type *NewType = Ops[i]->getType()->isVectorTy() +                          ? IntPtrTy +                          : IntPtrTy->getScalarType(); +      NewIdxs.push_back(ConstantExpr::getCast(CastInst::getCastOpcode(Ops[i], +                                                                      true, +                                                                      NewType, +                                                                      true), +                                              Ops[i], NewType)); +    } else +      NewIdxs.push_back(Ops[i]); +  } + +  if (!Any) +    return nullptr; + +  Constant *C = ConstantExpr::getGetElementPtr( +      SrcElemTy, Ops[0], NewIdxs, /*InBounds=*/false, InRangeIndex); +  if (Constant *Folded = ConstantFoldConstant(C, DL, TLI)) +    C = Folded; + +  return C; +} + +/// Strip the pointer casts, but preserve the address space information. +Constant* StripPtrCastKeepAS(Constant* Ptr, Type *&ElemTy) { +  assert(Ptr->getType()->isPointerTy() && "Not a pointer type"); +  auto *OldPtrTy = cast<PointerType>(Ptr->getType()); +  Ptr = Ptr->stripPointerCasts(); +  auto *NewPtrTy = cast<PointerType>(Ptr->getType()); + +  ElemTy = NewPtrTy->getPointerElementType(); + +  // Preserve the address space number of the pointer. +  if (NewPtrTy->getAddressSpace() != OldPtrTy->getAddressSpace()) { +    NewPtrTy = ElemTy->getPointerTo(OldPtrTy->getAddressSpace()); +    Ptr = ConstantExpr::getPointerCast(Ptr, NewPtrTy); +  } +  return Ptr; +} + +/// If we can symbolically evaluate the GEP constant expression, do so. +Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, +                                  ArrayRef<Constant *> Ops, +                                  const DataLayout &DL, +                                  const TargetLibraryInfo *TLI) { +  const GEPOperator *InnermostGEP = GEP; +  bool InBounds = GEP->isInBounds(); + +  Type *SrcElemTy = GEP->getSourceElementType(); +  Type *ResElemTy = GEP->getResultElementType(); +  Type *ResTy = GEP->getType(); +  if (!SrcElemTy->isSized()) +    return nullptr; + +  if (Constant *C = CastGEPIndices(SrcElemTy, Ops, ResTy, +                                   GEP->getInRangeIndex(), DL, TLI)) +    return C; + +  Constant *Ptr = Ops[0]; +  if (!Ptr->getType()->isPointerTy()) +    return nullptr; + +  Type *IntPtrTy = DL.getIntPtrType(Ptr->getType()); + +  // If this is a constant expr gep that is effectively computing an +  // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12' +  for (unsigned i = 1, e = Ops.size(); i != e; ++i) +      if (!isa<ConstantInt>(Ops[i])) { + +        // If this is "gep i8* Ptr, (sub 0, V)", fold this as: +        // "inttoptr (sub (ptrtoint Ptr), V)" +        if (Ops.size() == 2 && ResElemTy->isIntegerTy(8)) { +          auto *CE = dyn_cast<ConstantExpr>(Ops[1]); +          assert((!CE || CE->getType() == IntPtrTy) && +                 "CastGEPIndices didn't canonicalize index types!"); +          if (CE && CE->getOpcode() == Instruction::Sub && +              CE->getOperand(0)->isNullValue()) { +            Constant *Res = ConstantExpr::getPtrToInt(Ptr, CE->getType()); +            Res = ConstantExpr::getSub(Res, CE->getOperand(1)); +            Res = ConstantExpr::getIntToPtr(Res, ResTy); +            if (auto *FoldedRes = ConstantFoldConstant(Res, DL, TLI)) +              Res = FoldedRes; +            return Res; +          } +        } +        return nullptr; +      } + +  unsigned BitWidth = DL.getTypeSizeInBits(IntPtrTy); +  APInt Offset = +      APInt(BitWidth, +            DL.getIndexedOffsetInType( +                SrcElemTy, +                makeArrayRef((Value * const *)Ops.data() + 1, Ops.size() - 1))); +  Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy); + +  // If this is a GEP of a GEP, fold it all into a single GEP. +  while (auto *GEP = dyn_cast<GEPOperator>(Ptr)) { +    InnermostGEP = GEP; +    InBounds &= GEP->isInBounds(); + +    SmallVector<Value *, 4> NestedOps(GEP->op_begin() + 1, GEP->op_end()); + +    // Do not try the incorporate the sub-GEP if some index is not a number. +    bool AllConstantInt = true; +    for (Value *NestedOp : NestedOps) +      if (!isa<ConstantInt>(NestedOp)) { +        AllConstantInt = false; +        break; +      } +    if (!AllConstantInt) +      break; + +    Ptr = cast<Constant>(GEP->getOperand(0)); +    SrcElemTy = GEP->getSourceElementType(); +    Offset += APInt(BitWidth, DL.getIndexedOffsetInType(SrcElemTy, NestedOps)); +    Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy); +  } + +  // If the base value for this address is a literal integer value, fold the +  // getelementptr to the resulting integer value casted to the pointer type. +  APInt BasePtr(BitWidth, 0); +  if (auto *CE = dyn_cast<ConstantExpr>(Ptr)) { +    if (CE->getOpcode() == Instruction::IntToPtr) { +      if (auto *Base = dyn_cast<ConstantInt>(CE->getOperand(0))) +        BasePtr = Base->getValue().zextOrTrunc(BitWidth); +    } +  } + +  auto *PTy = cast<PointerType>(Ptr->getType()); +  if ((Ptr->isNullValue() || BasePtr != 0) && +      !DL.isNonIntegralPointerType(PTy)) { +    Constant *C = ConstantInt::get(Ptr->getContext(), Offset + BasePtr); +    return ConstantExpr::getIntToPtr(C, ResTy); +  } + +  // Otherwise form a regular getelementptr. Recompute the indices so that +  // we eliminate over-indexing of the notional static type array bounds. +  // This makes it easy to determine if the getelementptr is "inbounds". +  // Also, this helps GlobalOpt do SROA on GlobalVariables. +  Type *Ty = PTy; +  SmallVector<Constant *, 32> NewIdxs; + +  do { +    if (!Ty->isStructTy()) { +      if (Ty->isPointerTy()) { +        // The only pointer indexing we'll do is on the first index of the GEP. +        if (!NewIdxs.empty()) +          break; + +        Ty = SrcElemTy; + +        // Only handle pointers to sized types, not pointers to functions. +        if (!Ty->isSized()) +          return nullptr; +      } else if (auto *ATy = dyn_cast<SequentialType>(Ty)) { +        Ty = ATy->getElementType(); +      } else { +        // We've reached some non-indexable type. +        break; +      } + +      // Determine which element of the array the offset points into. +      APInt ElemSize(BitWidth, DL.getTypeAllocSize(Ty)); +      if (ElemSize == 0) { +        // The element size is 0. This may be [0 x Ty]*, so just use a zero +        // index for this level and proceed to the next level to see if it can +        // accommodate the offset. +        NewIdxs.push_back(ConstantInt::get(IntPtrTy, 0)); +      } else { +        // The element size is non-zero divide the offset by the element +        // size (rounding down), to compute the index at this level. +        bool Overflow; +        APInt NewIdx = Offset.sdiv_ov(ElemSize, Overflow); +        if (Overflow) +          break; +        Offset -= NewIdx * ElemSize; +        NewIdxs.push_back(ConstantInt::get(IntPtrTy, NewIdx)); +      } +    } else { +      auto *STy = cast<StructType>(Ty); +      // If we end up with an offset that isn't valid for this struct type, we +      // can't re-form this GEP in a regular form, so bail out. The pointer +      // operand likely went through casts that are necessary to make the GEP +      // sensible. +      const StructLayout &SL = *DL.getStructLayout(STy); +      if (Offset.isNegative() || Offset.uge(SL.getSizeInBytes())) +        break; + +      // Determine which field of the struct the offset points into. The +      // getZExtValue is fine as we've already ensured that the offset is +      // within the range representable by the StructLayout API. +      unsigned ElIdx = SL.getElementContainingOffset(Offset.getZExtValue()); +      NewIdxs.push_back(ConstantInt::get(Type::getInt32Ty(Ty->getContext()), +                                         ElIdx)); +      Offset -= APInt(BitWidth, SL.getElementOffset(ElIdx)); +      Ty = STy->getTypeAtIndex(ElIdx); +    } +  } while (Ty != ResElemTy); + +  // If we haven't used up the entire offset by descending the static +  // type, then the offset is pointing into the middle of an indivisible +  // member, so we can't simplify it. +  if (Offset != 0) +    return nullptr; + +  // Preserve the inrange index from the innermost GEP if possible. We must +  // have calculated the same indices up to and including the inrange index. +  Optional<unsigned> InRangeIndex; +  if (Optional<unsigned> LastIRIndex = InnermostGEP->getInRangeIndex()) +    if (SrcElemTy == InnermostGEP->getSourceElementType() && +        NewIdxs.size() > *LastIRIndex) { +      InRangeIndex = LastIRIndex; +      for (unsigned I = 0; I <= *LastIRIndex; ++I) +        if (NewIdxs[I] != InnermostGEP->getOperand(I + 1)) { +          InRangeIndex = None; +          break; +        } +    } + +  // Create a GEP. +  Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs, +                                               InBounds, InRangeIndex); +  assert(C->getType()->getPointerElementType() == Ty && +         "Computed GetElementPtr has unexpected type!"); + +  // If we ended up indexing a member with a type that doesn't match +  // the type of what the original indices indexed, add a cast. +  if (Ty != ResElemTy) +    C = FoldBitCast(C, ResTy, DL); + +  return C; +} + +/// Attempt to constant fold an instruction with the +/// specified opcode and operands.  If successful, the constant result is +/// returned, if not, null is returned.  Note that this function can fail when +/// attempting to fold instructions like loads and stores, which have no +/// constant expression form. +/// +/// TODO: This function neither utilizes nor preserves nsw/nuw/inbounds/inrange +/// etc information, due to only being passed an opcode and operands. Constant +/// folding using this function strips this information. +/// +Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, +                                       ArrayRef<Constant *> Ops, +                                       const DataLayout &DL, +                                       const TargetLibraryInfo *TLI) { +  Type *DestTy = InstOrCE->getType(); + +  // Handle easy binops first. +  if (Instruction::isBinaryOp(Opcode)) +    return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL); + +  if (Instruction::isCast(Opcode)) +    return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL); + +  if (auto *GEP = dyn_cast<GEPOperator>(InstOrCE)) { +    if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI)) +      return C; + +    return ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), Ops[0], +                                          Ops.slice(1), GEP->isInBounds(), +                                          GEP->getInRangeIndex()); +  } + +  if (auto *CE = dyn_cast<ConstantExpr>(InstOrCE)) +    return CE->getWithOperands(Ops); + +  switch (Opcode) { +  default: return nullptr; +  case Instruction::ICmp: +  case Instruction::FCmp: llvm_unreachable("Invalid for compares"); +  case Instruction::Call: +    if (auto *F = dyn_cast<Function>(Ops.back())) { +      ImmutableCallSite CS(cast<CallInst>(InstOrCE)); +      if (canConstantFoldCallTo(CS, F)) +        return ConstantFoldCall(CS, F, Ops.slice(0, Ops.size() - 1), TLI); +    } +    return nullptr; +  case Instruction::Select: +    return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); +  case Instruction::ExtractElement: +    return ConstantExpr::getExtractElement(Ops[0], Ops[1]); +  case Instruction::InsertElement: +    return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]); +  case Instruction::ShuffleVector: +    return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]); +  } +} + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// Constant Folding public APIs +//===----------------------------------------------------------------------===// + +namespace { + +Constant * +ConstantFoldConstantImpl(const Constant *C, const DataLayout &DL, +                         const TargetLibraryInfo *TLI, +                         SmallDenseMap<Constant *, Constant *> &FoldedOps) { +  if (!isa<ConstantVector>(C) && !isa<ConstantExpr>(C)) +    return nullptr; + +  SmallVector<Constant *, 8> Ops; +  for (const Use &NewU : C->operands()) { +    auto *NewC = cast<Constant>(&NewU); +    // Recursively fold the ConstantExpr's operands. If we have already folded +    // a ConstantExpr, we don't have to process it again. +    if (isa<ConstantVector>(NewC) || isa<ConstantExpr>(NewC)) { +      auto It = FoldedOps.find(NewC); +      if (It == FoldedOps.end()) { +        if (auto *FoldedC = +                ConstantFoldConstantImpl(NewC, DL, TLI, FoldedOps)) { +          FoldedOps.insert({NewC, FoldedC}); +          NewC = FoldedC; +        } else { +          FoldedOps.insert({NewC, NewC}); +        } +      } else { +        NewC = It->second; +      } +    } +    Ops.push_back(NewC); +  } + +  if (auto *CE = dyn_cast<ConstantExpr>(C)) { +    if (CE->isCompare()) +      return ConstantFoldCompareInstOperands(CE->getPredicate(), Ops[0], Ops[1], +                                             DL, TLI); + +    return ConstantFoldInstOperandsImpl(CE, CE->getOpcode(), Ops, DL, TLI); +  } + +  assert(isa<ConstantVector>(C)); +  return ConstantVector::get(Ops); +} + +} // end anonymous namespace + +Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL, +                                        const TargetLibraryInfo *TLI) { +  // Handle PHI nodes quickly here... +  if (auto *PN = dyn_cast<PHINode>(I)) { +    Constant *CommonValue = nullptr; + +    SmallDenseMap<Constant *, Constant *> FoldedOps; +    for (Value *Incoming : PN->incoming_values()) { +      // If the incoming value is undef then skip it.  Note that while we could +      // skip the value if it is equal to the phi node itself we choose not to +      // because that would break the rule that constant folding only applies if +      // all operands are constants. +      if (isa<UndefValue>(Incoming)) +        continue; +      // If the incoming value is not a constant, then give up. +      auto *C = dyn_cast<Constant>(Incoming); +      if (!C) +        return nullptr; +      // Fold the PHI's operands. +      if (auto *FoldedC = ConstantFoldConstantImpl(C, DL, TLI, FoldedOps)) +        C = FoldedC; +      // If the incoming value is a different constant to +      // the one we saw previously, then give up. +      if (CommonValue && C != CommonValue) +        return nullptr; +      CommonValue = C; +    } + +    // If we reach here, all incoming values are the same constant or undef. +    return CommonValue ? CommonValue : UndefValue::get(PN->getType()); +  } + +  // Scan the operand list, checking to see if they are all constants, if so, +  // hand off to ConstantFoldInstOperandsImpl. +  if (!all_of(I->operands(), [](Use &U) { return isa<Constant>(U); })) +    return nullptr; + +  SmallDenseMap<Constant *, Constant *> FoldedOps; +  SmallVector<Constant *, 8> Ops; +  for (const Use &OpU : I->operands()) { +    auto *Op = cast<Constant>(&OpU); +    // Fold the Instruction's operands. +    if (auto *FoldedOp = ConstantFoldConstantImpl(Op, DL, TLI, FoldedOps)) +      Op = FoldedOp; + +    Ops.push_back(Op); +  } + +  if (const auto *CI = dyn_cast<CmpInst>(I)) +    return ConstantFoldCompareInstOperands(CI->getPredicate(), Ops[0], Ops[1], +                                           DL, TLI); + +  if (const auto *LI = dyn_cast<LoadInst>(I)) +    return ConstantFoldLoadInst(LI, DL); + +  if (auto *IVI = dyn_cast<InsertValueInst>(I)) { +    return ConstantExpr::getInsertValue( +                                cast<Constant>(IVI->getAggregateOperand()), +                                cast<Constant>(IVI->getInsertedValueOperand()), +                                IVI->getIndices()); +  } + +  if (auto *EVI = dyn_cast<ExtractValueInst>(I)) { +    return ConstantExpr::getExtractValue( +                                    cast<Constant>(EVI->getAggregateOperand()), +                                    EVI->getIndices()); +  } + +  return ConstantFoldInstOperands(I, Ops, DL, TLI); +} + +Constant *llvm::ConstantFoldConstant(const Constant *C, const DataLayout &DL, +                                     const TargetLibraryInfo *TLI) { +  SmallDenseMap<Constant *, Constant *> FoldedOps; +  return ConstantFoldConstantImpl(C, DL, TLI, FoldedOps); +} + +Constant *llvm::ConstantFoldInstOperands(Instruction *I, +                                         ArrayRef<Constant *> Ops, +                                         const DataLayout &DL, +                                         const TargetLibraryInfo *TLI) { +  return ConstantFoldInstOperandsImpl(I, I->getOpcode(), Ops, DL, TLI); +} + +Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, +                                                Constant *Ops0, Constant *Ops1, +                                                const DataLayout &DL, +                                                const TargetLibraryInfo *TLI) { +  // fold: icmp (inttoptr x), null         -> icmp x, 0 +  // fold: icmp null, (inttoptr x)         -> icmp 0, x +  // fold: icmp (ptrtoint x), 0            -> icmp x, null +  // fold: icmp 0, (ptrtoint x)            -> icmp null, x +  // fold: icmp (inttoptr x), (inttoptr y) -> icmp trunc/zext x, trunc/zext y +  // fold: icmp (ptrtoint x), (ptrtoint y) -> icmp x, y +  // +  // FIXME: The following comment is out of data and the DataLayout is here now. +  // ConstantExpr::getCompare cannot do this, because it doesn't have DL +  // around to know if bit truncation is happening. +  if (auto *CE0 = dyn_cast<ConstantExpr>(Ops0)) { +    if (Ops1->isNullValue()) { +      if (CE0->getOpcode() == Instruction::IntToPtr) { +        Type *IntPtrTy = DL.getIntPtrType(CE0->getType()); +        // Convert the integer value to the right size to ensure we get the +        // proper extension or truncation. +        Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0), +                                                   IntPtrTy, false); +        Constant *Null = Constant::getNullValue(C->getType()); +        return ConstantFoldCompareInstOperands(Predicate, C, Null, DL, TLI); +      } + +      // Only do this transformation if the int is intptrty in size, otherwise +      // there is a truncation or extension that we aren't modeling. +      if (CE0->getOpcode() == Instruction::PtrToInt) { +        Type *IntPtrTy = DL.getIntPtrType(CE0->getOperand(0)->getType()); +        if (CE0->getType() == IntPtrTy) { +          Constant *C = CE0->getOperand(0); +          Constant *Null = Constant::getNullValue(C->getType()); +          return ConstantFoldCompareInstOperands(Predicate, C, Null, DL, TLI); +        } +      } +    } + +    if (auto *CE1 = dyn_cast<ConstantExpr>(Ops1)) { +      if (CE0->getOpcode() == CE1->getOpcode()) { +        if (CE0->getOpcode() == Instruction::IntToPtr) { +          Type *IntPtrTy = DL.getIntPtrType(CE0->getType()); + +          // Convert the integer value to the right size to ensure we get the +          // proper extension or truncation. +          Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0), +                                                      IntPtrTy, false); +          Constant *C1 = ConstantExpr::getIntegerCast(CE1->getOperand(0), +                                                      IntPtrTy, false); +          return ConstantFoldCompareInstOperands(Predicate, C0, C1, DL, TLI); +        } + +        // Only do this transformation if the int is intptrty in size, otherwise +        // there is a truncation or extension that we aren't modeling. +        if (CE0->getOpcode() == Instruction::PtrToInt) { +          Type *IntPtrTy = DL.getIntPtrType(CE0->getOperand(0)->getType()); +          if (CE0->getType() == IntPtrTy && +              CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType()) { +            return ConstantFoldCompareInstOperands( +                Predicate, CE0->getOperand(0), CE1->getOperand(0), DL, TLI); +          } +        } +      } +    } + +    // icmp eq (or x, y), 0 -> (icmp eq x, 0) & (icmp eq y, 0) +    // icmp ne (or x, y), 0 -> (icmp ne x, 0) | (icmp ne y, 0) +    if ((Predicate == ICmpInst::ICMP_EQ || Predicate == ICmpInst::ICMP_NE) && +        CE0->getOpcode() == Instruction::Or && Ops1->isNullValue()) { +      Constant *LHS = ConstantFoldCompareInstOperands( +          Predicate, CE0->getOperand(0), Ops1, DL, TLI); +      Constant *RHS = ConstantFoldCompareInstOperands( +          Predicate, CE0->getOperand(1), Ops1, DL, TLI); +      unsigned OpC = +        Predicate == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; +      return ConstantFoldBinaryOpOperands(OpC, LHS, RHS, DL); +    } +  } else if (isa<ConstantExpr>(Ops1)) { +    // If RHS is a constant expression, but the left side isn't, swap the +    // operands and try again. +    Predicate = ICmpInst::getSwappedPredicate((ICmpInst::Predicate)Predicate); +    return ConstantFoldCompareInstOperands(Predicate, Ops1, Ops0, DL, TLI); +  } + +  return ConstantExpr::getCompare(Predicate, Ops0, Ops1); +} + +Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, +                                             Constant *RHS, +                                             const DataLayout &DL) { +  assert(Instruction::isBinaryOp(Opcode)); +  if (isa<ConstantExpr>(LHS) || isa<ConstantExpr>(RHS)) +    if (Constant *C = SymbolicallyEvaluateBinop(Opcode, LHS, RHS, DL)) +      return C; + +  return ConstantExpr::get(Opcode, LHS, RHS); +} + +Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, +                                        Type *DestTy, const DataLayout &DL) { +  assert(Instruction::isCast(Opcode)); +  switch (Opcode) { +  default: +    llvm_unreachable("Missing case"); +  case Instruction::PtrToInt: +    // If the input is a inttoptr, eliminate the pair.  This requires knowing +    // the width of a pointer, so it can't be done in ConstantExpr::getCast. +    if (auto *CE = dyn_cast<ConstantExpr>(C)) { +      if (CE->getOpcode() == Instruction::IntToPtr) { +        Constant *Input = CE->getOperand(0); +        unsigned InWidth = Input->getType()->getScalarSizeInBits(); +        unsigned PtrWidth = DL.getPointerTypeSizeInBits(CE->getType()); +        if (PtrWidth < InWidth) { +          Constant *Mask = +            ConstantInt::get(CE->getContext(), +                             APInt::getLowBitsSet(InWidth, PtrWidth)); +          Input = ConstantExpr::getAnd(Input, Mask); +        } +        // Do a zext or trunc to get to the dest size. +        return ConstantExpr::getIntegerCast(Input, DestTy, false); +      } +    } +    return ConstantExpr::getCast(Opcode, C, DestTy); +  case Instruction::IntToPtr: +    // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if +    // the int size is >= the ptr size and the address spaces are the same. +    // This requires knowing the width of a pointer, so it can't be done in +    // ConstantExpr::getCast. +    if (auto *CE = dyn_cast<ConstantExpr>(C)) { +      if (CE->getOpcode() == Instruction::PtrToInt) { +        Constant *SrcPtr = CE->getOperand(0); +        unsigned SrcPtrSize = DL.getPointerTypeSizeInBits(SrcPtr->getType()); +        unsigned MidIntSize = CE->getType()->getScalarSizeInBits(); + +        if (MidIntSize >= SrcPtrSize) { +          unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace(); +          if (SrcAS == DestTy->getPointerAddressSpace()) +            return FoldBitCast(CE->getOperand(0), DestTy, DL); +        } +      } +    } + +    return ConstantExpr::getCast(Opcode, C, DestTy); +  case Instruction::Trunc: +  case Instruction::ZExt: +  case Instruction::SExt: +  case Instruction::FPTrunc: +  case Instruction::FPExt: +  case Instruction::UIToFP: +  case Instruction::SIToFP: +  case Instruction::FPToUI: +  case Instruction::FPToSI: +  case Instruction::AddrSpaceCast: +      return ConstantExpr::getCast(Opcode, C, DestTy); +  case Instruction::BitCast: +    return FoldBitCast(C, DestTy, DL); +  } +} + +Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, +                                                       ConstantExpr *CE) { +  if (!CE->getOperand(1)->isNullValue()) +    return nullptr;  // Do not allow stepping over the value! + +  // Loop over all of the operands, tracking down which value we are +  // addressing. +  for (unsigned i = 2, e = CE->getNumOperands(); i != e; ++i) { +    C = C->getAggregateElement(CE->getOperand(i)); +    if (!C) +      return nullptr; +  } +  return C; +} + +Constant * +llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, +                                        ArrayRef<Constant *> Indices) { +  // Loop over all of the operands, tracking down which value we are +  // addressing. +  for (Constant *Index : Indices) { +    C = C->getAggregateElement(Index); +    if (!C) +      return nullptr; +  } +  return C; +} + +//===----------------------------------------------------------------------===// +//  Constant Folding for Calls +// + +bool llvm::canConstantFoldCallTo(ImmutableCallSite CS, const Function *F) { +  if (CS.isNoBuiltin() || CS.isStrictFP()) +    return false; +  switch (F->getIntrinsicID()) { +  case Intrinsic::fabs: +  case Intrinsic::minnum: +  case Intrinsic::maxnum: +  case Intrinsic::log: +  case Intrinsic::log2: +  case Intrinsic::log10: +  case Intrinsic::exp: +  case Intrinsic::exp2: +  case Intrinsic::floor: +  case Intrinsic::ceil: +  case Intrinsic::sqrt: +  case Intrinsic::sin: +  case Intrinsic::cos: +  case Intrinsic::trunc: +  case Intrinsic::rint: +  case Intrinsic::nearbyint: +  case Intrinsic::pow: +  case Intrinsic::powi: +  case Intrinsic::bswap: +  case Intrinsic::ctpop: +  case Intrinsic::ctlz: +  case Intrinsic::cttz: +  case Intrinsic::fma: +  case Intrinsic::fmuladd: +  case Intrinsic::copysign: +  case Intrinsic::launder_invariant_group: +  case Intrinsic::strip_invariant_group: +  case Intrinsic::round: +  case Intrinsic::masked_load: +  case Intrinsic::sadd_with_overflow: +  case Intrinsic::uadd_with_overflow: +  case Intrinsic::ssub_with_overflow: +  case Intrinsic::usub_with_overflow: +  case Intrinsic::smul_with_overflow: +  case Intrinsic::umul_with_overflow: +  case Intrinsic::convert_from_fp16: +  case Intrinsic::convert_to_fp16: +  case Intrinsic::bitreverse: +  case Intrinsic::x86_sse_cvtss2si: +  case Intrinsic::x86_sse_cvtss2si64: +  case Intrinsic::x86_sse_cvttss2si: +  case Intrinsic::x86_sse_cvttss2si64: +  case Intrinsic::x86_sse2_cvtsd2si: +  case Intrinsic::x86_sse2_cvtsd2si64: +  case Intrinsic::x86_sse2_cvttsd2si: +  case Intrinsic::x86_sse2_cvttsd2si64: +    return true; +  default: +    return false; +  case Intrinsic::not_intrinsic: break; +  } + +  if (!F->hasName()) +    return false; +  StringRef Name = F->getName(); + +  // In these cases, the check of the length is required.  We don't want to +  // return true for a name like "cos\0blah" which strcmp would return equal to +  // "cos", but has length 8. +  switch (Name[0]) { +  default: +    return false; +  case 'a': +    return Name == "acos" || Name == "asin" || Name == "atan" || +           Name == "atan2" || Name == "acosf" || Name == "asinf" || +           Name == "atanf" || Name == "atan2f"; +  case 'c': +    return Name == "ceil" || Name == "cos" || Name == "cosh" || +           Name == "ceilf" || Name == "cosf" || Name == "coshf"; +  case 'e': +    return Name == "exp" || Name == "exp2" || Name == "expf" || Name == "exp2f"; +  case 'f': +    return Name == "fabs" || Name == "floor" || Name == "fmod" || +           Name == "fabsf" || Name == "floorf" || Name == "fmodf"; +  case 'l': +    return Name == "log" || Name == "log10" || Name == "logf" || +           Name == "log10f"; +  case 'p': +    return Name == "pow" || Name == "powf"; +  case 'r': +    return Name == "round" || Name == "roundf"; +  case 's': +    return Name == "sin" || Name == "sinh" || Name == "sqrt" || +           Name == "sinf" || Name == "sinhf" || Name == "sqrtf"; +  case 't': +    return Name == "tan" || Name == "tanh" || Name == "tanf" || Name == "tanhf"; +  case '_': + +    // Check for various function names that get used for the math functions +    // when the header files are preprocessed with the macro +    // __FINITE_MATH_ONLY__ enabled. +    // The '12' here is the length of the shortest name that can match. +    // We need to check the size before looking at Name[1] and Name[2] +    // so we may as well check a limit that will eliminate mismatches. +    if (Name.size() < 12 || Name[1] != '_') +      return false; +    switch (Name[2]) { +    default: +      return false; +    case 'a': +      return Name == "__acos_finite" || Name == "__acosf_finite" || +             Name == "__asin_finite" || Name == "__asinf_finite" || +             Name == "__atan2_finite" || Name == "__atan2f_finite"; +    case 'c': +      return Name == "__cosh_finite" || Name == "__coshf_finite"; +    case 'e': +      return Name == "__exp_finite" || Name == "__expf_finite" || +             Name == "__exp2_finite" || Name == "__exp2f_finite"; +    case 'l': +      return Name == "__log_finite" || Name == "__logf_finite" || +             Name == "__log10_finite" || Name == "__log10f_finite"; +    case 'p': +      return Name == "__pow_finite" || Name == "__powf_finite"; +    case 's': +      return Name == "__sinh_finite" || Name == "__sinhf_finite"; +    } +  } +} + +namespace { + +Constant *GetConstantFoldFPValue(double V, Type *Ty) { +  if (Ty->isHalfTy()) { +    APFloat APF(V); +    bool unused; +    APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &unused); +    return ConstantFP::get(Ty->getContext(), APF); +  } +  if (Ty->isFloatTy()) +    return ConstantFP::get(Ty->getContext(), APFloat((float)V)); +  if (Ty->isDoubleTy()) +    return ConstantFP::get(Ty->getContext(), APFloat(V)); +  llvm_unreachable("Can only constant fold half/float/double"); +} + +/// Clear the floating-point exception state. +inline void llvm_fenv_clearexcept() { +#if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT +  feclearexcept(FE_ALL_EXCEPT); +#endif +  errno = 0; +} + +/// Test if a floating-point exception was raised. +inline bool llvm_fenv_testexcept() { +  int errno_val = errno; +  if (errno_val == ERANGE || errno_val == EDOM) +    return true; +#if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT && HAVE_DECL_FE_INEXACT +  if (fetestexcept(FE_ALL_EXCEPT & ~FE_INEXACT)) +    return true; +#endif +  return false; +} + +Constant *ConstantFoldFP(double (*NativeFP)(double), double V, Type *Ty) { +  llvm_fenv_clearexcept(); +  V = NativeFP(V); +  if (llvm_fenv_testexcept()) { +    llvm_fenv_clearexcept(); +    return nullptr; +  } + +  return GetConstantFoldFPValue(V, Ty); +} + +Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), double V, +                               double W, Type *Ty) { +  llvm_fenv_clearexcept(); +  V = NativeFP(V, W); +  if (llvm_fenv_testexcept()) { +    llvm_fenv_clearexcept(); +    return nullptr; +  } + +  return GetConstantFoldFPValue(V, Ty); +} + +/// Attempt to fold an SSE floating point to integer conversion of a constant +/// floating point. If roundTowardZero is false, the default IEEE rounding is +/// used (toward nearest, ties to even). This matches the behavior of the +/// non-truncating SSE instructions in the default rounding mode. The desired +/// integer type Ty is used to select how many bits are available for the +/// result. Returns null if the conversion cannot be performed, otherwise +/// returns the Constant value resulting from the conversion. +Constant *ConstantFoldSSEConvertToInt(const APFloat &Val, bool roundTowardZero, +                                      Type *Ty) { +  // All of these conversion intrinsics form an integer of at most 64bits. +  unsigned ResultWidth = Ty->getIntegerBitWidth(); +  assert(ResultWidth <= 64 && +         "Can only constant fold conversions to 64 and 32 bit ints"); + +  uint64_t UIntVal; +  bool isExact = false; +  APFloat::roundingMode mode = roundTowardZero? APFloat::rmTowardZero +                                              : APFloat::rmNearestTiesToEven; +  APFloat::opStatus status = +      Val.convertToInteger(makeMutableArrayRef(UIntVal), ResultWidth, +                           /*isSigned=*/true, mode, &isExact); +  if (status != APFloat::opOK && +      (!roundTowardZero || status != APFloat::opInexact)) +    return nullptr; +  return ConstantInt::get(Ty, UIntVal, /*isSigned=*/true); +} + +double getValueAsDouble(ConstantFP *Op) { +  Type *Ty = Op->getType(); + +  if (Ty->isFloatTy()) +    return Op->getValueAPF().convertToFloat(); + +  if (Ty->isDoubleTy()) +    return Op->getValueAPF().convertToDouble(); + +  bool unused; +  APFloat APF = Op->getValueAPF(); +  APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &unused); +  return APF.convertToDouble(); +} + +Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, +                                 ArrayRef<Constant *> Operands, +                                 const TargetLibraryInfo *TLI, +                                 ImmutableCallSite CS) { +  if (Operands.size() == 1) { +    if (isa<UndefValue>(Operands[0])) { +      // cosine(arg) is between -1 and 1. cosine(invalid arg) is NaN +      if (IntrinsicID == Intrinsic::cos) +        return Constant::getNullValue(Ty); +      if (IntrinsicID == Intrinsic::bswap || +          IntrinsicID == Intrinsic::bitreverse || +          IntrinsicID == Intrinsic::launder_invariant_group || +          IntrinsicID == Intrinsic::strip_invariant_group) +        return Operands[0]; +    } + +    if (isa<ConstantPointerNull>(Operands[0])) { +      // launder(null) == null == strip(null) iff in addrspace 0 +      if (IntrinsicID == Intrinsic::launder_invariant_group || +          IntrinsicID == Intrinsic::strip_invariant_group) { +        // If instruction is not yet put in a basic block (e.g. when cloning +        // a function during inlining), CS caller may not be available. +        // So check CS's BB first before querying CS.getCaller. +        const Function *Caller = CS.getParent() ? CS.getCaller() : nullptr; +        if (Caller && +            !NullPointerIsDefined( +                Caller, Operands[0]->getType()->getPointerAddressSpace())) { +          return Operands[0]; +        } +        return nullptr; +      } +    } + +    if (auto *Op = dyn_cast<ConstantFP>(Operands[0])) { +      if (IntrinsicID == Intrinsic::convert_to_fp16) { +        APFloat Val(Op->getValueAPF()); + +        bool lost = false; +        Val.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &lost); + +        return ConstantInt::get(Ty->getContext(), Val.bitcastToAPInt()); +      } + +      if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) +        return nullptr; + +      if (IntrinsicID == Intrinsic::round) { +        APFloat V = Op->getValueAPF(); +        V.roundToIntegral(APFloat::rmNearestTiesToAway); +        return ConstantFP::get(Ty->getContext(), V); +      } + +      if (IntrinsicID == Intrinsic::floor) { +        APFloat V = Op->getValueAPF(); +        V.roundToIntegral(APFloat::rmTowardNegative); +        return ConstantFP::get(Ty->getContext(), V); +      } + +      if (IntrinsicID == Intrinsic::ceil) { +        APFloat V = Op->getValueAPF(); +        V.roundToIntegral(APFloat::rmTowardPositive); +        return ConstantFP::get(Ty->getContext(), V); +      } + +      if (IntrinsicID == Intrinsic::trunc) { +        APFloat V = Op->getValueAPF(); +        V.roundToIntegral(APFloat::rmTowardZero); +        return ConstantFP::get(Ty->getContext(), V); +      } + +      if (IntrinsicID == Intrinsic::rint) { +        APFloat V = Op->getValueAPF(); +        V.roundToIntegral(APFloat::rmNearestTiesToEven); +        return ConstantFP::get(Ty->getContext(), V); +      } + +      if (IntrinsicID == Intrinsic::nearbyint) { +        APFloat V = Op->getValueAPF(); +        V.roundToIntegral(APFloat::rmNearestTiesToEven); +        return ConstantFP::get(Ty->getContext(), V); +      } + +      /// We only fold functions with finite arguments. Folding NaN and inf is +      /// likely to be aborted with an exception anyway, and some host libms +      /// have known errors raising exceptions. +      if (Op->getValueAPF().isNaN() || Op->getValueAPF().isInfinity()) +        return nullptr; + +      /// Currently APFloat versions of these functions do not exist, so we use +      /// the host native double versions.  Float versions are not called +      /// directly but for all these it is true (float)(f((double)arg)) == +      /// f(arg).  Long double not supported yet. +      double V = getValueAsDouble(Op); + +      switch (IntrinsicID) { +        default: break; +        case Intrinsic::fabs: +          return ConstantFoldFP(fabs, V, Ty); +        case Intrinsic::log2: +          return ConstantFoldFP(Log2, V, Ty); +        case Intrinsic::log: +          return ConstantFoldFP(log, V, Ty); +        case Intrinsic::log10: +          return ConstantFoldFP(log10, V, Ty); +        case Intrinsic::exp: +          return ConstantFoldFP(exp, V, Ty); +        case Intrinsic::exp2: +          return ConstantFoldFP(exp2, V, Ty); +        case Intrinsic::sin: +          return ConstantFoldFP(sin, V, Ty); +        case Intrinsic::cos: +          return ConstantFoldFP(cos, V, Ty); +        case Intrinsic::sqrt: +          return ConstantFoldFP(sqrt, V, Ty); +      } + +      if (!TLI) +        return nullptr; + +      char NameKeyChar = Name[0]; +      if (Name[0] == '_' && Name.size() > 2 && Name[1] == '_') +        NameKeyChar = Name[2]; + +      switch (NameKeyChar) { +      case 'a': +        if ((Name == "acos" && TLI->has(LibFunc_acos)) || +            (Name == "acosf" && TLI->has(LibFunc_acosf)) || +            (Name == "__acos_finite" && TLI->has(LibFunc_acos_finite)) || +            (Name == "__acosf_finite" && TLI->has(LibFunc_acosf_finite))) +          return ConstantFoldFP(acos, V, Ty); +        else if ((Name == "asin" && TLI->has(LibFunc_asin)) || +                 (Name == "asinf" && TLI->has(LibFunc_asinf)) || +                 (Name == "__asin_finite" && TLI->has(LibFunc_asin_finite)) || +                 (Name == "__asinf_finite" && TLI->has(LibFunc_asinf_finite))) +          return ConstantFoldFP(asin, V, Ty); +        else if ((Name == "atan" && TLI->has(LibFunc_atan)) || +                 (Name == "atanf" && TLI->has(LibFunc_atanf))) +          return ConstantFoldFP(atan, V, Ty); +        break; +      case 'c': +        if ((Name == "ceil" && TLI->has(LibFunc_ceil)) || +            (Name == "ceilf" && TLI->has(LibFunc_ceilf))) +          return ConstantFoldFP(ceil, V, Ty); +        else if ((Name == "cos" && TLI->has(LibFunc_cos)) || +                 (Name == "cosf" && TLI->has(LibFunc_cosf))) +          return ConstantFoldFP(cos, V, Ty); +        else if ((Name == "cosh" && TLI->has(LibFunc_cosh)) || +                 (Name == "coshf" && TLI->has(LibFunc_coshf)) || +                 (Name == "__cosh_finite" && TLI->has(LibFunc_cosh_finite)) || +                 (Name == "__coshf_finite" && TLI->has(LibFunc_coshf_finite))) +          return ConstantFoldFP(cosh, V, Ty); +        break; +      case 'e': +        if ((Name == "exp" && TLI->has(LibFunc_exp)) || +            (Name == "expf" && TLI->has(LibFunc_expf)) || +            (Name == "__exp_finite" && TLI->has(LibFunc_exp_finite)) || +            (Name == "__expf_finite" && TLI->has(LibFunc_expf_finite))) +          return ConstantFoldFP(exp, V, Ty); +        if ((Name == "exp2" && TLI->has(LibFunc_exp2)) || +            (Name == "exp2f" && TLI->has(LibFunc_exp2f)) || +            (Name == "__exp2_finite" && TLI->has(LibFunc_exp2_finite)) || +            (Name == "__exp2f_finite" && TLI->has(LibFunc_exp2f_finite))) +          // Constant fold exp2(x) as pow(2,x) in case the host doesn't have a +          // C99 library. +          return ConstantFoldBinaryFP(pow, 2.0, V, Ty); +        break; +      case 'f': +        if ((Name == "fabs" && TLI->has(LibFunc_fabs)) || +            (Name == "fabsf" && TLI->has(LibFunc_fabsf))) +          return ConstantFoldFP(fabs, V, Ty); +        else if ((Name == "floor" && TLI->has(LibFunc_floor)) || +                 (Name == "floorf" && TLI->has(LibFunc_floorf))) +          return ConstantFoldFP(floor, V, Ty); +        break; +      case 'l': +        if ((Name == "log" && V > 0 && TLI->has(LibFunc_log)) || +            (Name == "logf" && V > 0 && TLI->has(LibFunc_logf)) || +            (Name == "__log_finite" && V > 0 && +              TLI->has(LibFunc_log_finite)) || +            (Name == "__logf_finite" && V > 0 && +              TLI->has(LibFunc_logf_finite))) +          return ConstantFoldFP(log, V, Ty); +        else if ((Name == "log10" && V > 0 && TLI->has(LibFunc_log10)) || +                 (Name == "log10f" && V > 0 && TLI->has(LibFunc_log10f)) || +                 (Name == "__log10_finite" && V > 0 && +                   TLI->has(LibFunc_log10_finite)) || +                 (Name == "__log10f_finite" && V > 0 && +                   TLI->has(LibFunc_log10f_finite))) +          return ConstantFoldFP(log10, V, Ty); +        break; +      case 'r': +        if ((Name == "round" && TLI->has(LibFunc_round)) || +            (Name == "roundf" && TLI->has(LibFunc_roundf))) +          return ConstantFoldFP(round, V, Ty); +        break; +      case 's': +        if ((Name == "sin" && TLI->has(LibFunc_sin)) || +            (Name == "sinf" && TLI->has(LibFunc_sinf))) +          return ConstantFoldFP(sin, V, Ty); +        else if ((Name == "sinh" && TLI->has(LibFunc_sinh)) || +                 (Name == "sinhf" && TLI->has(LibFunc_sinhf)) || +                 (Name == "__sinh_finite" && TLI->has(LibFunc_sinh_finite)) || +                 (Name == "__sinhf_finite" && TLI->has(LibFunc_sinhf_finite))) +          return ConstantFoldFP(sinh, V, Ty); +        else if ((Name == "sqrt" && V >= 0 && TLI->has(LibFunc_sqrt)) || +                 (Name == "sqrtf" && V >= 0 && TLI->has(LibFunc_sqrtf))) +          return ConstantFoldFP(sqrt, V, Ty); +        break; +      case 't': +        if ((Name == "tan" && TLI->has(LibFunc_tan)) || +            (Name == "tanf" && TLI->has(LibFunc_tanf))) +          return ConstantFoldFP(tan, V, Ty); +        else if ((Name == "tanh" && TLI->has(LibFunc_tanh)) || +                 (Name == "tanhf" && TLI->has(LibFunc_tanhf))) +          return ConstantFoldFP(tanh, V, Ty); +        break; +      default: +        break; +      } +      return nullptr; +    } + +    if (auto *Op = dyn_cast<ConstantInt>(Operands[0])) { +      switch (IntrinsicID) { +      case Intrinsic::bswap: +        return ConstantInt::get(Ty->getContext(), Op->getValue().byteSwap()); +      case Intrinsic::ctpop: +        return ConstantInt::get(Ty, Op->getValue().countPopulation()); +      case Intrinsic::bitreverse: +        return ConstantInt::get(Ty->getContext(), Op->getValue().reverseBits()); +      case Intrinsic::convert_from_fp16: { +        APFloat Val(APFloat::IEEEhalf(), Op->getValue()); + +        bool lost = false; +        APFloat::opStatus status = Val.convert( +            Ty->getFltSemantics(), APFloat::rmNearestTiesToEven, &lost); + +        // Conversion is always precise. +        (void)status; +        assert(status == APFloat::opOK && !lost && +               "Precision lost during fp16 constfolding"); + +        return ConstantFP::get(Ty->getContext(), Val); +      } +      default: +        return nullptr; +      } +    } + +    // Support ConstantVector in case we have an Undef in the top. +    if (isa<ConstantVector>(Operands[0]) || +        isa<ConstantDataVector>(Operands[0])) { +      auto *Op = cast<Constant>(Operands[0]); +      switch (IntrinsicID) { +      default: break; +      case Intrinsic::x86_sse_cvtss2si: +      case Intrinsic::x86_sse_cvtss2si64: +      case Intrinsic::x86_sse2_cvtsd2si: +      case Intrinsic::x86_sse2_cvtsd2si64: +        if (ConstantFP *FPOp = +                dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) +          return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), +                                             /*roundTowardZero=*/false, Ty); +        break; +      case Intrinsic::x86_sse_cvttss2si: +      case Intrinsic::x86_sse_cvttss2si64: +      case Intrinsic::x86_sse2_cvttsd2si: +      case Intrinsic::x86_sse2_cvttsd2si64: +        if (ConstantFP *FPOp = +                dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) +          return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), +                                             /*roundTowardZero=*/true, Ty); +        break; +      } +    } + +    return nullptr; +  } + +  if (Operands.size() == 2) { +    if (auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) { +      if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) +        return nullptr; +      double Op1V = getValueAsDouble(Op1); + +      if (auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) { +        if (Op2->getType() != Op1->getType()) +          return nullptr; + +        double Op2V = getValueAsDouble(Op2); +        if (IntrinsicID == Intrinsic::pow) { +          return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); +        } +        if (IntrinsicID == Intrinsic::copysign) { +          APFloat V1 = Op1->getValueAPF(); +          const APFloat &V2 = Op2->getValueAPF(); +          V1.copySign(V2); +          return ConstantFP::get(Ty->getContext(), V1); +        } + +        if (IntrinsicID == Intrinsic::minnum) { +          const APFloat &C1 = Op1->getValueAPF(); +          const APFloat &C2 = Op2->getValueAPF(); +          return ConstantFP::get(Ty->getContext(), minnum(C1, C2)); +        } + +        if (IntrinsicID == Intrinsic::maxnum) { +          const APFloat &C1 = Op1->getValueAPF(); +          const APFloat &C2 = Op2->getValueAPF(); +          return ConstantFP::get(Ty->getContext(), maxnum(C1, C2)); +        } + +        if (!TLI) +          return nullptr; +        if ((Name == "pow" && TLI->has(LibFunc_pow)) || +            (Name == "powf" && TLI->has(LibFunc_powf)) || +            (Name == "__pow_finite" && TLI->has(LibFunc_pow_finite)) || +            (Name == "__powf_finite" && TLI->has(LibFunc_powf_finite))) +          return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty); +        if ((Name == "fmod" && TLI->has(LibFunc_fmod)) || +            (Name == "fmodf" && TLI->has(LibFunc_fmodf))) +          return ConstantFoldBinaryFP(fmod, Op1V, Op2V, Ty); +        if ((Name == "atan2" && TLI->has(LibFunc_atan2)) || +            (Name == "atan2f" && TLI->has(LibFunc_atan2f)) || +            (Name == "__atan2_finite" && TLI->has(LibFunc_atan2_finite)) || +            (Name == "__atan2f_finite" && TLI->has(LibFunc_atan2f_finite))) +          return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty); +      } else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) { +        if (IntrinsicID == Intrinsic::powi && Ty->isHalfTy()) +          return ConstantFP::get(Ty->getContext(), +                                 APFloat((float)std::pow((float)Op1V, +                                                 (int)Op2C->getZExtValue()))); +        if (IntrinsicID == Intrinsic::powi && Ty->isFloatTy()) +          return ConstantFP::get(Ty->getContext(), +                                 APFloat((float)std::pow((float)Op1V, +                                                 (int)Op2C->getZExtValue()))); +        if (IntrinsicID == Intrinsic::powi && Ty->isDoubleTy()) +          return ConstantFP::get(Ty->getContext(), +                                 APFloat((double)std::pow((double)Op1V, +                                                   (int)Op2C->getZExtValue()))); +      } +      return nullptr; +    } + +    if (auto *Op1 = dyn_cast<ConstantInt>(Operands[0])) { +      if (auto *Op2 = dyn_cast<ConstantInt>(Operands[1])) { +        switch (IntrinsicID) { +        default: break; +        case Intrinsic::sadd_with_overflow: +        case Intrinsic::uadd_with_overflow: +        case Intrinsic::ssub_with_overflow: +        case Intrinsic::usub_with_overflow: +        case Intrinsic::smul_with_overflow: +        case Intrinsic::umul_with_overflow: { +          APInt Res; +          bool Overflow; +          switch (IntrinsicID) { +          default: llvm_unreachable("Invalid case"); +          case Intrinsic::sadd_with_overflow: +            Res = Op1->getValue().sadd_ov(Op2->getValue(), Overflow); +            break; +          case Intrinsic::uadd_with_overflow: +            Res = Op1->getValue().uadd_ov(Op2->getValue(), Overflow); +            break; +          case Intrinsic::ssub_with_overflow: +            Res = Op1->getValue().ssub_ov(Op2->getValue(), Overflow); +            break; +          case Intrinsic::usub_with_overflow: +            Res = Op1->getValue().usub_ov(Op2->getValue(), Overflow); +            break; +          case Intrinsic::smul_with_overflow: +            Res = Op1->getValue().smul_ov(Op2->getValue(), Overflow); +            break; +          case Intrinsic::umul_with_overflow: +            Res = Op1->getValue().umul_ov(Op2->getValue(), Overflow); +            break; +          } +          Constant *Ops[] = { +            ConstantInt::get(Ty->getContext(), Res), +            ConstantInt::get(Type::getInt1Ty(Ty->getContext()), Overflow) +          }; +          return ConstantStruct::get(cast<StructType>(Ty), Ops); +        } +        case Intrinsic::cttz: +          if (Op2->isOne() && Op1->isZero()) // cttz(0, 1) is undef. +            return UndefValue::get(Ty); +          return ConstantInt::get(Ty, Op1->getValue().countTrailingZeros()); +        case Intrinsic::ctlz: +          if (Op2->isOne() && Op1->isZero()) // ctlz(0, 1) is undef. +            return UndefValue::get(Ty); +          return ConstantInt::get(Ty, Op1->getValue().countLeadingZeros()); +        } +      } + +      return nullptr; +    } +    return nullptr; +  } + +  if (Operands.size() != 3) +    return nullptr; + +  if (const auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) { +    if (const auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) { +      if (const auto *Op3 = dyn_cast<ConstantFP>(Operands[2])) { +        switch (IntrinsicID) { +        default: break; +        case Intrinsic::fma: +        case Intrinsic::fmuladd: { +          APFloat V = Op1->getValueAPF(); +          APFloat::opStatus s = V.fusedMultiplyAdd(Op2->getValueAPF(), +                                                   Op3->getValueAPF(), +                                                   APFloat::rmNearestTiesToEven); +          if (s != APFloat::opInvalidOp) +            return ConstantFP::get(Ty->getContext(), V); + +          return nullptr; +        } +        } +      } +    } +  } + +  return nullptr; +} + +Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, +                                 VectorType *VTy, ArrayRef<Constant *> Operands, +                                 const DataLayout &DL, +                                 const TargetLibraryInfo *TLI, +                                 ImmutableCallSite CS) { +  SmallVector<Constant *, 4> Result(VTy->getNumElements()); +  SmallVector<Constant *, 4> Lane(Operands.size()); +  Type *Ty = VTy->getElementType(); + +  if (IntrinsicID == Intrinsic::masked_load) { +    auto *SrcPtr = Operands[0]; +    auto *Mask = Operands[2]; +    auto *Passthru = Operands[3]; + +    Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, VTy, DL); + +    SmallVector<Constant *, 32> NewElements; +    for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { +      auto *MaskElt = Mask->getAggregateElement(I); +      if (!MaskElt) +        break; +      auto *PassthruElt = Passthru->getAggregateElement(I); +      auto *VecElt = VecData ? VecData->getAggregateElement(I) : nullptr; +      if (isa<UndefValue>(MaskElt)) { +        if (PassthruElt) +          NewElements.push_back(PassthruElt); +        else if (VecElt) +          NewElements.push_back(VecElt); +        else +          return nullptr; +      } +      if (MaskElt->isNullValue()) { +        if (!PassthruElt) +          return nullptr; +        NewElements.push_back(PassthruElt); +      } else if (MaskElt->isOneValue()) { +        if (!VecElt) +          return nullptr; +        NewElements.push_back(VecElt); +      } else { +        return nullptr; +      } +    } +    if (NewElements.size() != VTy->getNumElements()) +      return nullptr; +    return ConstantVector::get(NewElements); +  } + +  for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { +    // Gather a column of constants. +    for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) { +      // These intrinsics use a scalar type for their second argument. +      if (J == 1 && +          (IntrinsicID == Intrinsic::cttz || IntrinsicID == Intrinsic::ctlz || +           IntrinsicID == Intrinsic::powi)) { +        Lane[J] = Operands[J]; +        continue; +      } + +      Constant *Agg = Operands[J]->getAggregateElement(I); +      if (!Agg) +        return nullptr; + +      Lane[J] = Agg; +    } + +    // Use the regular scalar folding to simplify this column. +    Constant *Folded = ConstantFoldScalarCall(Name, IntrinsicID, Ty, Lane, TLI, CS); +    if (!Folded) +      return nullptr; +    Result[I] = Folded; +  } + +  return ConstantVector::get(Result); +} + +} // end anonymous namespace + +Constant * +llvm::ConstantFoldCall(ImmutableCallSite CS, Function *F, +                       ArrayRef<Constant *> Operands, +                       const TargetLibraryInfo *TLI) { +  if (CS.isNoBuiltin() || CS.isStrictFP()) +    return nullptr; +  if (!F->hasName()) +    return nullptr; +  StringRef Name = F->getName(); + +  Type *Ty = F->getReturnType(); + +  if (auto *VTy = dyn_cast<VectorType>(Ty)) +    return ConstantFoldVectorCall(Name, F->getIntrinsicID(), VTy, Operands, +                                  F->getParent()->getDataLayout(), TLI, CS); + +  return ConstantFoldScalarCall(Name, F->getIntrinsicID(), Ty, Operands, TLI, CS); +} + +bool llvm::isMathLibCallNoop(CallSite CS, const TargetLibraryInfo *TLI) { +  // FIXME: Refactor this code; this duplicates logic in LibCallsShrinkWrap +  // (and to some extent ConstantFoldScalarCall). +  if (CS.isNoBuiltin() || CS.isStrictFP()) +    return false; +  Function *F = CS.getCalledFunction(); +  if (!F) +    return false; + +  LibFunc Func; +  if (!TLI || !TLI->getLibFunc(*F, Func)) +    return false; + +  if (CS.getNumArgOperands() == 1) { +    if (ConstantFP *OpC = dyn_cast<ConstantFP>(CS.getArgOperand(0))) { +      const APFloat &Op = OpC->getValueAPF(); +      switch (Func) { +      case LibFunc_logl: +      case LibFunc_log: +      case LibFunc_logf: +      case LibFunc_log2l: +      case LibFunc_log2: +      case LibFunc_log2f: +      case LibFunc_log10l: +      case LibFunc_log10: +      case LibFunc_log10f: +        return Op.isNaN() || (!Op.isZero() && !Op.isNegative()); + +      case LibFunc_expl: +      case LibFunc_exp: +      case LibFunc_expf: +        // FIXME: These boundaries are slightly conservative. +        if (OpC->getType()->isDoubleTy()) +          return Op.compare(APFloat(-745.0)) != APFloat::cmpLessThan && +                 Op.compare(APFloat(709.0)) != APFloat::cmpGreaterThan; +        if (OpC->getType()->isFloatTy()) +          return Op.compare(APFloat(-103.0f)) != APFloat::cmpLessThan && +                 Op.compare(APFloat(88.0f)) != APFloat::cmpGreaterThan; +        break; + +      case LibFunc_exp2l: +      case LibFunc_exp2: +      case LibFunc_exp2f: +        // FIXME: These boundaries are slightly conservative. +        if (OpC->getType()->isDoubleTy()) +          return Op.compare(APFloat(-1074.0)) != APFloat::cmpLessThan && +                 Op.compare(APFloat(1023.0)) != APFloat::cmpGreaterThan; +        if (OpC->getType()->isFloatTy()) +          return Op.compare(APFloat(-149.0f)) != APFloat::cmpLessThan && +                 Op.compare(APFloat(127.0f)) != APFloat::cmpGreaterThan; +        break; + +      case LibFunc_sinl: +      case LibFunc_sin: +      case LibFunc_sinf: +      case LibFunc_cosl: +      case LibFunc_cos: +      case LibFunc_cosf: +        return !Op.isInfinity(); + +      case LibFunc_tanl: +      case LibFunc_tan: +      case LibFunc_tanf: { +        // FIXME: Stop using the host math library. +        // FIXME: The computation isn't done in the right precision. +        Type *Ty = OpC->getType(); +        if (Ty->isDoubleTy() || Ty->isFloatTy() || Ty->isHalfTy()) { +          double OpV = getValueAsDouble(OpC); +          return ConstantFoldFP(tan, OpV, Ty) != nullptr; +        } +        break; +      } + +      case LibFunc_asinl: +      case LibFunc_asin: +      case LibFunc_asinf: +      case LibFunc_acosl: +      case LibFunc_acos: +      case LibFunc_acosf: +        return Op.compare(APFloat(Op.getSemantics(), "-1")) != +                   APFloat::cmpLessThan && +               Op.compare(APFloat(Op.getSemantics(), "1")) != +                   APFloat::cmpGreaterThan; + +      case LibFunc_sinh: +      case LibFunc_cosh: +      case LibFunc_sinhf: +      case LibFunc_coshf: +      case LibFunc_sinhl: +      case LibFunc_coshl: +        // FIXME: These boundaries are slightly conservative. +        if (OpC->getType()->isDoubleTy()) +          return Op.compare(APFloat(-710.0)) != APFloat::cmpLessThan && +                 Op.compare(APFloat(710.0)) != APFloat::cmpGreaterThan; +        if (OpC->getType()->isFloatTy()) +          return Op.compare(APFloat(-89.0f)) != APFloat::cmpLessThan && +                 Op.compare(APFloat(89.0f)) != APFloat::cmpGreaterThan; +        break; + +      case LibFunc_sqrtl: +      case LibFunc_sqrt: +      case LibFunc_sqrtf: +        return Op.isNaN() || Op.isZero() || !Op.isNegative(); + +      // FIXME: Add more functions: sqrt_finite, atanh, expm1, log1p, +      // maybe others? +      default: +        break; +      } +    } +  } + +  if (CS.getNumArgOperands() == 2) { +    ConstantFP *Op0C = dyn_cast<ConstantFP>(CS.getArgOperand(0)); +    ConstantFP *Op1C = dyn_cast<ConstantFP>(CS.getArgOperand(1)); +    if (Op0C && Op1C) { +      const APFloat &Op0 = Op0C->getValueAPF(); +      const APFloat &Op1 = Op1C->getValueAPF(); + +      switch (Func) { +      case LibFunc_powl: +      case LibFunc_pow: +      case LibFunc_powf: { +        // FIXME: Stop using the host math library. +        // FIXME: The computation isn't done in the right precision. +        Type *Ty = Op0C->getType(); +        if (Ty->isDoubleTy() || Ty->isFloatTy() || Ty->isHalfTy()) { +          if (Ty == Op1C->getType()) { +            double Op0V = getValueAsDouble(Op0C); +            double Op1V = getValueAsDouble(Op1C); +            return ConstantFoldBinaryFP(pow, Op0V, Op1V, Ty) != nullptr; +          } +        } +        break; +      } + +      case LibFunc_fmodl: +      case LibFunc_fmod: +      case LibFunc_fmodf: +        return Op0.isNaN() || Op1.isNaN() || +               (!Op0.isInfinity() && !Op1.isZero()); + +      default: +        break; +      } +    } +  } + +  return false; +} diff --git a/contrib/llvm/lib/Analysis/CostModel.cpp b/contrib/llvm/lib/Analysis/CostModel.cpp new file mode 100644 index 000000000000..3d55bf20bb40 --- /dev/null +++ b/contrib/llvm/lib/Analysis/CostModel.cpp @@ -0,0 +1,112 @@ +//===- CostModel.cpp ------ Cost Model Analysis ---------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the cost model analysis. It provides a very basic cost +// estimation for LLVM-IR. This analysis uses the services of the codegen +// to approximate the cost of any IR instruction when lowered to machine +// instructions. The cost results are unit-less and the cost number represents +// the throughput of the machine assuming that all loads hit the cache, all +// branches are predicted, etc. The cost numbers can be added in order to +// compare two or more transformation alternatives. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +static cl::opt<TargetTransformInfo::TargetCostKind> CostKind( +    "cost-kind", cl::desc("Target cost kind"), +    cl::init(TargetTransformInfo::TCK_RecipThroughput), +    cl::values(clEnumValN(TargetTransformInfo::TCK_RecipThroughput, +                          "throughput", "Reciprocal throughput"), +               clEnumValN(TargetTransformInfo::TCK_Latency, +                          "latency", "Instruction latency"), +               clEnumValN(TargetTransformInfo::TCK_CodeSize, +                          "code-size", "Code size"))); + +#define CM_NAME "cost-model" +#define DEBUG_TYPE CM_NAME + +namespace { +  class CostModelAnalysis : public FunctionPass { + +  public: +    static char ID; // Class identification, replacement for typeinfo +    CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) { +      initializeCostModelAnalysisPass( +        *PassRegistry::getPassRegistry()); +    } + +    /// Returns the expected cost of the instruction. +    /// Returns -1 if the cost is unknown. +    /// Note, this method does not cache the cost calculation and it +    /// can be expensive in some cases. +    unsigned getInstructionCost(const Instruction *I) const { +      return TTI->getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput); +    } + +  private: +    void getAnalysisUsage(AnalysisUsage &AU) const override; +    bool runOnFunction(Function &F) override; +    void print(raw_ostream &OS, const Module*) const override; + +    /// The function that we analyze. +    Function *F; +    /// Target information. +    const TargetTransformInfo *TTI; +  }; +}  // End of anonymous namespace + +// Register this pass. +char CostModelAnalysis::ID = 0; +static const char cm_name[] = "Cost Model Analysis"; +INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true) +INITIALIZE_PASS_END  (CostModelAnalysis, CM_NAME, cm_name, false, true) + +FunctionPass *llvm::createCostModelAnalysisPass() { +  return new CostModelAnalysis(); +} + +void +CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +} + +bool +CostModelAnalysis::runOnFunction(Function &F) { + this->F = &F; + auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); + TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr; + + return false; +} + +void CostModelAnalysis::print(raw_ostream &OS, const Module*) const { +  if (!F) +    return; + +  for (BasicBlock &B : *F) { +    for (Instruction &Inst : B) { +      unsigned Cost = TTI->getInstructionCost(&Inst, CostKind); +      if (Cost != (unsigned)-1) +        OS << "Cost Model: Found an estimated cost of " << Cost; +      else +        OS << "Cost Model: Unknown cost"; + +      OS << " for instruction: " << Inst << "\n"; +    } +  } +} diff --git a/contrib/llvm/lib/Analysis/Delinearization.cpp b/contrib/llvm/lib/Analysis/Delinearization.cpp new file mode 100644 index 000000000000..4cafb7da16d3 --- /dev/null +++ b/contrib/llvm/lib/Analysis/Delinearization.cpp @@ -0,0 +1,130 @@ +//===---- Delinearization.cpp - MultiDimensional Index Delinearization ----===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This implements an analysis pass that tries to delinearize all GEP +// instructions in all loops using the SCEV analysis functionality. This pass is +// only used for testing purposes: if your pass needs delinearization, please +// use the on-demand SCEVAddRecExpr::delinearize() function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DL_NAME "delinearize" +#define DEBUG_TYPE DL_NAME + +namespace { + +class Delinearization : public FunctionPass { +  Delinearization(const Delinearization &); // do not implement +protected: +  Function *F; +  LoopInfo *LI; +  ScalarEvolution *SE; + +public: +  static char ID; // Pass identification, replacement for typeid + +  Delinearization() : FunctionPass(ID) { +    initializeDelinearizationPass(*PassRegistry::getPassRegistry()); +  } +  bool runOnFunction(Function &F) override; +  void getAnalysisUsage(AnalysisUsage &AU) const override; +  void print(raw_ostream &O, const Module *M = nullptr) const override; +}; + +} // end anonymous namespace + +void Delinearization::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<LoopInfoWrapperPass>(); +  AU.addRequired<ScalarEvolutionWrapperPass>(); +} + +bool Delinearization::runOnFunction(Function &F) { +  this->F = &F; +  SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); +  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  return false; +} + +void Delinearization::print(raw_ostream &O, const Module *) const { +  O << "Delinearization on function " << F->getName() << ":\n"; +  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { +    Instruction *Inst = &(*I); + +    // Only analyze loads and stores. +    if (!isa<StoreInst>(Inst) && !isa<LoadInst>(Inst) && +        !isa<GetElementPtrInst>(Inst)) +      continue; + +    const BasicBlock *BB = Inst->getParent(); +    // Delinearize the memory access as analyzed in all the surrounding loops. +    // Do not analyze memory accesses outside loops. +    for (Loop *L = LI->getLoopFor(BB); L != nullptr; L = L->getParentLoop()) { +      const SCEV *AccessFn = SE->getSCEVAtScope(getPointerOperand(Inst), L); + +      const SCEVUnknown *BasePointer = +          dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFn)); +      // Do not delinearize if we cannot find the base pointer. +      if (!BasePointer) +        break; +      AccessFn = SE->getMinusSCEV(AccessFn, BasePointer); + +      O << "\n"; +      O << "Inst:" << *Inst << "\n"; +      O << "In Loop with Header: " << L->getHeader()->getName() << "\n"; +      O << "AccessFunction: " << *AccessFn << "\n"; + +      SmallVector<const SCEV *, 3> Subscripts, Sizes; +      SE->delinearize(AccessFn, Subscripts, Sizes, SE->getElementSize(Inst)); +      if (Subscripts.size() == 0 || Sizes.size() == 0 || +          Subscripts.size() != Sizes.size()) { +        O << "failed to delinearize\n"; +        continue; +      } + +      O << "Base offset: " << *BasePointer << "\n"; +      O << "ArrayDecl[UnknownSize]"; +      int Size = Subscripts.size(); +      for (int i = 0; i < Size - 1; i++) +        O << "[" << *Sizes[i] << "]"; +      O << " with elements of " << *Sizes[Size - 1] << " bytes.\n"; + +      O << "ArrayRef"; +      for (int i = 0; i < Size; i++) +        O << "[" << *Subscripts[i] << "]"; +      O << "\n"; +    } +  } +} + +char Delinearization::ID = 0; +static const char delinearization_name[] = "Delinearization"; +INITIALIZE_PASS_BEGIN(Delinearization, DL_NAME, delinearization_name, true, +                      true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(Delinearization, DL_NAME, delinearization_name, true, true) + +FunctionPass *llvm::createDelinearizationPass() { return new Delinearization; } diff --git a/contrib/llvm/lib/Analysis/DemandedBits.cpp b/contrib/llvm/lib/Analysis/DemandedBits.cpp new file mode 100644 index 000000000000..e7637cd88327 --- /dev/null +++ b/contrib/llvm/lib/Analysis/DemandedBits.cpp @@ -0,0 +1,410 @@ +//===- DemandedBits.cpp - Determine demanded bits -------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements a demanded bits analysis. A demanded bit is one that +// contributes to a result; bits that are not demanded can be either zero or +// one without affecting control or data flow. For example in this sequence: +// +//   %1 = add i32 %x, %y +//   %2 = trunc i32 %1 to i16 +// +// Only the lowest 16 bits of %1 are demanded; the rest are removed by the +// trunc. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cstdint> + +using namespace llvm; + +#define DEBUG_TYPE "demanded-bits" + +char DemandedBitsWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits", +                      "Demanded bits analysis", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits", +                    "Demanded bits analysis", false, false) + +DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) { +  initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesCFG(); +  AU.addRequired<AssumptionCacheTracker>(); +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.setPreservesAll(); +} + +void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const { +  DB->print(OS); +} + +static bool isAlwaysLive(Instruction *I) { +  return isa<TerminatorInst>(I) || isa<DbgInfoIntrinsic>(I) || +      I->isEHPad() || I->mayHaveSideEffects(); +} + +void DemandedBits::determineLiveOperandBits( +    const Instruction *UserI, const Instruction *I, unsigned OperandNo, +    const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2) { +  unsigned BitWidth = AB.getBitWidth(); + +  // We're called once per operand, but for some instructions, we need to +  // compute known bits of both operands in order to determine the live bits of +  // either (when both operands are instructions themselves). We don't, +  // however, want to do this twice, so we cache the result in APInts that live +  // in the caller. For the two-relevant-operands case, both operand values are +  // provided here. +  auto ComputeKnownBits = +      [&](unsigned BitWidth, const Value *V1, const Value *V2) { +        const DataLayout &DL = I->getModule()->getDataLayout(); +        Known = KnownBits(BitWidth); +        computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT); + +        if (V2) { +          Known2 = KnownBits(BitWidth); +          computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT); +        } +      }; + +  switch (UserI->getOpcode()) { +  default: break; +  case Instruction::Call: +  case Instruction::Invoke: +    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI)) +      switch (II->getIntrinsicID()) { +      default: break; +      case Intrinsic::bswap: +        // The alive bits of the input are the swapped alive bits of +        // the output. +        AB = AOut.byteSwap(); +        break; +      case Intrinsic::bitreverse: +        // The alive bits of the input are the reversed alive bits of +        // the output. +        AB = AOut.reverseBits(); +        break; +      case Intrinsic::ctlz: +        if (OperandNo == 0) { +          // We need some output bits, so we need all bits of the +          // input to the left of, and including, the leftmost bit +          // known to be one. +          ComputeKnownBits(BitWidth, I, nullptr); +          AB = APInt::getHighBitsSet(BitWidth, +                 std::min(BitWidth, Known.countMaxLeadingZeros()+1)); +        } +        break; +      case Intrinsic::cttz: +        if (OperandNo == 0) { +          // We need some output bits, so we need all bits of the +          // input to the right of, and including, the rightmost bit +          // known to be one. +          ComputeKnownBits(BitWidth, I, nullptr); +          AB = APInt::getLowBitsSet(BitWidth, +                 std::min(BitWidth, Known.countMaxTrailingZeros()+1)); +        } +        break; +      } +    break; +  case Instruction::Add: +  case Instruction::Sub: +  case Instruction::Mul: +    // Find the highest live output bit. We don't need any more input +    // bits than that (adds, and thus subtracts, ripple only to the +    // left). +    AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits()); +    break; +  case Instruction::Shl: +    if (OperandNo == 0) +      if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) { +        uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); +        AB = AOut.lshr(ShiftAmt); + +        // If the shift is nuw/nsw, then the high bits are not dead +        // (because we've promised that they *must* be zero). +        const ShlOperator *S = cast<ShlOperator>(UserI); +        if (S->hasNoSignedWrap()) +          AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1); +        else if (S->hasNoUnsignedWrap()) +          AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt); +      } +    break; +  case Instruction::LShr: +    if (OperandNo == 0) +      if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) { +        uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); +        AB = AOut.shl(ShiftAmt); + +        // If the shift is exact, then the low bits are not dead +        // (they must be zero). +        if (cast<LShrOperator>(UserI)->isExact()) +          AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt); +      } +    break; +  case Instruction::AShr: +    if (OperandNo == 0) +      if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) { +        uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); +        AB = AOut.shl(ShiftAmt); +        // Because the high input bit is replicated into the +        // high-order bits of the result, if we need any of those +        // bits, then we must keep the highest input bit. +        if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt)) +            .getBoolValue()) +          AB.setSignBit(); + +        // If the shift is exact, then the low bits are not dead +        // (they must be zero). +        if (cast<AShrOperator>(UserI)->isExact()) +          AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt); +      } +    break; +  case Instruction::And: +    AB = AOut; + +    // For bits that are known zero, the corresponding bits in the +    // other operand are dead (unless they're both zero, in which +    // case they can't both be dead, so just mark the LHS bits as +    // dead). +    if (OperandNo == 0) { +      ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); +      AB &= ~Known2.Zero; +    } else { +      if (!isa<Instruction>(UserI->getOperand(0))) +        ComputeKnownBits(BitWidth, UserI->getOperand(0), I); +      AB &= ~(Known.Zero & ~Known2.Zero); +    } +    break; +  case Instruction::Or: +    AB = AOut; + +    // For bits that are known one, the corresponding bits in the +    // other operand are dead (unless they're both one, in which +    // case they can't both be dead, so just mark the LHS bits as +    // dead). +    if (OperandNo == 0) { +      ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); +      AB &= ~Known2.One; +    } else { +      if (!isa<Instruction>(UserI->getOperand(0))) +        ComputeKnownBits(BitWidth, UserI->getOperand(0), I); +      AB &= ~(Known.One & ~Known2.One); +    } +    break; +  case Instruction::Xor: +  case Instruction::PHI: +    AB = AOut; +    break; +  case Instruction::Trunc: +    AB = AOut.zext(BitWidth); +    break; +  case Instruction::ZExt: +    AB = AOut.trunc(BitWidth); +    break; +  case Instruction::SExt: +    AB = AOut.trunc(BitWidth); +    // Because the high input bit is replicated into the +    // high-order bits of the result, if we need any of those +    // bits, then we must keep the highest input bit. +    if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(), +                                      AOut.getBitWidth() - BitWidth)) +        .getBoolValue()) +      AB.setSignBit(); +    break; +  case Instruction::Select: +    if (OperandNo != 0) +      AB = AOut; +    break; +  } +} + +bool DemandedBitsWrapperPass::runOnFunction(Function &F) { +  auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  DB.emplace(F, AC, DT); +  return false; +} + +void DemandedBitsWrapperPass::releaseMemory() { +  DB.reset(); +} + +void DemandedBits::performAnalysis() { +  if (Analyzed) +    // Analysis already completed for this function. +    return; +  Analyzed = true; + +  Visited.clear(); +  AliveBits.clear(); + +  SmallVector<Instruction*, 128> Worklist; + +  // Collect the set of "root" instructions that are known live. +  for (Instruction &I : instructions(F)) { +    if (!isAlwaysLive(&I)) +      continue; + +    LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n"); +    // For integer-valued instructions, set up an initial empty set of alive +    // bits and add the instruction to the work list. For other instructions +    // add their operands to the work list (for integer values operands, mark +    // all bits as live). +    if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) { +      if (AliveBits.try_emplace(&I, IT->getBitWidth(), 0).second) +        Worklist.push_back(&I); + +      continue; +    } + +    // Non-integer-typed instructions... +    for (Use &OI : I.operands()) { +      if (Instruction *J = dyn_cast<Instruction>(OI)) { +        if (IntegerType *IT = dyn_cast<IntegerType>(J->getType())) +          AliveBits[J] = APInt::getAllOnesValue(IT->getBitWidth()); +        Worklist.push_back(J); +      } +    } +    // To save memory, we don't add I to the Visited set here. Instead, we +    // check isAlwaysLive on every instruction when searching for dead +    // instructions later (we need to check isAlwaysLive for the +    // integer-typed instructions anyway). +  } + +  // Propagate liveness backwards to operands. +  while (!Worklist.empty()) { +    Instruction *UserI = Worklist.pop_back_val(); + +    LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI); +    APInt AOut; +    if (UserI->getType()->isIntegerTy()) { +      AOut = AliveBits[UserI]; +      LLVM_DEBUG(dbgs() << " Alive Out: " << AOut); +    } +    LLVM_DEBUG(dbgs() << "\n"); + +    if (!UserI->getType()->isIntegerTy()) +      Visited.insert(UserI); + +    KnownBits Known, Known2; +    // Compute the set of alive bits for each operand. These are anded into the +    // existing set, if any, and if that changes the set of alive bits, the +    // operand is added to the work-list. +    for (Use &OI : UserI->operands()) { +      if (Instruction *I = dyn_cast<Instruction>(OI)) { +        if (IntegerType *IT = dyn_cast<IntegerType>(I->getType())) { +          unsigned BitWidth = IT->getBitWidth(); +          APInt AB = APInt::getAllOnesValue(BitWidth); +          if (UserI->getType()->isIntegerTy() && !AOut && +              !isAlwaysLive(UserI)) { +            AB = APInt(BitWidth, 0); +          } else { +            // If all bits of the output are dead, then all bits of the input +            // Bits of each operand that are used to compute alive bits of the +            // output are alive, all others are dead. +            determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB, +                                     Known, Known2); +          } + +          // If we've added to the set of alive bits (or the operand has not +          // been previously visited), then re-queue the operand to be visited +          // again. +          APInt ABPrev(BitWidth, 0); +          auto ABI = AliveBits.find(I); +          if (ABI != AliveBits.end()) +            ABPrev = ABI->second; + +          APInt ABNew = AB | ABPrev; +          if (ABNew != ABPrev || ABI == AliveBits.end()) { +            AliveBits[I] = std::move(ABNew); +            Worklist.push_back(I); +          } +        } else if (!Visited.count(I)) { +          Worklist.push_back(I); +        } +      } +    } +  } +} + +APInt DemandedBits::getDemandedBits(Instruction *I) { +  performAnalysis(); + +  const DataLayout &DL = I->getModule()->getDataLayout(); +  auto Found = AliveBits.find(I); +  if (Found != AliveBits.end()) +    return Found->second; +  return APInt::getAllOnesValue(DL.getTypeSizeInBits(I->getType())); +} + +bool DemandedBits::isInstructionDead(Instruction *I) { +  performAnalysis(); + +  return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() && +    !isAlwaysLive(I); +} + +void DemandedBits::print(raw_ostream &OS) { +  performAnalysis(); +  for (auto &KV : AliveBits) { +    OS << "DemandedBits: 0x" << Twine::utohexstr(KV.second.getLimitedValue()) +       << " for " << *KV.first << '\n'; +  } +} + +FunctionPass *llvm::createDemandedBitsWrapperPass() { +  return new DemandedBitsWrapperPass(); +} + +AnalysisKey DemandedBitsAnalysis::Key; + +DemandedBits DemandedBitsAnalysis::run(Function &F, +                                             FunctionAnalysisManager &AM) { +  auto &AC = AM.getResult<AssumptionAnalysis>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  return DemandedBits(F, AC, DT); +} + +PreservedAnalyses DemandedBitsPrinterPass::run(Function &F, +                                               FunctionAnalysisManager &AM) { +  AM.getResult<DemandedBitsAnalysis>(F).print(OS); +  return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Analysis/DependenceAnalysis.cpp b/contrib/llvm/lib/Analysis/DependenceAnalysis.cpp new file mode 100644 index 000000000000..79c2728d5620 --- /dev/null +++ b/contrib/llvm/lib/Analysis/DependenceAnalysis.cpp @@ -0,0 +1,3981 @@ +//===-- DependenceAnalysis.cpp - DA Implementation --------------*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// DependenceAnalysis is an LLVM pass that analyses dependences between memory +// accesses. Currently, it is an (incomplete) implementation of the approach +// described in +// +//            Practical Dependence Testing +//            Goff, Kennedy, Tseng +//            PLDI 1991 +// +// There's a single entry point that analyzes the dependence between a pair +// of memory references in a function, returning either NULL, for no dependence, +// or a more-or-less detailed description of the dependence between them. +// +// Currently, the implementation cannot propagate constraints between +// coupled RDIV subscripts and lacks a multi-subscript MIV test. +// Both of these are conservative weaknesses; +// that is, not a source of correctness problems. +// +// Since Clang linearizes some array subscripts, the dependence +// analysis is using SCEV->delinearize to recover the representation of multiple +// subscripts, and thus avoid the more expensive and less precise MIV tests. The +// delinearization is controlled by the flag -da-delinearize. +// +// We should pay some careful attention to the possibility of integer overflow +// in the implementation of the various tests. This could happen with Add, +// Subtract, or Multiply, with both APInt's and SCEV's. +// +// Some non-linear subscript pairs can be handled by the GCD test +// (and perhaps other tests). +// Should explore how often these things occur. +// +// Finally, it seems like certain test cases expose weaknesses in the SCEV +// simplification, especially in the handling of sign and zero extensions. +// It could be useful to spend time exploring these. +// +// Please note that this is work in progress and the interface is subject to +// change. +// +//===----------------------------------------------------------------------===// +//                                                                            // +//                   In memory of Ken Kennedy, 1945 - 2007                    // +//                                                                            // +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "da" + +//===----------------------------------------------------------------------===// +// statistics + +STATISTIC(TotalArrayPairs, "Array pairs tested"); +STATISTIC(SeparableSubscriptPairs, "Separable subscript pairs"); +STATISTIC(CoupledSubscriptPairs, "Coupled subscript pairs"); +STATISTIC(NonlinearSubscriptPairs, "Nonlinear subscript pairs"); +STATISTIC(ZIVapplications, "ZIV applications"); +STATISTIC(ZIVindependence, "ZIV independence"); +STATISTIC(StrongSIVapplications, "Strong SIV applications"); +STATISTIC(StrongSIVsuccesses, "Strong SIV successes"); +STATISTIC(StrongSIVindependence, "Strong SIV independence"); +STATISTIC(WeakCrossingSIVapplications, "Weak-Crossing SIV applications"); +STATISTIC(WeakCrossingSIVsuccesses, "Weak-Crossing SIV successes"); +STATISTIC(WeakCrossingSIVindependence, "Weak-Crossing SIV independence"); +STATISTIC(ExactSIVapplications, "Exact SIV applications"); +STATISTIC(ExactSIVsuccesses, "Exact SIV successes"); +STATISTIC(ExactSIVindependence, "Exact SIV independence"); +STATISTIC(WeakZeroSIVapplications, "Weak-Zero SIV applications"); +STATISTIC(WeakZeroSIVsuccesses, "Weak-Zero SIV successes"); +STATISTIC(WeakZeroSIVindependence, "Weak-Zero SIV independence"); +STATISTIC(ExactRDIVapplications, "Exact RDIV applications"); +STATISTIC(ExactRDIVindependence, "Exact RDIV independence"); +STATISTIC(SymbolicRDIVapplications, "Symbolic RDIV applications"); +STATISTIC(SymbolicRDIVindependence, "Symbolic RDIV independence"); +STATISTIC(DeltaApplications, "Delta applications"); +STATISTIC(DeltaSuccesses, "Delta successes"); +STATISTIC(DeltaIndependence, "Delta independence"); +STATISTIC(DeltaPropagations, "Delta propagations"); +STATISTIC(GCDapplications, "GCD applications"); +STATISTIC(GCDsuccesses, "GCD successes"); +STATISTIC(GCDindependence, "GCD independence"); +STATISTIC(BanerjeeApplications, "Banerjee applications"); +STATISTIC(BanerjeeIndependence, "Banerjee independence"); +STATISTIC(BanerjeeSuccesses, "Banerjee successes"); + +static cl::opt<bool> +    Delinearize("da-delinearize", cl::init(true), cl::Hidden, cl::ZeroOrMore, +                cl::desc("Try to delinearize array references.")); + +//===----------------------------------------------------------------------===// +// basics + +DependenceAnalysis::Result +DependenceAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { +  auto &AA = FAM.getResult<AAManager>(F); +  auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F); +  auto &LI = FAM.getResult<LoopAnalysis>(F); +  return DependenceInfo(&F, &AA, &SE, &LI); +} + +AnalysisKey DependenceAnalysis::Key; + +INITIALIZE_PASS_BEGIN(DependenceAnalysisWrapperPass, "da", +                      "Dependence Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(DependenceAnalysisWrapperPass, "da", "Dependence Analysis", +                    true, true) + +char DependenceAnalysisWrapperPass::ID = 0; + +FunctionPass *llvm::createDependenceAnalysisWrapperPass() { +  return new DependenceAnalysisWrapperPass(); +} + +bool DependenceAnalysisWrapperPass::runOnFunction(Function &F) { +  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); +  auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); +  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  info.reset(new DependenceInfo(&F, &AA, &SE, &LI)); +  return false; +} + +DependenceInfo &DependenceAnalysisWrapperPass::getDI() const { return *info; } + +void DependenceAnalysisWrapperPass::releaseMemory() { info.reset(); } + +void DependenceAnalysisWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequiredTransitive<AAResultsWrapperPass>(); +  AU.addRequiredTransitive<ScalarEvolutionWrapperPass>(); +  AU.addRequiredTransitive<LoopInfoWrapperPass>(); +} + + +// Used to test the dependence analyzer. +// Looks through the function, noting loads and stores. +// Calls depends() on every possible pair and prints out the result. +// Ignores all other instructions. +static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA) { +  auto *F = DA->getFunction(); +  for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); SrcI != SrcE; +       ++SrcI) { +    if (isa<StoreInst>(*SrcI) || isa<LoadInst>(*SrcI)) { +      for (inst_iterator DstI = SrcI, DstE = inst_end(F); +           DstI != DstE; ++DstI) { +        if (isa<StoreInst>(*DstI) || isa<LoadInst>(*DstI)) { +          OS << "da analyze - "; +          if (auto D = DA->depends(&*SrcI, &*DstI, true)) { +            D->dump(OS); +            for (unsigned Level = 1; Level <= D->getLevels(); Level++) { +              if (D->isSplitable(Level)) { +                OS << "da analyze - split level = " << Level; +                OS << ", iteration = " << *DA->getSplitIteration(*D, Level); +                OS << "!\n"; +              } +            } +          } +          else +            OS << "none!\n"; +        } +      } +    } +  } +} + +void DependenceAnalysisWrapperPass::print(raw_ostream &OS, +                                          const Module *) const { +  dumpExampleDependence(OS, info.get()); +} + +//===----------------------------------------------------------------------===// +// Dependence methods + +// Returns true if this is an input dependence. +bool Dependence::isInput() const { +  return Src->mayReadFromMemory() && Dst->mayReadFromMemory(); +} + + +// Returns true if this is an output dependence. +bool Dependence::isOutput() const { +  return Src->mayWriteToMemory() && Dst->mayWriteToMemory(); +} + + +// Returns true if this is an flow (aka true)  dependence. +bool Dependence::isFlow() const { +  return Src->mayWriteToMemory() && Dst->mayReadFromMemory(); +} + + +// Returns true if this is an anti dependence. +bool Dependence::isAnti() const { +  return Src->mayReadFromMemory() && Dst->mayWriteToMemory(); +} + + +// Returns true if a particular level is scalar; that is, +// if no subscript in the source or destination mention the induction +// variable associated with the loop at this level. +// Leave this out of line, so it will serve as a virtual method anchor +bool Dependence::isScalar(unsigned level) const { +  return false; +} + + +//===----------------------------------------------------------------------===// +// FullDependence methods + +FullDependence::FullDependence(Instruction *Source, Instruction *Destination, +                               bool PossiblyLoopIndependent, +                               unsigned CommonLevels) +    : Dependence(Source, Destination), Levels(CommonLevels), +      LoopIndependent(PossiblyLoopIndependent) { +  Consistent = true; +  if (CommonLevels) +    DV = make_unique<DVEntry[]>(CommonLevels); +} + +// The rest are simple getters that hide the implementation. + +// getDirection - Returns the direction associated with a particular level. +unsigned FullDependence::getDirection(unsigned Level) const { +  assert(0 < Level && Level <= Levels && "Level out of range"); +  return DV[Level - 1].Direction; +} + + +// Returns the distance (or NULL) associated with a particular level. +const SCEV *FullDependence::getDistance(unsigned Level) const { +  assert(0 < Level && Level <= Levels && "Level out of range"); +  return DV[Level - 1].Distance; +} + + +// Returns true if a particular level is scalar; that is, +// if no subscript in the source or destination mention the induction +// variable associated with the loop at this level. +bool FullDependence::isScalar(unsigned Level) const { +  assert(0 < Level && Level <= Levels && "Level out of range"); +  return DV[Level - 1].Scalar; +} + + +// Returns true if peeling the first iteration from this loop +// will break this dependence. +bool FullDependence::isPeelFirst(unsigned Level) const { +  assert(0 < Level && Level <= Levels && "Level out of range"); +  return DV[Level - 1].PeelFirst; +} + + +// Returns true if peeling the last iteration from this loop +// will break this dependence. +bool FullDependence::isPeelLast(unsigned Level) const { +  assert(0 < Level && Level <= Levels && "Level out of range"); +  return DV[Level - 1].PeelLast; +} + + +// Returns true if splitting this loop will break the dependence. +bool FullDependence::isSplitable(unsigned Level) const { +  assert(0 < Level && Level <= Levels && "Level out of range"); +  return DV[Level - 1].Splitable; +} + + +//===----------------------------------------------------------------------===// +// DependenceInfo::Constraint methods + +// If constraint is a point <X, Y>, returns X. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getX() const { +  assert(Kind == Point && "Kind should be Point"); +  return A; +} + + +// If constraint is a point <X, Y>, returns Y. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getY() const { +  assert(Kind == Point && "Kind should be Point"); +  return B; +} + + +// If constraint is a line AX + BY = C, returns A. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getA() const { +  assert((Kind == Line || Kind == Distance) && +         "Kind should be Line (or Distance)"); +  return A; +} + + +// If constraint is a line AX + BY = C, returns B. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getB() const { +  assert((Kind == Line || Kind == Distance) && +         "Kind should be Line (or Distance)"); +  return B; +} + + +// If constraint is a line AX + BY = C, returns C. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getC() const { +  assert((Kind == Line || Kind == Distance) && +         "Kind should be Line (or Distance)"); +  return C; +} + + +// If constraint is a distance, returns D. +// Otherwise assert. +const SCEV *DependenceInfo::Constraint::getD() const { +  assert(Kind == Distance && "Kind should be Distance"); +  return SE->getNegativeSCEV(C); +} + + +// Returns the loop associated with this constraint. +const Loop *DependenceInfo::Constraint::getAssociatedLoop() const { +  assert((Kind == Distance || Kind == Line || Kind == Point) && +         "Kind should be Distance, Line, or Point"); +  return AssociatedLoop; +} + +void DependenceInfo::Constraint::setPoint(const SCEV *X, const SCEV *Y, +                                          const Loop *CurLoop) { +  Kind = Point; +  A = X; +  B = Y; +  AssociatedLoop = CurLoop; +} + +void DependenceInfo::Constraint::setLine(const SCEV *AA, const SCEV *BB, +                                         const SCEV *CC, const Loop *CurLoop) { +  Kind = Line; +  A = AA; +  B = BB; +  C = CC; +  AssociatedLoop = CurLoop; +} + +void DependenceInfo::Constraint::setDistance(const SCEV *D, +                                             const Loop *CurLoop) { +  Kind = Distance; +  A = SE->getOne(D->getType()); +  B = SE->getNegativeSCEV(A); +  C = SE->getNegativeSCEV(D); +  AssociatedLoop = CurLoop; +} + +void DependenceInfo::Constraint::setEmpty() { Kind = Empty; } + +void DependenceInfo::Constraint::setAny(ScalarEvolution *NewSE) { +  SE = NewSE; +  Kind = Any; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +// For debugging purposes. Dumps the constraint out to OS. +LLVM_DUMP_METHOD void DependenceInfo::Constraint::dump(raw_ostream &OS) const { +  if (isEmpty()) +    OS << " Empty\n"; +  else if (isAny()) +    OS << " Any\n"; +  else if (isPoint()) +    OS << " Point is <" << *getX() << ", " << *getY() << ">\n"; +  else if (isDistance()) +    OS << " Distance is " << *getD() << +      " (" << *getA() << "*X + " << *getB() << "*Y = " << *getC() << ")\n"; +  else if (isLine()) +    OS << " Line is " << *getA() << "*X + " << +      *getB() << "*Y = " << *getC() << "\n"; +  else +    llvm_unreachable("unknown constraint type in Constraint::dump"); +} +#endif + + +// Updates X with the intersection +// of the Constraints X and Y. Returns true if X has changed. +// Corresponds to Figure 4 from the paper +// +//            Practical Dependence Testing +//            Goff, Kennedy, Tseng +//            PLDI 1991 +bool DependenceInfo::intersectConstraints(Constraint *X, const Constraint *Y) { +  ++DeltaApplications; +  LLVM_DEBUG(dbgs() << "\tintersect constraints\n"); +  LLVM_DEBUG(dbgs() << "\t    X ="; X->dump(dbgs())); +  LLVM_DEBUG(dbgs() << "\t    Y ="; Y->dump(dbgs())); +  assert(!Y->isPoint() && "Y must not be a Point"); +  if (X->isAny()) { +    if (Y->isAny()) +      return false; +    *X = *Y; +    return true; +  } +  if (X->isEmpty()) +    return false; +  if (Y->isEmpty()) { +    X->setEmpty(); +    return true; +  } + +  if (X->isDistance() && Y->isDistance()) { +    LLVM_DEBUG(dbgs() << "\t    intersect 2 distances\n"); +    if (isKnownPredicate(CmpInst::ICMP_EQ, X->getD(), Y->getD())) +      return false; +    if (isKnownPredicate(CmpInst::ICMP_NE, X->getD(), Y->getD())) { +      X->setEmpty(); +      ++DeltaSuccesses; +      return true; +    } +    // Hmmm, interesting situation. +    // I guess if either is constant, keep it and ignore the other. +    if (isa<SCEVConstant>(Y->getD())) { +      *X = *Y; +      return true; +    } +    return false; +  } + +  // At this point, the pseudo-code in Figure 4 of the paper +  // checks if (X->isPoint() && Y->isPoint()). +  // This case can't occur in our implementation, +  // since a Point can only arise as the result of intersecting +  // two Line constraints, and the right-hand value, Y, is never +  // the result of an intersection. +  assert(!(X->isPoint() && Y->isPoint()) && +         "We shouldn't ever see X->isPoint() && Y->isPoint()"); + +  if (X->isLine() && Y->isLine()) { +    LLVM_DEBUG(dbgs() << "\t    intersect 2 lines\n"); +    const SCEV *Prod1 = SE->getMulExpr(X->getA(), Y->getB()); +    const SCEV *Prod2 = SE->getMulExpr(X->getB(), Y->getA()); +    if (isKnownPredicate(CmpInst::ICMP_EQ, Prod1, Prod2)) { +      // slopes are equal, so lines are parallel +      LLVM_DEBUG(dbgs() << "\t\tsame slope\n"); +      Prod1 = SE->getMulExpr(X->getC(), Y->getB()); +      Prod2 = SE->getMulExpr(X->getB(), Y->getC()); +      if (isKnownPredicate(CmpInst::ICMP_EQ, Prod1, Prod2)) +        return false; +      if (isKnownPredicate(CmpInst::ICMP_NE, Prod1, Prod2)) { +        X->setEmpty(); +        ++DeltaSuccesses; +        return true; +      } +      return false; +    } +    if (isKnownPredicate(CmpInst::ICMP_NE, Prod1, Prod2)) { +      // slopes differ, so lines intersect +      LLVM_DEBUG(dbgs() << "\t\tdifferent slopes\n"); +      const SCEV *C1B2 = SE->getMulExpr(X->getC(), Y->getB()); +      const SCEV *C1A2 = SE->getMulExpr(X->getC(), Y->getA()); +      const SCEV *C2B1 = SE->getMulExpr(Y->getC(), X->getB()); +      const SCEV *C2A1 = SE->getMulExpr(Y->getC(), X->getA()); +      const SCEV *A1B2 = SE->getMulExpr(X->getA(), Y->getB()); +      const SCEV *A2B1 = SE->getMulExpr(Y->getA(), X->getB()); +      const SCEVConstant *C1A2_C2A1 = +        dyn_cast<SCEVConstant>(SE->getMinusSCEV(C1A2, C2A1)); +      const SCEVConstant *C1B2_C2B1 = +        dyn_cast<SCEVConstant>(SE->getMinusSCEV(C1B2, C2B1)); +      const SCEVConstant *A1B2_A2B1 = +        dyn_cast<SCEVConstant>(SE->getMinusSCEV(A1B2, A2B1)); +      const SCEVConstant *A2B1_A1B2 = +        dyn_cast<SCEVConstant>(SE->getMinusSCEV(A2B1, A1B2)); +      if (!C1B2_C2B1 || !C1A2_C2A1 || +          !A1B2_A2B1 || !A2B1_A1B2) +        return false; +      APInt Xtop = C1B2_C2B1->getAPInt(); +      APInt Xbot = A1B2_A2B1->getAPInt(); +      APInt Ytop = C1A2_C2A1->getAPInt(); +      APInt Ybot = A2B1_A1B2->getAPInt(); +      LLVM_DEBUG(dbgs() << "\t\tXtop = " << Xtop << "\n"); +      LLVM_DEBUG(dbgs() << "\t\tXbot = " << Xbot << "\n"); +      LLVM_DEBUG(dbgs() << "\t\tYtop = " << Ytop << "\n"); +      LLVM_DEBUG(dbgs() << "\t\tYbot = " << Ybot << "\n"); +      APInt Xq = Xtop; // these need to be initialized, even +      APInt Xr = Xtop; // though they're just going to be overwritten +      APInt::sdivrem(Xtop, Xbot, Xq, Xr); +      APInt Yq = Ytop; +      APInt Yr = Ytop; +      APInt::sdivrem(Ytop, Ybot, Yq, Yr); +      if (Xr != 0 || Yr != 0) { +        X->setEmpty(); +        ++DeltaSuccesses; +        return true; +      } +      LLVM_DEBUG(dbgs() << "\t\tX = " << Xq << ", Y = " << Yq << "\n"); +      if (Xq.slt(0) || Yq.slt(0)) { +        X->setEmpty(); +        ++DeltaSuccesses; +        return true; +      } +      if (const SCEVConstant *CUB = +          collectConstantUpperBound(X->getAssociatedLoop(), Prod1->getType())) { +        const APInt &UpperBound = CUB->getAPInt(); +        LLVM_DEBUG(dbgs() << "\t\tupper bound = " << UpperBound << "\n"); +        if (Xq.sgt(UpperBound) || Yq.sgt(UpperBound)) { +          X->setEmpty(); +          ++DeltaSuccesses; +          return true; +        } +      } +      X->setPoint(SE->getConstant(Xq), +                  SE->getConstant(Yq), +                  X->getAssociatedLoop()); +      ++DeltaSuccesses; +      return true; +    } +    return false; +  } + +  // if (X->isLine() && Y->isPoint()) This case can't occur. +  assert(!(X->isLine() && Y->isPoint()) && "This case should never occur"); + +  if (X->isPoint() && Y->isLine()) { +    LLVM_DEBUG(dbgs() << "\t    intersect Point and Line\n"); +    const SCEV *A1X1 = SE->getMulExpr(Y->getA(), X->getX()); +    const SCEV *B1Y1 = SE->getMulExpr(Y->getB(), X->getY()); +    const SCEV *Sum = SE->getAddExpr(A1X1, B1Y1); +    if (isKnownPredicate(CmpInst::ICMP_EQ, Sum, Y->getC())) +      return false; +    if (isKnownPredicate(CmpInst::ICMP_NE, Sum, Y->getC())) { +      X->setEmpty(); +      ++DeltaSuccesses; +      return true; +    } +    return false; +  } + +  llvm_unreachable("shouldn't reach the end of Constraint intersection"); +  return false; +} + + +//===----------------------------------------------------------------------===// +// DependenceInfo methods + +// For debugging purposes. Dumps a dependence to OS. +void Dependence::dump(raw_ostream &OS) const { +  bool Splitable = false; +  if (isConfused()) +    OS << "confused"; +  else { +    if (isConsistent()) +      OS << "consistent "; +    if (isFlow()) +      OS << "flow"; +    else if (isOutput()) +      OS << "output"; +    else if (isAnti()) +      OS << "anti"; +    else if (isInput()) +      OS << "input"; +    unsigned Levels = getLevels(); +    OS << " ["; +    for (unsigned II = 1; II <= Levels; ++II) { +      if (isSplitable(II)) +        Splitable = true; +      if (isPeelFirst(II)) +        OS << 'p'; +      const SCEV *Distance = getDistance(II); +      if (Distance) +        OS << *Distance; +      else if (isScalar(II)) +        OS << "S"; +      else { +        unsigned Direction = getDirection(II); +        if (Direction == DVEntry::ALL) +          OS << "*"; +        else { +          if (Direction & DVEntry::LT) +            OS << "<"; +          if (Direction & DVEntry::EQ) +            OS << "="; +          if (Direction & DVEntry::GT) +            OS << ">"; +        } +      } +      if (isPeelLast(II)) +        OS << 'p'; +      if (II < Levels) +        OS << " "; +    } +    if (isLoopIndependent()) +      OS << "|<"; +    OS << "]"; +    if (Splitable) +      OS << " splitable"; +  } +  OS << "!\n"; +} + +// Returns NoAlias/MayAliass/MustAlias for two memory locations based upon their +// underlaying objects. If LocA and LocB are known to not alias (for any reason: +// tbaa, non-overlapping regions etc), then it is known there is no dependecy. +// Otherwise the underlying objects are checked to see if they point to +// different identifiable objects. +static AliasResult underlyingObjectsAlias(AliasAnalysis *AA, +                                          const DataLayout &DL, +                                          const MemoryLocation &LocA, +                                          const MemoryLocation &LocB) { +  // Check the original locations (minus size) for noalias, which can happen for +  // tbaa, incompatible underlying object locations, etc. +  MemoryLocation LocAS(LocA.Ptr, MemoryLocation::UnknownSize, LocA.AATags); +  MemoryLocation LocBS(LocB.Ptr, MemoryLocation::UnknownSize, LocB.AATags); +  if (AA->alias(LocAS, LocBS) == NoAlias) +    return NoAlias; + +  // Check the underlying objects are the same +  const Value *AObj = GetUnderlyingObject(LocA.Ptr, DL); +  const Value *BObj = GetUnderlyingObject(LocB.Ptr, DL); + +  // If the underlying objects are the same, they must alias +  if (AObj == BObj) +    return MustAlias; + +  // We may have hit the recursion limit for underlying objects, or have +  // underlying objects where we don't know they will alias. +  if (!isIdentifiedObject(AObj) || !isIdentifiedObject(BObj)) +    return MayAlias; + +  // Otherwise we know the objects are different and both identified objects so +  // must not alias. +  return NoAlias; +} + + +// Returns true if the load or store can be analyzed. Atomic and volatile +// operations have properties which this analysis does not understand. +static +bool isLoadOrStore(const Instruction *I) { +  if (const LoadInst *LI = dyn_cast<LoadInst>(I)) +    return LI->isUnordered(); +  else if (const StoreInst *SI = dyn_cast<StoreInst>(I)) +    return SI->isUnordered(); +  return false; +} + + +// Examines the loop nesting of the Src and Dst +// instructions and establishes their shared loops. Sets the variables +// CommonLevels, SrcLevels, and MaxLevels. +// The source and destination instructions needn't be contained in the same +// loop. The routine establishNestingLevels finds the level of most deeply +// nested loop that contains them both, CommonLevels. An instruction that's +// not contained in a loop is at level = 0. MaxLevels is equal to the level +// of the source plus the level of the destination, minus CommonLevels. +// This lets us allocate vectors MaxLevels in length, with room for every +// distinct loop referenced in both the source and destination subscripts. +// The variable SrcLevels is the nesting depth of the source instruction. +// It's used to help calculate distinct loops referenced by the destination. +// Here's the map from loops to levels: +//            0 - unused +//            1 - outermost common loop +//          ... - other common loops +// CommonLevels - innermost common loop +//          ... - loops containing Src but not Dst +//    SrcLevels - innermost loop containing Src but not Dst +//          ... - loops containing Dst but not Src +//    MaxLevels - innermost loops containing Dst but not Src +// Consider the follow code fragment: +//   for (a = ...) { +//     for (b = ...) { +//       for (c = ...) { +//         for (d = ...) { +//           A[] = ...; +//         } +//       } +//       for (e = ...) { +//         for (f = ...) { +//           for (g = ...) { +//             ... = A[]; +//           } +//         } +//       } +//     } +//   } +// If we're looking at the possibility of a dependence between the store +// to A (the Src) and the load from A (the Dst), we'll note that they +// have 2 loops in common, so CommonLevels will equal 2 and the direction +// vector for Result will have 2 entries. SrcLevels = 4 and MaxLevels = 7. +// A map from loop names to loop numbers would look like +//     a - 1 +//     b - 2 = CommonLevels +//     c - 3 +//     d - 4 = SrcLevels +//     e - 5 +//     f - 6 +//     g - 7 = MaxLevels +void DependenceInfo::establishNestingLevels(const Instruction *Src, +                                            const Instruction *Dst) { +  const BasicBlock *SrcBlock = Src->getParent(); +  const BasicBlock *DstBlock = Dst->getParent(); +  unsigned SrcLevel = LI->getLoopDepth(SrcBlock); +  unsigned DstLevel = LI->getLoopDepth(DstBlock); +  const Loop *SrcLoop = LI->getLoopFor(SrcBlock); +  const Loop *DstLoop = LI->getLoopFor(DstBlock); +  SrcLevels = SrcLevel; +  MaxLevels = SrcLevel + DstLevel; +  while (SrcLevel > DstLevel) { +    SrcLoop = SrcLoop->getParentLoop(); +    SrcLevel--; +  } +  while (DstLevel > SrcLevel) { +    DstLoop = DstLoop->getParentLoop(); +    DstLevel--; +  } +  while (SrcLoop != DstLoop) { +    SrcLoop = SrcLoop->getParentLoop(); +    DstLoop = DstLoop->getParentLoop(); +    SrcLevel--; +  } +  CommonLevels = SrcLevel; +  MaxLevels -= CommonLevels; +} + + +// Given one of the loops containing the source, return +// its level index in our numbering scheme. +unsigned DependenceInfo::mapSrcLoop(const Loop *SrcLoop) const { +  return SrcLoop->getLoopDepth(); +} + + +// Given one of the loops containing the destination, +// return its level index in our numbering scheme. +unsigned DependenceInfo::mapDstLoop(const Loop *DstLoop) const { +  unsigned D = DstLoop->getLoopDepth(); +  if (D > CommonLevels) +    return D - CommonLevels + SrcLevels; +  else +    return D; +} + + +// Returns true if Expression is loop invariant in LoopNest. +bool DependenceInfo::isLoopInvariant(const SCEV *Expression, +                                     const Loop *LoopNest) const { +  if (!LoopNest) +    return true; +  return SE->isLoopInvariant(Expression, LoopNest) && +    isLoopInvariant(Expression, LoopNest->getParentLoop()); +} + + + +// Finds the set of loops from the LoopNest that +// have a level <= CommonLevels and are referred to by the SCEV Expression. +void DependenceInfo::collectCommonLoops(const SCEV *Expression, +                                        const Loop *LoopNest, +                                        SmallBitVector &Loops) const { +  while (LoopNest) { +    unsigned Level = LoopNest->getLoopDepth(); +    if (Level <= CommonLevels && !SE->isLoopInvariant(Expression, LoopNest)) +      Loops.set(Level); +    LoopNest = LoopNest->getParentLoop(); +  } +} + +void DependenceInfo::unifySubscriptType(ArrayRef<Subscript *> Pairs) { + +  unsigned widestWidthSeen = 0; +  Type *widestType; + +  // Go through each pair and find the widest bit to which we need +  // to extend all of them. +  for (Subscript *Pair : Pairs) { +    const SCEV *Src = Pair->Src; +    const SCEV *Dst = Pair->Dst; +    IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); +    IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); +    if (SrcTy == nullptr || DstTy == nullptr) { +      assert(SrcTy == DstTy && "This function only unify integer types and " +             "expect Src and Dst share the same type " +             "otherwise."); +      continue; +    } +    if (SrcTy->getBitWidth() > widestWidthSeen) { +      widestWidthSeen = SrcTy->getBitWidth(); +      widestType = SrcTy; +    } +    if (DstTy->getBitWidth() > widestWidthSeen) { +      widestWidthSeen = DstTy->getBitWidth(); +      widestType = DstTy; +    } +  } + + +  assert(widestWidthSeen > 0); + +  // Now extend each pair to the widest seen. +  for (Subscript *Pair : Pairs) { +    const SCEV *Src = Pair->Src; +    const SCEV *Dst = Pair->Dst; +    IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); +    IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); +    if (SrcTy == nullptr || DstTy == nullptr) { +      assert(SrcTy == DstTy && "This function only unify integer types and " +             "expect Src and Dst share the same type " +             "otherwise."); +      continue; +    } +    if (SrcTy->getBitWidth() < widestWidthSeen) +      // Sign-extend Src to widestType +      Pair->Src = SE->getSignExtendExpr(Src, widestType); +    if (DstTy->getBitWidth() < widestWidthSeen) { +      // Sign-extend Dst to widestType +      Pair->Dst = SE->getSignExtendExpr(Dst, widestType); +    } +  } +} + +// removeMatchingExtensions - Examines a subscript pair. +// If the source and destination are identically sign (or zero) +// extended, it strips off the extension in an effect to simplify +// the actual analysis. +void DependenceInfo::removeMatchingExtensions(Subscript *Pair) { +  const SCEV *Src = Pair->Src; +  const SCEV *Dst = Pair->Dst; +  if ((isa<SCEVZeroExtendExpr>(Src) && isa<SCEVZeroExtendExpr>(Dst)) || +      (isa<SCEVSignExtendExpr>(Src) && isa<SCEVSignExtendExpr>(Dst))) { +    const SCEVCastExpr *SrcCast = cast<SCEVCastExpr>(Src); +    const SCEVCastExpr *DstCast = cast<SCEVCastExpr>(Dst); +    const SCEV *SrcCastOp = SrcCast->getOperand(); +    const SCEV *DstCastOp = DstCast->getOperand(); +    if (SrcCastOp->getType() == DstCastOp->getType()) { +      Pair->Src = SrcCastOp; +      Pair->Dst = DstCastOp; +    } +  } +} + + +// Examine the scev and return true iff it's linear. +// Collect any loops mentioned in the set of "Loops". +bool DependenceInfo::checkSrcSubscript(const SCEV *Src, const Loop *LoopNest, +                                       SmallBitVector &Loops) { +  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Src); +  if (!AddRec) +    return isLoopInvariant(Src, LoopNest); +  const SCEV *Start = AddRec->getStart(); +  const SCEV *Step = AddRec->getStepRecurrence(*SE); +  const SCEV *UB = SE->getBackedgeTakenCount(AddRec->getLoop()); +  if (!isa<SCEVCouldNotCompute>(UB)) { +    if (SE->getTypeSizeInBits(Start->getType()) < +        SE->getTypeSizeInBits(UB->getType())) { +      if (!AddRec->getNoWrapFlags()) +        return false; +    } +  } +  if (!isLoopInvariant(Step, LoopNest)) +    return false; +  Loops.set(mapSrcLoop(AddRec->getLoop())); +  return checkSrcSubscript(Start, LoopNest, Loops); +} + + + +// Examine the scev and return true iff it's linear. +// Collect any loops mentioned in the set of "Loops". +bool DependenceInfo::checkDstSubscript(const SCEV *Dst, const Loop *LoopNest, +                                       SmallBitVector &Loops) { +  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Dst); +  if (!AddRec) +    return isLoopInvariant(Dst, LoopNest); +  const SCEV *Start = AddRec->getStart(); +  const SCEV *Step = AddRec->getStepRecurrence(*SE); +  const SCEV *UB = SE->getBackedgeTakenCount(AddRec->getLoop()); +  if (!isa<SCEVCouldNotCompute>(UB)) { +    if (SE->getTypeSizeInBits(Start->getType()) < +        SE->getTypeSizeInBits(UB->getType())) { +      if (!AddRec->getNoWrapFlags()) +        return false; +    } +  } +  if (!isLoopInvariant(Step, LoopNest)) +    return false; +  Loops.set(mapDstLoop(AddRec->getLoop())); +  return checkDstSubscript(Start, LoopNest, Loops); +} + + +// Examines the subscript pair (the Src and Dst SCEVs) +// and classifies it as either ZIV, SIV, RDIV, MIV, or Nonlinear. +// Collects the associated loops in a set. +DependenceInfo::Subscript::ClassificationKind +DependenceInfo::classifyPair(const SCEV *Src, const Loop *SrcLoopNest, +                             const SCEV *Dst, const Loop *DstLoopNest, +                             SmallBitVector &Loops) { +  SmallBitVector SrcLoops(MaxLevels + 1); +  SmallBitVector DstLoops(MaxLevels + 1); +  if (!checkSrcSubscript(Src, SrcLoopNest, SrcLoops)) +    return Subscript::NonLinear; +  if (!checkDstSubscript(Dst, DstLoopNest, DstLoops)) +    return Subscript::NonLinear; +  Loops = SrcLoops; +  Loops |= DstLoops; +  unsigned N = Loops.count(); +  if (N == 0) +    return Subscript::ZIV; +  if (N == 1) +    return Subscript::SIV; +  if (N == 2 && (SrcLoops.count() == 0 || +                 DstLoops.count() == 0 || +                 (SrcLoops.count() == 1 && DstLoops.count() == 1))) +    return Subscript::RDIV; +  return Subscript::MIV; +} + + +// A wrapper around SCEV::isKnownPredicate. +// Looks for cases where we're interested in comparing for equality. +// If both X and Y have been identically sign or zero extended, +// it strips off the (confusing) extensions before invoking +// SCEV::isKnownPredicate. Perhaps, someday, the ScalarEvolution package +// will be similarly updated. +// +// If SCEV::isKnownPredicate can't prove the predicate, +// we try simple subtraction, which seems to help in some cases +// involving symbolics. +bool DependenceInfo::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *X, +                                      const SCEV *Y) const { +  if (Pred == CmpInst::ICMP_EQ || +      Pred == CmpInst::ICMP_NE) { +    if ((isa<SCEVSignExtendExpr>(X) && +         isa<SCEVSignExtendExpr>(Y)) || +        (isa<SCEVZeroExtendExpr>(X) && +         isa<SCEVZeroExtendExpr>(Y))) { +      const SCEVCastExpr *CX = cast<SCEVCastExpr>(X); +      const SCEVCastExpr *CY = cast<SCEVCastExpr>(Y); +      const SCEV *Xop = CX->getOperand(); +      const SCEV *Yop = CY->getOperand(); +      if (Xop->getType() == Yop->getType()) { +        X = Xop; +        Y = Yop; +      } +    } +  } +  if (SE->isKnownPredicate(Pred, X, Y)) +    return true; +  // If SE->isKnownPredicate can't prove the condition, +  // we try the brute-force approach of subtracting +  // and testing the difference. +  // By testing with SE->isKnownPredicate first, we avoid +  // the possibility of overflow when the arguments are constants. +  const SCEV *Delta = SE->getMinusSCEV(X, Y); +  switch (Pred) { +  case CmpInst::ICMP_EQ: +    return Delta->isZero(); +  case CmpInst::ICMP_NE: +    return SE->isKnownNonZero(Delta); +  case CmpInst::ICMP_SGE: +    return SE->isKnownNonNegative(Delta); +  case CmpInst::ICMP_SLE: +    return SE->isKnownNonPositive(Delta); +  case CmpInst::ICMP_SGT: +    return SE->isKnownPositive(Delta); +  case CmpInst::ICMP_SLT: +    return SE->isKnownNegative(Delta); +  default: +    llvm_unreachable("unexpected predicate in isKnownPredicate"); +  } +} + +/// Compare to see if S is less than Size, using isKnownNegative(S - max(Size, 1)) +/// with some extra checking if S is an AddRec and we can prove less-than using +/// the loop bounds. +bool DependenceInfo::isKnownLessThan(const SCEV *S, const SCEV *Size) const { +  // First unify to the same type +  auto *SType = dyn_cast<IntegerType>(S->getType()); +  auto *SizeType = dyn_cast<IntegerType>(Size->getType()); +  if (!SType || !SizeType) +    return false; +  Type *MaxType = +      (SType->getBitWidth() >= SizeType->getBitWidth()) ? SType : SizeType; +  S = SE->getTruncateOrZeroExtend(S, MaxType); +  Size = SE->getTruncateOrZeroExtend(Size, MaxType); + +  // Special check for addrecs using BE taken count +  const SCEV *Bound = SE->getMinusSCEV(S, Size); +  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Bound)) { +    if (AddRec->isAffine()) { +      const SCEV *BECount = SE->getBackedgeTakenCount(AddRec->getLoop()); +      if (!isa<SCEVCouldNotCompute>(BECount)) { +        const SCEV *Limit = AddRec->evaluateAtIteration(BECount, *SE); +        if (SE->isKnownNegative(Limit)) +          return true; +      } +    } +  } + +  // Check using normal isKnownNegative +  const SCEV *LimitedBound = +      SE->getMinusSCEV(S, SE->getSMaxExpr(Size, SE->getOne(Size->getType()))); +  return SE->isKnownNegative(LimitedBound); +} + +bool DependenceInfo::isKnownNonNegative(const SCEV *S, const Value *Ptr) const { +  bool Inbounds = false; +  if (auto *SrcGEP = dyn_cast<GetElementPtrInst>(Ptr)) +    Inbounds = SrcGEP->isInBounds(); +  if (Inbounds) { +    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { +      if (AddRec->isAffine()) { +        // We know S is for Ptr, the operand on a load/store, so doesn't wrap. +        // If both parts are NonNegative, the end result will be NonNegative +        if (SE->isKnownNonNegative(AddRec->getStart()) && +            SE->isKnownNonNegative(AddRec->getOperand(1))) +          return true; +      } +    } +  } + +  return SE->isKnownNonNegative(S); +} + +// All subscripts are all the same type. +// Loop bound may be smaller (e.g., a char). +// Should zero extend loop bound, since it's always >= 0. +// This routine collects upper bound and extends or truncates if needed. +// Truncating is safe when subscripts are known not to wrap. Cases without +// nowrap flags should have been rejected earlier. +// Return null if no bound available. +const SCEV *DependenceInfo::collectUpperBound(const Loop *L, Type *T) const { +  if (SE->hasLoopInvariantBackedgeTakenCount(L)) { +    const SCEV *UB = SE->getBackedgeTakenCount(L); +    return SE->getTruncateOrZeroExtend(UB, T); +  } +  return nullptr; +} + + +// Calls collectUpperBound(), then attempts to cast it to SCEVConstant. +// If the cast fails, returns NULL. +const SCEVConstant *DependenceInfo::collectConstantUpperBound(const Loop *L, +                                                              Type *T) const { +  if (const SCEV *UB = collectUpperBound(L, T)) +    return dyn_cast<SCEVConstant>(UB); +  return nullptr; +} + + +// testZIV - +// When we have a pair of subscripts of the form [c1] and [c2], +// where c1 and c2 are both loop invariant, we attack it using +// the ZIV test. Basically, we test by comparing the two values, +// but there are actually three possible results: +// 1) the values are equal, so there's a dependence +// 2) the values are different, so there's no dependence +// 3) the values might be equal, so we have to assume a dependence. +// +// Return true if dependence disproved. +bool DependenceInfo::testZIV(const SCEV *Src, const SCEV *Dst, +                             FullDependence &Result) const { +  LLVM_DEBUG(dbgs() << "    src = " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "    dst = " << *Dst << "\n"); +  ++ZIVapplications; +  if (isKnownPredicate(CmpInst::ICMP_EQ, Src, Dst)) { +    LLVM_DEBUG(dbgs() << "    provably dependent\n"); +    return false; // provably dependent +  } +  if (isKnownPredicate(CmpInst::ICMP_NE, Src, Dst)) { +    LLVM_DEBUG(dbgs() << "    provably independent\n"); +    ++ZIVindependence; +    return true; // provably independent +  } +  LLVM_DEBUG(dbgs() << "    possibly dependent\n"); +  Result.Consistent = false; +  return false; // possibly dependent +} + + +// strongSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.1 +// +// When we have a pair of subscripts of the form [c1 + a*i] and [c2 + a*i], +// where i is an induction variable, c1 and c2 are loop invariant, +//  and a is a constant, we can solve it exactly using the Strong SIV test. +// +// Can prove independence. Failing that, can compute distance (and direction). +// In the presence of symbolic terms, we can sometimes make progress. +// +// If there's a dependence, +// +//    c1 + a*i = c2 + a*i' +// +// The dependence distance is +// +//    d = i' - i = (c1 - c2)/a +// +// A dependence only exists if d is an integer and abs(d) <= U, where U is the +// loop's upper bound. If a dependence exists, the dependence direction is +// defined as +// +//                { < if d > 0 +//    direction = { = if d = 0 +//                { > if d < 0 +// +// Return true if dependence disproved. +bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst, +                                   const SCEV *DstConst, const Loop *CurLoop, +                                   unsigned Level, FullDependence &Result, +                                   Constraint &NewConstraint) const { +  LLVM_DEBUG(dbgs() << "\tStrong SIV test\n"); +  LLVM_DEBUG(dbgs() << "\t    Coeff = " << *Coeff); +  LLVM_DEBUG(dbgs() << ", " << *Coeff->getType() << "\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcConst = " << *SrcConst); +  LLVM_DEBUG(dbgs() << ", " << *SrcConst->getType() << "\n"); +  LLVM_DEBUG(dbgs() << "\t    DstConst = " << *DstConst); +  LLVM_DEBUG(dbgs() << ", " << *DstConst->getType() << "\n"); +  ++StrongSIVapplications; +  assert(0 < Level && Level <= CommonLevels && "level out of range"); +  Level--; + +  const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst); +  LLVM_DEBUG(dbgs() << "\t    Delta = " << *Delta); +  LLVM_DEBUG(dbgs() << ", " << *Delta->getType() << "\n"); + +  // check that |Delta| < iteration count +  if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { +    LLVM_DEBUG(dbgs() << "\t    UpperBound = " << *UpperBound); +    LLVM_DEBUG(dbgs() << ", " << *UpperBound->getType() << "\n"); +    const SCEV *AbsDelta = +      SE->isKnownNonNegative(Delta) ? Delta : SE->getNegativeSCEV(Delta); +    const SCEV *AbsCoeff = +      SE->isKnownNonNegative(Coeff) ? Coeff : SE->getNegativeSCEV(Coeff); +    const SCEV *Product = SE->getMulExpr(UpperBound, AbsCoeff); +    if (isKnownPredicate(CmpInst::ICMP_SGT, AbsDelta, Product)) { +      // Distance greater than trip count - no dependence +      ++StrongSIVindependence; +      ++StrongSIVsuccesses; +      return true; +    } +  } + +  // Can we compute distance? +  if (isa<SCEVConstant>(Delta) && isa<SCEVConstant>(Coeff)) { +    APInt ConstDelta = cast<SCEVConstant>(Delta)->getAPInt(); +    APInt ConstCoeff = cast<SCEVConstant>(Coeff)->getAPInt(); +    APInt Distance  = ConstDelta; // these need to be initialized +    APInt Remainder = ConstDelta; +    APInt::sdivrem(ConstDelta, ConstCoeff, Distance, Remainder); +    LLVM_DEBUG(dbgs() << "\t    Distance = " << Distance << "\n"); +    LLVM_DEBUG(dbgs() << "\t    Remainder = " << Remainder << "\n"); +    // Make sure Coeff divides Delta exactly +    if (Remainder != 0) { +      // Coeff doesn't divide Distance, no dependence +      ++StrongSIVindependence; +      ++StrongSIVsuccesses; +      return true; +    } +    Result.DV[Level].Distance = SE->getConstant(Distance); +    NewConstraint.setDistance(SE->getConstant(Distance), CurLoop); +    if (Distance.sgt(0)) +      Result.DV[Level].Direction &= Dependence::DVEntry::LT; +    else if (Distance.slt(0)) +      Result.DV[Level].Direction &= Dependence::DVEntry::GT; +    else +      Result.DV[Level].Direction &= Dependence::DVEntry::EQ; +    ++StrongSIVsuccesses; +  } +  else if (Delta->isZero()) { +    // since 0/X == 0 +    Result.DV[Level].Distance = Delta; +    NewConstraint.setDistance(Delta, CurLoop); +    Result.DV[Level].Direction &= Dependence::DVEntry::EQ; +    ++StrongSIVsuccesses; +  } +  else { +    if (Coeff->isOne()) { +      LLVM_DEBUG(dbgs() << "\t    Distance = " << *Delta << "\n"); +      Result.DV[Level].Distance = Delta; // since X/1 == X +      NewConstraint.setDistance(Delta, CurLoop); +    } +    else { +      Result.Consistent = false; +      NewConstraint.setLine(Coeff, +                            SE->getNegativeSCEV(Coeff), +                            SE->getNegativeSCEV(Delta), CurLoop); +    } + +    // maybe we can get a useful direction +    bool DeltaMaybeZero     = !SE->isKnownNonZero(Delta); +    bool DeltaMaybePositive = !SE->isKnownNonPositive(Delta); +    bool DeltaMaybeNegative = !SE->isKnownNonNegative(Delta); +    bool CoeffMaybePositive = !SE->isKnownNonPositive(Coeff); +    bool CoeffMaybeNegative = !SE->isKnownNonNegative(Coeff); +    // The double negatives above are confusing. +    // It helps to read !SE->isKnownNonZero(Delta) +    // as "Delta might be Zero" +    unsigned NewDirection = Dependence::DVEntry::NONE; +    if ((DeltaMaybePositive && CoeffMaybePositive) || +        (DeltaMaybeNegative && CoeffMaybeNegative)) +      NewDirection = Dependence::DVEntry::LT; +    if (DeltaMaybeZero) +      NewDirection |= Dependence::DVEntry::EQ; +    if ((DeltaMaybeNegative && CoeffMaybePositive) || +        (DeltaMaybePositive && CoeffMaybeNegative)) +      NewDirection |= Dependence::DVEntry::GT; +    if (NewDirection < Result.DV[Level].Direction) +      ++StrongSIVsuccesses; +    Result.DV[Level].Direction &= NewDirection; +  } +  return false; +} + + +// weakCrossingSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.2 +// +// When we have a pair of subscripts of the form [c1 + a*i] and [c2 - a*i], +// where i is an induction variable, c1 and c2 are loop invariant, +// and a is a constant, we can solve it exactly using the +// Weak-Crossing SIV test. +// +// Given c1 + a*i = c2 - a*i', we can look for the intersection of +// the two lines, where i = i', yielding +// +//    c1 + a*i = c2 - a*i +//    2a*i = c2 - c1 +//    i = (c2 - c1)/2a +// +// If i < 0, there is no dependence. +// If i > upperbound, there is no dependence. +// If i = 0 (i.e., if c1 = c2), there's a dependence with distance = 0. +// If i = upperbound, there's a dependence with distance = 0. +// If i is integral, there's a dependence (all directions). +// If the non-integer part = 1/2, there's a dependence (<> directions). +// Otherwise, there's no dependence. +// +// Can prove independence. Failing that, +// can sometimes refine the directions. +// Can determine iteration for splitting. +// +// Return true if dependence disproved. +bool DependenceInfo::weakCrossingSIVtest( +    const SCEV *Coeff, const SCEV *SrcConst, const SCEV *DstConst, +    const Loop *CurLoop, unsigned Level, FullDependence &Result, +    Constraint &NewConstraint, const SCEV *&SplitIter) const { +  LLVM_DEBUG(dbgs() << "\tWeak-Crossing SIV test\n"); +  LLVM_DEBUG(dbgs() << "\t    Coeff = " << *Coeff << "\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcConst = " << *SrcConst << "\n"); +  LLVM_DEBUG(dbgs() << "\t    DstConst = " << *DstConst << "\n"); +  ++WeakCrossingSIVapplications; +  assert(0 < Level && Level <= CommonLevels && "Level out of range"); +  Level--; +  Result.Consistent = false; +  const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); +  LLVM_DEBUG(dbgs() << "\t    Delta = " << *Delta << "\n"); +  NewConstraint.setLine(Coeff, Coeff, Delta, CurLoop); +  if (Delta->isZero()) { +    Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::LT); +    Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::GT); +    ++WeakCrossingSIVsuccesses; +    if (!Result.DV[Level].Direction) { +      ++WeakCrossingSIVindependence; +      return true; +    } +    Result.DV[Level].Distance = Delta; // = 0 +    return false; +  } +  const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(Coeff); +  if (!ConstCoeff) +    return false; + +  Result.DV[Level].Splitable = true; +  if (SE->isKnownNegative(ConstCoeff)) { +    ConstCoeff = dyn_cast<SCEVConstant>(SE->getNegativeSCEV(ConstCoeff)); +    assert(ConstCoeff && +           "dynamic cast of negative of ConstCoeff should yield constant"); +    Delta = SE->getNegativeSCEV(Delta); +  } +  assert(SE->isKnownPositive(ConstCoeff) && "ConstCoeff should be positive"); + +  // compute SplitIter for use by DependenceInfo::getSplitIteration() +  SplitIter = SE->getUDivExpr( +      SE->getSMaxExpr(SE->getZero(Delta->getType()), Delta), +      SE->getMulExpr(SE->getConstant(Delta->getType(), 2), ConstCoeff)); +  LLVM_DEBUG(dbgs() << "\t    Split iter = " << *SplitIter << "\n"); + +  const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta); +  if (!ConstDelta) +    return false; + +  // We're certain that ConstCoeff > 0; therefore, +  // if Delta < 0, then no dependence. +  LLVM_DEBUG(dbgs() << "\t    Delta = " << *Delta << "\n"); +  LLVM_DEBUG(dbgs() << "\t    ConstCoeff = " << *ConstCoeff << "\n"); +  if (SE->isKnownNegative(Delta)) { +    // No dependence, Delta < 0 +    ++WeakCrossingSIVindependence; +    ++WeakCrossingSIVsuccesses; +    return true; +  } + +  // We're certain that Delta > 0 and ConstCoeff > 0. +  // Check Delta/(2*ConstCoeff) against upper loop bound +  if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { +    LLVM_DEBUG(dbgs() << "\t    UpperBound = " << *UpperBound << "\n"); +    const SCEV *ConstantTwo = SE->getConstant(UpperBound->getType(), 2); +    const SCEV *ML = SE->getMulExpr(SE->getMulExpr(ConstCoeff, UpperBound), +                                    ConstantTwo); +    LLVM_DEBUG(dbgs() << "\t    ML = " << *ML << "\n"); +    if (isKnownPredicate(CmpInst::ICMP_SGT, Delta, ML)) { +      // Delta too big, no dependence +      ++WeakCrossingSIVindependence; +      ++WeakCrossingSIVsuccesses; +      return true; +    } +    if (isKnownPredicate(CmpInst::ICMP_EQ, Delta, ML)) { +      // i = i' = UB +      Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::LT); +      Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::GT); +      ++WeakCrossingSIVsuccesses; +      if (!Result.DV[Level].Direction) { +        ++WeakCrossingSIVindependence; +        return true; +      } +      Result.DV[Level].Splitable = false; +      Result.DV[Level].Distance = SE->getZero(Delta->getType()); +      return false; +    } +  } + +  // check that Coeff divides Delta +  APInt APDelta = ConstDelta->getAPInt(); +  APInt APCoeff = ConstCoeff->getAPInt(); +  APInt Distance = APDelta; // these need to be initialzed +  APInt Remainder = APDelta; +  APInt::sdivrem(APDelta, APCoeff, Distance, Remainder); +  LLVM_DEBUG(dbgs() << "\t    Remainder = " << Remainder << "\n"); +  if (Remainder != 0) { +    // Coeff doesn't divide Delta, no dependence +    ++WeakCrossingSIVindependence; +    ++WeakCrossingSIVsuccesses; +    return true; +  } +  LLVM_DEBUG(dbgs() << "\t    Distance = " << Distance << "\n"); + +  // if 2*Coeff doesn't divide Delta, then the equal direction isn't possible +  APInt Two = APInt(Distance.getBitWidth(), 2, true); +  Remainder = Distance.srem(Two); +  LLVM_DEBUG(dbgs() << "\t    Remainder = " << Remainder << "\n"); +  if (Remainder != 0) { +    // Equal direction isn't possible +    Result.DV[Level].Direction &= unsigned(~Dependence::DVEntry::EQ); +    ++WeakCrossingSIVsuccesses; +  } +  return false; +} + + +// Kirch's algorithm, from +// +//        Optimizing Supercompilers for Supercomputers +//        Michael Wolfe +//        MIT Press, 1989 +// +// Program 2.1, page 29. +// Computes the GCD of AM and BM. +// Also finds a solution to the equation ax - by = gcd(a, b). +// Returns true if dependence disproved; i.e., gcd does not divide Delta. +static bool findGCD(unsigned Bits, const APInt &AM, const APInt &BM, +                    const APInt &Delta, APInt &G, APInt &X, APInt &Y) { +  APInt A0(Bits, 1, true), A1(Bits, 0, true); +  APInt B0(Bits, 0, true), B1(Bits, 1, true); +  APInt G0 = AM.abs(); +  APInt G1 = BM.abs(); +  APInt Q = G0; // these need to be initialized +  APInt R = G0; +  APInt::sdivrem(G0, G1, Q, R); +  while (R != 0) { +    APInt A2 = A0 - Q*A1; A0 = A1; A1 = A2; +    APInt B2 = B0 - Q*B1; B0 = B1; B1 = B2; +    G0 = G1; G1 = R; +    APInt::sdivrem(G0, G1, Q, R); +  } +  G = G1; +  LLVM_DEBUG(dbgs() << "\t    GCD = " << G << "\n"); +  X = AM.slt(0) ? -A1 : A1; +  Y = BM.slt(0) ? B1 : -B1; + +  // make sure gcd divides Delta +  R = Delta.srem(G); +  if (R != 0) +    return true; // gcd doesn't divide Delta, no dependence +  Q = Delta.sdiv(G); +  X *= Q; +  Y *= Q; +  return false; +} + +static APInt floorOfQuotient(const APInt &A, const APInt &B) { +  APInt Q = A; // these need to be initialized +  APInt R = A; +  APInt::sdivrem(A, B, Q, R); +  if (R == 0) +    return Q; +  if ((A.sgt(0) && B.sgt(0)) || +      (A.slt(0) && B.slt(0))) +    return Q; +  else +    return Q - 1; +} + +static APInt ceilingOfQuotient(const APInt &A, const APInt &B) { +  APInt Q = A; // these need to be initialized +  APInt R = A; +  APInt::sdivrem(A, B, Q, R); +  if (R == 0) +    return Q; +  if ((A.sgt(0) && B.sgt(0)) || +      (A.slt(0) && B.slt(0))) +    return Q + 1; +  else +    return Q; +} + + +static +APInt maxAPInt(APInt A, APInt B) { +  return A.sgt(B) ? A : B; +} + + +static +APInt minAPInt(APInt A, APInt B) { +  return A.slt(B) ? A : B; +} + + +// exactSIVtest - +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*i], +// where i is an induction variable, c1 and c2 are loop invariant, and a1 +// and a2 are constant, we can solve it exactly using an algorithm developed +// by Banerjee and Wolfe. See Section 2.5.3 in +// +//        Optimizing Supercompilers for Supercomputers +//        Michael Wolfe +//        MIT Press, 1989 +// +// It's slower than the specialized tests (strong SIV, weak-zero SIV, etc), +// so use them if possible. They're also a bit better with symbolics and, +// in the case of the strong SIV test, can compute Distances. +// +// Return true if dependence disproved. +bool DependenceInfo::exactSIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff, +                                  const SCEV *SrcConst, const SCEV *DstConst, +                                  const Loop *CurLoop, unsigned Level, +                                  FullDependence &Result, +                                  Constraint &NewConstraint) const { +  LLVM_DEBUG(dbgs() << "\tExact SIV test\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcCoeff = " << *SrcCoeff << " = AM\n"); +  LLVM_DEBUG(dbgs() << "\t    DstCoeff = " << *DstCoeff << " = BM\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcConst = " << *SrcConst << "\n"); +  LLVM_DEBUG(dbgs() << "\t    DstConst = " << *DstConst << "\n"); +  ++ExactSIVapplications; +  assert(0 < Level && Level <= CommonLevels && "Level out of range"); +  Level--; +  Result.Consistent = false; +  const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); +  LLVM_DEBUG(dbgs() << "\t    Delta = " << *Delta << "\n"); +  NewConstraint.setLine(SrcCoeff, SE->getNegativeSCEV(DstCoeff), +                        Delta, CurLoop); +  const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta); +  const SCEVConstant *ConstSrcCoeff = dyn_cast<SCEVConstant>(SrcCoeff); +  const SCEVConstant *ConstDstCoeff = dyn_cast<SCEVConstant>(DstCoeff); +  if (!ConstDelta || !ConstSrcCoeff || !ConstDstCoeff) +    return false; + +  // find gcd +  APInt G, X, Y; +  APInt AM = ConstSrcCoeff->getAPInt(); +  APInt BM = ConstDstCoeff->getAPInt(); +  unsigned Bits = AM.getBitWidth(); +  if (findGCD(Bits, AM, BM, ConstDelta->getAPInt(), G, X, Y)) { +    // gcd doesn't divide Delta, no dependence +    ++ExactSIVindependence; +    ++ExactSIVsuccesses; +    return true; +  } + +  LLVM_DEBUG(dbgs() << "\t    X = " << X << ", Y = " << Y << "\n"); + +  // since SCEV construction normalizes, LM = 0 +  APInt UM(Bits, 1, true); +  bool UMvalid = false; +  // UM is perhaps unavailable, let's check +  if (const SCEVConstant *CUB = +      collectConstantUpperBound(CurLoop, Delta->getType())) { +    UM = CUB->getAPInt(); +    LLVM_DEBUG(dbgs() << "\t    UM = " << UM << "\n"); +    UMvalid = true; +  } + +  APInt TU(APInt::getSignedMaxValue(Bits)); +  APInt TL(APInt::getSignedMinValue(Bits)); + +  // test(BM/G, LM-X) and test(-BM/G, X-UM) +  APInt TMUL = BM.sdiv(G); +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(-X, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    if (UMvalid) { +      TU = minAPInt(TU, floorOfQuotient(UM - X, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    } +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(-X, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    if (UMvalid) { +      TL = maxAPInt(TL, ceilingOfQuotient(UM - X, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    } +  } + +  // test(AM/G, LM-Y) and test(-AM/G, Y-UM) +  TMUL = AM.sdiv(G); +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(-Y, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    if (UMvalid) { +      TU = minAPInt(TU, floorOfQuotient(UM - Y, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    } +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(-Y, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    if (UMvalid) { +      TL = maxAPInt(TL, ceilingOfQuotient(UM - Y, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    } +  } +  if (TL.sgt(TU)) { +    ++ExactSIVindependence; +    ++ExactSIVsuccesses; +    return true; +  } + +  // explore directions +  unsigned NewDirection = Dependence::DVEntry::NONE; + +  // less than +  APInt SaveTU(TU); // save these +  APInt SaveTL(TL); +  LLVM_DEBUG(dbgs() << "\t    exploring LT direction\n"); +  TMUL = AM - BM; +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(X - Y + 1, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TL = " << TL << "\n"); +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(X - Y + 1, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TU = " << TU << "\n"); +  } +  if (TL.sle(TU)) { +    NewDirection |= Dependence::DVEntry::LT; +    ++ExactSIVsuccesses; +  } + +  // equal +  TU = SaveTU; // restore +  TL = SaveTL; +  LLVM_DEBUG(dbgs() << "\t    exploring EQ direction\n"); +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(X - Y, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TL = " << TL << "\n"); +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(X - Y, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TU = " << TU << "\n"); +  } +  TMUL = BM - AM; +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(Y - X, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TL = " << TL << "\n"); +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(Y - X, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TU = " << TU << "\n"); +  } +  if (TL.sle(TU)) { +    NewDirection |= Dependence::DVEntry::EQ; +    ++ExactSIVsuccesses; +  } + +  // greater than +  TU = SaveTU; // restore +  TL = SaveTL; +  LLVM_DEBUG(dbgs() << "\t    exploring GT direction\n"); +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(Y - X + 1, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TL = " << TL << "\n"); +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(Y - X + 1, TMUL)); +    LLVM_DEBUG(dbgs() << "\t\t    TU = " << TU << "\n"); +  } +  if (TL.sle(TU)) { +    NewDirection |= Dependence::DVEntry::GT; +    ++ExactSIVsuccesses; +  } + +  // finished +  Result.DV[Level].Direction &= NewDirection; +  if (Result.DV[Level].Direction == Dependence::DVEntry::NONE) +    ++ExactSIVindependence; +  return Result.DV[Level].Direction == Dependence::DVEntry::NONE; +} + + + +// Return true if the divisor evenly divides the dividend. +static +bool isRemainderZero(const SCEVConstant *Dividend, +                     const SCEVConstant *Divisor) { +  const APInt &ConstDividend = Dividend->getAPInt(); +  const APInt &ConstDivisor = Divisor->getAPInt(); +  return ConstDividend.srem(ConstDivisor) == 0; +} + + +// weakZeroSrcSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.2 +// +// When we have a pair of subscripts of the form [c1] and [c2 + a*i], +// where i is an induction variable, c1 and c2 are loop invariant, +// and a is a constant, we can solve it exactly using the +// Weak-Zero SIV test. +// +// Given +// +//    c1 = c2 + a*i +// +// we get +// +//    (c1 - c2)/a = i +// +// If i is not an integer, there's no dependence. +// If i < 0 or > UB, there's no dependence. +// If i = 0, the direction is >= and peeling the +// 1st iteration will break the dependence. +// If i = UB, the direction is <= and peeling the +// last iteration will break the dependence. +// Otherwise, the direction is *. +// +// Can prove independence. Failing that, we can sometimes refine +// the directions. Can sometimes show that first or last +// iteration carries all the dependences (so worth peeling). +// +// (see also weakZeroDstSIVtest) +// +// Return true if dependence disproved. +bool DependenceInfo::weakZeroSrcSIVtest(const SCEV *DstCoeff, +                                        const SCEV *SrcConst, +                                        const SCEV *DstConst, +                                        const Loop *CurLoop, unsigned Level, +                                        FullDependence &Result, +                                        Constraint &NewConstraint) const { +  // For the WeakSIV test, it's possible the loop isn't common to +  // the Src and Dst loops. If it isn't, then there's no need to +  // record a direction. +  LLVM_DEBUG(dbgs() << "\tWeak-Zero (src) SIV test\n"); +  LLVM_DEBUG(dbgs() << "\t    DstCoeff = " << *DstCoeff << "\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcConst = " << *SrcConst << "\n"); +  LLVM_DEBUG(dbgs() << "\t    DstConst = " << *DstConst << "\n"); +  ++WeakZeroSIVapplications; +  assert(0 < Level && Level <= MaxLevels && "Level out of range"); +  Level--; +  Result.Consistent = false; +  const SCEV *Delta = SE->getMinusSCEV(SrcConst, DstConst); +  NewConstraint.setLine(SE->getZero(Delta->getType()), DstCoeff, Delta, +                        CurLoop); +  LLVM_DEBUG(dbgs() << "\t    Delta = " << *Delta << "\n"); +  if (isKnownPredicate(CmpInst::ICMP_EQ, SrcConst, DstConst)) { +    if (Level < CommonLevels) { +      Result.DV[Level].Direction &= Dependence::DVEntry::GE; +      Result.DV[Level].PeelFirst = true; +      ++WeakZeroSIVsuccesses; +    } +    return false; // dependences caused by first iteration +  } +  const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(DstCoeff); +  if (!ConstCoeff) +    return false; +  const SCEV *AbsCoeff = +    SE->isKnownNegative(ConstCoeff) ? +    SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; +  const SCEV *NewDelta = +    SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + +  // check that Delta/SrcCoeff < iteration count +  // really check NewDelta < count*AbsCoeff +  if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { +    LLVM_DEBUG(dbgs() << "\t    UpperBound = " << *UpperBound << "\n"); +    const SCEV *Product = SE->getMulExpr(AbsCoeff, UpperBound); +    if (isKnownPredicate(CmpInst::ICMP_SGT, NewDelta, Product)) { +      ++WeakZeroSIVindependence; +      ++WeakZeroSIVsuccesses; +      return true; +    } +    if (isKnownPredicate(CmpInst::ICMP_EQ, NewDelta, Product)) { +      // dependences caused by last iteration +      if (Level < CommonLevels) { +        Result.DV[Level].Direction &= Dependence::DVEntry::LE; +        Result.DV[Level].PeelLast = true; +        ++WeakZeroSIVsuccesses; +      } +      return false; +    } +  } + +  // check that Delta/SrcCoeff >= 0 +  // really check that NewDelta >= 0 +  if (SE->isKnownNegative(NewDelta)) { +    // No dependence, newDelta < 0 +    ++WeakZeroSIVindependence; +    ++WeakZeroSIVsuccesses; +    return true; +  } + +  // if SrcCoeff doesn't divide Delta, then no dependence +  if (isa<SCEVConstant>(Delta) && +      !isRemainderZero(cast<SCEVConstant>(Delta), ConstCoeff)) { +    ++WeakZeroSIVindependence; +    ++WeakZeroSIVsuccesses; +    return true; +  } +  return false; +} + + +// weakZeroDstSIVtest - +// From the paper, Practical Dependence Testing, Section 4.2.2 +// +// When we have a pair of subscripts of the form [c1 + a*i] and [c2], +// where i is an induction variable, c1 and c2 are loop invariant, +// and a is a constant, we can solve it exactly using the +// Weak-Zero SIV test. +// +// Given +// +//    c1 + a*i = c2 +// +// we get +// +//    i = (c2 - c1)/a +// +// If i is not an integer, there's no dependence. +// If i < 0 or > UB, there's no dependence. +// If i = 0, the direction is <= and peeling the +// 1st iteration will break the dependence. +// If i = UB, the direction is >= and peeling the +// last iteration will break the dependence. +// Otherwise, the direction is *. +// +// Can prove independence. Failing that, we can sometimes refine +// the directions. Can sometimes show that first or last +// iteration carries all the dependences (so worth peeling). +// +// (see also weakZeroSrcSIVtest) +// +// Return true if dependence disproved. +bool DependenceInfo::weakZeroDstSIVtest(const SCEV *SrcCoeff, +                                        const SCEV *SrcConst, +                                        const SCEV *DstConst, +                                        const Loop *CurLoop, unsigned Level, +                                        FullDependence &Result, +                                        Constraint &NewConstraint) const { +  // For the WeakSIV test, it's possible the loop isn't common to the +  // Src and Dst loops. If it isn't, then there's no need to record a direction. +  LLVM_DEBUG(dbgs() << "\tWeak-Zero (dst) SIV test\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcCoeff = " << *SrcCoeff << "\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcConst = " << *SrcConst << "\n"); +  LLVM_DEBUG(dbgs() << "\t    DstConst = " << *DstConst << "\n"); +  ++WeakZeroSIVapplications; +  assert(0 < Level && Level <= SrcLevels && "Level out of range"); +  Level--; +  Result.Consistent = false; +  const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); +  NewConstraint.setLine(SrcCoeff, SE->getZero(Delta->getType()), Delta, +                        CurLoop); +  LLVM_DEBUG(dbgs() << "\t    Delta = " << *Delta << "\n"); +  if (isKnownPredicate(CmpInst::ICMP_EQ, DstConst, SrcConst)) { +    if (Level < CommonLevels) { +      Result.DV[Level].Direction &= Dependence::DVEntry::LE; +      Result.DV[Level].PeelFirst = true; +      ++WeakZeroSIVsuccesses; +    } +    return false; // dependences caused by first iteration +  } +  const SCEVConstant *ConstCoeff = dyn_cast<SCEVConstant>(SrcCoeff); +  if (!ConstCoeff) +    return false; +  const SCEV *AbsCoeff = +    SE->isKnownNegative(ConstCoeff) ? +    SE->getNegativeSCEV(ConstCoeff) : ConstCoeff; +  const SCEV *NewDelta = +    SE->isKnownNegative(ConstCoeff) ? SE->getNegativeSCEV(Delta) : Delta; + +  // check that Delta/SrcCoeff < iteration count +  // really check NewDelta < count*AbsCoeff +  if (const SCEV *UpperBound = collectUpperBound(CurLoop, Delta->getType())) { +    LLVM_DEBUG(dbgs() << "\t    UpperBound = " << *UpperBound << "\n"); +    const SCEV *Product = SE->getMulExpr(AbsCoeff, UpperBound); +    if (isKnownPredicate(CmpInst::ICMP_SGT, NewDelta, Product)) { +      ++WeakZeroSIVindependence; +      ++WeakZeroSIVsuccesses; +      return true; +    } +    if (isKnownPredicate(CmpInst::ICMP_EQ, NewDelta, Product)) { +      // dependences caused by last iteration +      if (Level < CommonLevels) { +        Result.DV[Level].Direction &= Dependence::DVEntry::GE; +        Result.DV[Level].PeelLast = true; +        ++WeakZeroSIVsuccesses; +      } +      return false; +    } +  } + +  // check that Delta/SrcCoeff >= 0 +  // really check that NewDelta >= 0 +  if (SE->isKnownNegative(NewDelta)) { +    // No dependence, newDelta < 0 +    ++WeakZeroSIVindependence; +    ++WeakZeroSIVsuccesses; +    return true; +  } + +  // if SrcCoeff doesn't divide Delta, then no dependence +  if (isa<SCEVConstant>(Delta) && +      !isRemainderZero(cast<SCEVConstant>(Delta), ConstCoeff)) { +    ++WeakZeroSIVindependence; +    ++WeakZeroSIVsuccesses; +    return true; +  } +  return false; +} + + +// exactRDIVtest - Tests the RDIV subscript pair for dependence. +// Things of the form [c1 + a*i] and [c2 + b*j], +// where i and j are induction variable, c1 and c2 are loop invariant, +// and a and b are constants. +// Returns true if any possible dependence is disproved. +// Marks the result as inconsistent. +// Works in some cases that symbolicRDIVtest doesn't, and vice versa. +bool DependenceInfo::exactRDIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff, +                                   const SCEV *SrcConst, const SCEV *DstConst, +                                   const Loop *SrcLoop, const Loop *DstLoop, +                                   FullDependence &Result) const { +  LLVM_DEBUG(dbgs() << "\tExact RDIV test\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcCoeff = " << *SrcCoeff << " = AM\n"); +  LLVM_DEBUG(dbgs() << "\t    DstCoeff = " << *DstCoeff << " = BM\n"); +  LLVM_DEBUG(dbgs() << "\t    SrcConst = " << *SrcConst << "\n"); +  LLVM_DEBUG(dbgs() << "\t    DstConst = " << *DstConst << "\n"); +  ++ExactRDIVapplications; +  Result.Consistent = false; +  const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); +  LLVM_DEBUG(dbgs() << "\t    Delta = " << *Delta << "\n"); +  const SCEVConstant *ConstDelta = dyn_cast<SCEVConstant>(Delta); +  const SCEVConstant *ConstSrcCoeff = dyn_cast<SCEVConstant>(SrcCoeff); +  const SCEVConstant *ConstDstCoeff = dyn_cast<SCEVConstant>(DstCoeff); +  if (!ConstDelta || !ConstSrcCoeff || !ConstDstCoeff) +    return false; + +  // find gcd +  APInt G, X, Y; +  APInt AM = ConstSrcCoeff->getAPInt(); +  APInt BM = ConstDstCoeff->getAPInt(); +  unsigned Bits = AM.getBitWidth(); +  if (findGCD(Bits, AM, BM, ConstDelta->getAPInt(), G, X, Y)) { +    // gcd doesn't divide Delta, no dependence +    ++ExactRDIVindependence; +    return true; +  } + +  LLVM_DEBUG(dbgs() << "\t    X = " << X << ", Y = " << Y << "\n"); + +  // since SCEV construction seems to normalize, LM = 0 +  APInt SrcUM(Bits, 1, true); +  bool SrcUMvalid = false; +  // SrcUM is perhaps unavailable, let's check +  if (const SCEVConstant *UpperBound = +      collectConstantUpperBound(SrcLoop, Delta->getType())) { +    SrcUM = UpperBound->getAPInt(); +    LLVM_DEBUG(dbgs() << "\t    SrcUM = " << SrcUM << "\n"); +    SrcUMvalid = true; +  } + +  APInt DstUM(Bits, 1, true); +  bool DstUMvalid = false; +  // UM is perhaps unavailable, let's check +  if (const SCEVConstant *UpperBound = +      collectConstantUpperBound(DstLoop, Delta->getType())) { +    DstUM = UpperBound->getAPInt(); +    LLVM_DEBUG(dbgs() << "\t    DstUM = " << DstUM << "\n"); +    DstUMvalid = true; +  } + +  APInt TU(APInt::getSignedMaxValue(Bits)); +  APInt TL(APInt::getSignedMinValue(Bits)); + +  // test(BM/G, LM-X) and test(-BM/G, X-UM) +  APInt TMUL = BM.sdiv(G); +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(-X, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    if (SrcUMvalid) { +      TU = minAPInt(TU, floorOfQuotient(SrcUM - X, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    } +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(-X, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    if (SrcUMvalid) { +      TL = maxAPInt(TL, ceilingOfQuotient(SrcUM - X, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    } +  } + +  // test(AM/G, LM-Y) and test(-AM/G, Y-UM) +  TMUL = AM.sdiv(G); +  if (TMUL.sgt(0)) { +    TL = maxAPInt(TL, ceilingOfQuotient(-Y, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    if (DstUMvalid) { +      TU = minAPInt(TU, floorOfQuotient(DstUM - Y, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    } +  } +  else { +    TU = minAPInt(TU, floorOfQuotient(-Y, TMUL)); +    LLVM_DEBUG(dbgs() << "\t    TU = " << TU << "\n"); +    if (DstUMvalid) { +      TL = maxAPInt(TL, ceilingOfQuotient(DstUM - Y, TMUL)); +      LLVM_DEBUG(dbgs() << "\t    TL = " << TL << "\n"); +    } +  } +  if (TL.sgt(TU)) +    ++ExactRDIVindependence; +  return TL.sgt(TU); +} + + +// symbolicRDIVtest - +// In Section 4.5 of the Practical Dependence Testing paper,the authors +// introduce a special case of Banerjee's Inequalities (also called the +// Extreme-Value Test) that can handle some of the SIV and RDIV cases, +// particularly cases with symbolics. Since it's only able to disprove +// dependence (not compute distances or directions), we'll use it as a +// fall back for the other tests. +// +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*j] +// where i and j are induction variables and c1 and c2 are loop invariants, +// we can use the symbolic tests to disprove some dependences, serving as a +// backup for the RDIV test. Note that i and j can be the same variable, +// letting this test serve as a backup for the various SIV tests. +// +// For a dependence to exist, c1 + a1*i must equal c2 + a2*j for some +//  0 <= i <= N1 and some 0 <= j <= N2, where N1 and N2 are the (normalized) +// loop bounds for the i and j loops, respectively. So, ... +// +// c1 + a1*i = c2 + a2*j +// a1*i - a2*j = c2 - c1 +// +// To test for a dependence, we compute c2 - c1 and make sure it's in the +// range of the maximum and minimum possible values of a1*i - a2*j. +// Considering the signs of a1 and a2, we have 4 possible cases: +// +// 1) If a1 >= 0 and a2 >= 0, then +//        a1*0 - a2*N2 <= c2 - c1 <= a1*N1 - a2*0 +//              -a2*N2 <= c2 - c1 <= a1*N1 +// +// 2) If a1 >= 0 and a2 <= 0, then +//        a1*0 - a2*0 <= c2 - c1 <= a1*N1 - a2*N2 +//                  0 <= c2 - c1 <= a1*N1 - a2*N2 +// +// 3) If a1 <= 0 and a2 >= 0, then +//        a1*N1 - a2*N2 <= c2 - c1 <= a1*0 - a2*0 +//        a1*N1 - a2*N2 <= c2 - c1 <= 0 +// +// 4) If a1 <= 0 and a2 <= 0, then +//        a1*N1 - a2*0  <= c2 - c1 <= a1*0 - a2*N2 +//        a1*N1         <= c2 - c1 <=       -a2*N2 +// +// return true if dependence disproved +bool DependenceInfo::symbolicRDIVtest(const SCEV *A1, const SCEV *A2, +                                      const SCEV *C1, const SCEV *C2, +                                      const Loop *Loop1, +                                      const Loop *Loop2) const { +  ++SymbolicRDIVapplications; +  LLVM_DEBUG(dbgs() << "\ttry symbolic RDIV test\n"); +  LLVM_DEBUG(dbgs() << "\t    A1 = " << *A1); +  LLVM_DEBUG(dbgs() << ", type = " << *A1->getType() << "\n"); +  LLVM_DEBUG(dbgs() << "\t    A2 = " << *A2 << "\n"); +  LLVM_DEBUG(dbgs() << "\t    C1 = " << *C1 << "\n"); +  LLVM_DEBUG(dbgs() << "\t    C2 = " << *C2 << "\n"); +  const SCEV *N1 = collectUpperBound(Loop1, A1->getType()); +  const SCEV *N2 = collectUpperBound(Loop2, A1->getType()); +  LLVM_DEBUG(if (N1) dbgs() << "\t    N1 = " << *N1 << "\n"); +  LLVM_DEBUG(if (N2) dbgs() << "\t    N2 = " << *N2 << "\n"); +  const SCEV *C2_C1 = SE->getMinusSCEV(C2, C1); +  const SCEV *C1_C2 = SE->getMinusSCEV(C1, C2); +  LLVM_DEBUG(dbgs() << "\t    C2 - C1 = " << *C2_C1 << "\n"); +  LLVM_DEBUG(dbgs() << "\t    C1 - C2 = " << *C1_C2 << "\n"); +  if (SE->isKnownNonNegative(A1)) { +    if (SE->isKnownNonNegative(A2)) { +      // A1 >= 0 && A2 >= 0 +      if (N1) { +        // make sure that c2 - c1 <= a1*N1 +        const SCEV *A1N1 = SE->getMulExpr(A1, N1); +        LLVM_DEBUG(dbgs() << "\t    A1*N1 = " << *A1N1 << "\n"); +        if (isKnownPredicate(CmpInst::ICMP_SGT, C2_C1, A1N1)) { +          ++SymbolicRDIVindependence; +          return true; +        } +      } +      if (N2) { +        // make sure that -a2*N2 <= c2 - c1, or a2*N2 >= c1 - c2 +        const SCEV *A2N2 = SE->getMulExpr(A2, N2); +        LLVM_DEBUG(dbgs() << "\t    A2*N2 = " << *A2N2 << "\n"); +        if (isKnownPredicate(CmpInst::ICMP_SLT, A2N2, C1_C2)) { +          ++SymbolicRDIVindependence; +          return true; +        } +      } +    } +    else if (SE->isKnownNonPositive(A2)) { +      // a1 >= 0 && a2 <= 0 +      if (N1 && N2) { +        // make sure that c2 - c1 <= a1*N1 - a2*N2 +        const SCEV *A1N1 = SE->getMulExpr(A1, N1); +        const SCEV *A2N2 = SE->getMulExpr(A2, N2); +        const SCEV *A1N1_A2N2 = SE->getMinusSCEV(A1N1, A2N2); +        LLVM_DEBUG(dbgs() << "\t    A1*N1 - A2*N2 = " << *A1N1_A2N2 << "\n"); +        if (isKnownPredicate(CmpInst::ICMP_SGT, C2_C1, A1N1_A2N2)) { +          ++SymbolicRDIVindependence; +          return true; +        } +      } +      // make sure that 0 <= c2 - c1 +      if (SE->isKnownNegative(C2_C1)) { +        ++SymbolicRDIVindependence; +        return true; +      } +    } +  } +  else if (SE->isKnownNonPositive(A1)) { +    if (SE->isKnownNonNegative(A2)) { +      // a1 <= 0 && a2 >= 0 +      if (N1 && N2) { +        // make sure that a1*N1 - a2*N2 <= c2 - c1 +        const SCEV *A1N1 = SE->getMulExpr(A1, N1); +        const SCEV *A2N2 = SE->getMulExpr(A2, N2); +        const SCEV *A1N1_A2N2 = SE->getMinusSCEV(A1N1, A2N2); +        LLVM_DEBUG(dbgs() << "\t    A1*N1 - A2*N2 = " << *A1N1_A2N2 << "\n"); +        if (isKnownPredicate(CmpInst::ICMP_SGT, A1N1_A2N2, C2_C1)) { +          ++SymbolicRDIVindependence; +          return true; +        } +      } +      // make sure that c2 - c1 <= 0 +      if (SE->isKnownPositive(C2_C1)) { +        ++SymbolicRDIVindependence; +        return true; +      } +    } +    else if (SE->isKnownNonPositive(A2)) { +      // a1 <= 0 && a2 <= 0 +      if (N1) { +        // make sure that a1*N1 <= c2 - c1 +        const SCEV *A1N1 = SE->getMulExpr(A1, N1); +        LLVM_DEBUG(dbgs() << "\t    A1*N1 = " << *A1N1 << "\n"); +        if (isKnownPredicate(CmpInst::ICMP_SGT, A1N1, C2_C1)) { +          ++SymbolicRDIVindependence; +          return true; +        } +      } +      if (N2) { +        // make sure that c2 - c1 <= -a2*N2, or c1 - c2 >= a2*N2 +        const SCEV *A2N2 = SE->getMulExpr(A2, N2); +        LLVM_DEBUG(dbgs() << "\t    A2*N2 = " << *A2N2 << "\n"); +        if (isKnownPredicate(CmpInst::ICMP_SLT, C1_C2, A2N2)) { +          ++SymbolicRDIVindependence; +          return true; +        } +      } +    } +  } +  return false; +} + + +// testSIV - +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 - a2*i] +// where i is an induction variable, c1 and c2 are loop invariant, and a1 and +// a2 are constant, we attack it with an SIV test. While they can all be +// solved with the Exact SIV test, it's worthwhile to use simpler tests when +// they apply; they're cheaper and sometimes more precise. +// +// Return true if dependence disproved. +bool DependenceInfo::testSIV(const SCEV *Src, const SCEV *Dst, unsigned &Level, +                             FullDependence &Result, Constraint &NewConstraint, +                             const SCEV *&SplitIter) const { +  LLVM_DEBUG(dbgs() << "    src = " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "    dst = " << *Dst << "\n"); +  const SCEVAddRecExpr *SrcAddRec = dyn_cast<SCEVAddRecExpr>(Src); +  const SCEVAddRecExpr *DstAddRec = dyn_cast<SCEVAddRecExpr>(Dst); +  if (SrcAddRec && DstAddRec) { +    const SCEV *SrcConst = SrcAddRec->getStart(); +    const SCEV *DstConst = DstAddRec->getStart(); +    const SCEV *SrcCoeff = SrcAddRec->getStepRecurrence(*SE); +    const SCEV *DstCoeff = DstAddRec->getStepRecurrence(*SE); +    const Loop *CurLoop = SrcAddRec->getLoop(); +    assert(CurLoop == DstAddRec->getLoop() && +           "both loops in SIV should be same"); +    Level = mapSrcLoop(CurLoop); +    bool disproven; +    if (SrcCoeff == DstCoeff) +      disproven = strongSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop, +                                Level, Result, NewConstraint); +    else if (SrcCoeff == SE->getNegativeSCEV(DstCoeff)) +      disproven = weakCrossingSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop, +                                      Level, Result, NewConstraint, SplitIter); +    else +      disproven = exactSIVtest(SrcCoeff, DstCoeff, SrcConst, DstConst, CurLoop, +                               Level, Result, NewConstraint); +    return disproven || +      gcdMIVtest(Src, Dst, Result) || +      symbolicRDIVtest(SrcCoeff, DstCoeff, SrcConst, DstConst, CurLoop, CurLoop); +  } +  if (SrcAddRec) { +    const SCEV *SrcConst = SrcAddRec->getStart(); +    const SCEV *SrcCoeff = SrcAddRec->getStepRecurrence(*SE); +    const SCEV *DstConst = Dst; +    const Loop *CurLoop = SrcAddRec->getLoop(); +    Level = mapSrcLoop(CurLoop); +    return weakZeroDstSIVtest(SrcCoeff, SrcConst, DstConst, CurLoop, +                              Level, Result, NewConstraint) || +      gcdMIVtest(Src, Dst, Result); +  } +  if (DstAddRec) { +    const SCEV *DstConst = DstAddRec->getStart(); +    const SCEV *DstCoeff = DstAddRec->getStepRecurrence(*SE); +    const SCEV *SrcConst = Src; +    const Loop *CurLoop = DstAddRec->getLoop(); +    Level = mapDstLoop(CurLoop); +    return weakZeroSrcSIVtest(DstCoeff, SrcConst, DstConst, +                              CurLoop, Level, Result, NewConstraint) || +      gcdMIVtest(Src, Dst, Result); +  } +  llvm_unreachable("SIV test expected at least one AddRec"); +  return false; +} + + +// testRDIV - +// When we have a pair of subscripts of the form [c1 + a1*i] and [c2 + a2*j] +// where i and j are induction variables, c1 and c2 are loop invariant, +// and a1 and a2 are constant, we can solve it exactly with an easy adaptation +// of the Exact SIV test, the Restricted Double Index Variable (RDIV) test. +// It doesn't make sense to talk about distance or direction in this case, +// so there's no point in making special versions of the Strong SIV test or +// the Weak-crossing SIV test. +// +// With minor algebra, this test can also be used for things like +// [c1 + a1*i + a2*j][c2]. +// +// Return true if dependence disproved. +bool DependenceInfo::testRDIV(const SCEV *Src, const SCEV *Dst, +                              FullDependence &Result) const { +  // we have 3 possible situations here: +  //   1) [a*i + b] and [c*j + d] +  //   2) [a*i + c*j + b] and [d] +  //   3) [b] and [a*i + c*j + d] +  // We need to find what we've got and get organized + +  const SCEV *SrcConst, *DstConst; +  const SCEV *SrcCoeff, *DstCoeff; +  const Loop *SrcLoop, *DstLoop; + +  LLVM_DEBUG(dbgs() << "    src = " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "    dst = " << *Dst << "\n"); +  const SCEVAddRecExpr *SrcAddRec = dyn_cast<SCEVAddRecExpr>(Src); +  const SCEVAddRecExpr *DstAddRec = dyn_cast<SCEVAddRecExpr>(Dst); +  if (SrcAddRec && DstAddRec) { +    SrcConst = SrcAddRec->getStart(); +    SrcCoeff = SrcAddRec->getStepRecurrence(*SE); +    SrcLoop = SrcAddRec->getLoop(); +    DstConst = DstAddRec->getStart(); +    DstCoeff = DstAddRec->getStepRecurrence(*SE); +    DstLoop = DstAddRec->getLoop(); +  } +  else if (SrcAddRec) { +    if (const SCEVAddRecExpr *tmpAddRec = +        dyn_cast<SCEVAddRecExpr>(SrcAddRec->getStart())) { +      SrcConst = tmpAddRec->getStart(); +      SrcCoeff = tmpAddRec->getStepRecurrence(*SE); +      SrcLoop = tmpAddRec->getLoop(); +      DstConst = Dst; +      DstCoeff = SE->getNegativeSCEV(SrcAddRec->getStepRecurrence(*SE)); +      DstLoop = SrcAddRec->getLoop(); +    } +    else +      llvm_unreachable("RDIV reached by surprising SCEVs"); +  } +  else if (DstAddRec) { +    if (const SCEVAddRecExpr *tmpAddRec = +        dyn_cast<SCEVAddRecExpr>(DstAddRec->getStart())) { +      DstConst = tmpAddRec->getStart(); +      DstCoeff = tmpAddRec->getStepRecurrence(*SE); +      DstLoop = tmpAddRec->getLoop(); +      SrcConst = Src; +      SrcCoeff = SE->getNegativeSCEV(DstAddRec->getStepRecurrence(*SE)); +      SrcLoop = DstAddRec->getLoop(); +    } +    else +      llvm_unreachable("RDIV reached by surprising SCEVs"); +  } +  else +    llvm_unreachable("RDIV expected at least one AddRec"); +  return exactRDIVtest(SrcCoeff, DstCoeff, +                       SrcConst, DstConst, +                       SrcLoop, DstLoop, +                       Result) || +    gcdMIVtest(Src, Dst, Result) || +    symbolicRDIVtest(SrcCoeff, DstCoeff, +                     SrcConst, DstConst, +                     SrcLoop, DstLoop); +} + + +// Tests the single-subscript MIV pair (Src and Dst) for dependence. +// Return true if dependence disproved. +// Can sometimes refine direction vectors. +bool DependenceInfo::testMIV(const SCEV *Src, const SCEV *Dst, +                             const SmallBitVector &Loops, +                             FullDependence &Result) const { +  LLVM_DEBUG(dbgs() << "    src = " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "    dst = " << *Dst << "\n"); +  Result.Consistent = false; +  return gcdMIVtest(Src, Dst, Result) || +    banerjeeMIVtest(Src, Dst, Loops, Result); +} + + +// Given a product, e.g., 10*X*Y, returns the first constant operand, +// in this case 10. If there is no constant part, returns NULL. +static +const SCEVConstant *getConstantPart(const SCEV *Expr) { +  if (const auto *Constant = dyn_cast<SCEVConstant>(Expr)) +    return Constant; +  else if (const auto *Product = dyn_cast<SCEVMulExpr>(Expr)) +    if (const auto *Constant = dyn_cast<SCEVConstant>(Product->getOperand(0))) +      return Constant; +  return nullptr; +} + + +//===----------------------------------------------------------------------===// +// gcdMIVtest - +// Tests an MIV subscript pair for dependence. +// Returns true if any possible dependence is disproved. +// Marks the result as inconsistent. +// Can sometimes disprove the equal direction for 1 or more loops, +// as discussed in Michael Wolfe's book, +// High Performance Compilers for Parallel Computing, page 235. +// +// We spend some effort (code!) to handle cases like +// [10*i + 5*N*j + 15*M + 6], where i and j are induction variables, +// but M and N are just loop-invariant variables. +// This should help us handle linearized subscripts; +// also makes this test a useful backup to the various SIV tests. +// +// It occurs to me that the presence of loop-invariant variables +// changes the nature of the test from "greatest common divisor" +// to "a common divisor". +bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst, +                                FullDependence &Result) const { +  LLVM_DEBUG(dbgs() << "starting gcd\n"); +  ++GCDapplications; +  unsigned BitWidth = SE->getTypeSizeInBits(Src->getType()); +  APInt RunningGCD = APInt::getNullValue(BitWidth); + +  // Examine Src coefficients. +  // Compute running GCD and record source constant. +  // Because we're looking for the constant at the end of the chain, +  // we can't quit the loop just because the GCD == 1. +  const SCEV *Coefficients = Src; +  while (const SCEVAddRecExpr *AddRec = +         dyn_cast<SCEVAddRecExpr>(Coefficients)) { +    const SCEV *Coeff = AddRec->getStepRecurrence(*SE); +    // If the coefficient is the product of a constant and other stuff, +    // we can use the constant in the GCD computation. +    const auto *Constant = getConstantPart(Coeff); +    if (!Constant) +      return false; +    APInt ConstCoeff = Constant->getAPInt(); +    RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); +    Coefficients = AddRec->getStart(); +  } +  const SCEV *SrcConst = Coefficients; + +  // Examine Dst coefficients. +  // Compute running GCD and record destination constant. +  // Because we're looking for the constant at the end of the chain, +  // we can't quit the loop just because the GCD == 1. +  Coefficients = Dst; +  while (const SCEVAddRecExpr *AddRec = +         dyn_cast<SCEVAddRecExpr>(Coefficients)) { +    const SCEV *Coeff = AddRec->getStepRecurrence(*SE); +    // If the coefficient is the product of a constant and other stuff, +    // we can use the constant in the GCD computation. +    const auto *Constant = getConstantPart(Coeff); +    if (!Constant) +      return false; +    APInt ConstCoeff = Constant->getAPInt(); +    RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); +    Coefficients = AddRec->getStart(); +  } +  const SCEV *DstConst = Coefficients; + +  APInt ExtraGCD = APInt::getNullValue(BitWidth); +  const SCEV *Delta = SE->getMinusSCEV(DstConst, SrcConst); +  LLVM_DEBUG(dbgs() << "    Delta = " << *Delta << "\n"); +  const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Delta); +  if (const SCEVAddExpr *Sum = dyn_cast<SCEVAddExpr>(Delta)) { +    // If Delta is a sum of products, we may be able to make further progress. +    for (unsigned Op = 0, Ops = Sum->getNumOperands(); Op < Ops; Op++) { +      const SCEV *Operand = Sum->getOperand(Op); +      if (isa<SCEVConstant>(Operand)) { +        assert(!Constant && "Surprised to find multiple constants"); +        Constant = cast<SCEVConstant>(Operand); +      } +      else if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Operand)) { +        // Search for constant operand to participate in GCD; +        // If none found; return false. +        const SCEVConstant *ConstOp = getConstantPart(Product); +        if (!ConstOp) +          return false; +        APInt ConstOpValue = ConstOp->getAPInt(); +        ExtraGCD = APIntOps::GreatestCommonDivisor(ExtraGCD, +                                                   ConstOpValue.abs()); +      } +      else +        return false; +    } +  } +  if (!Constant) +    return false; +  APInt ConstDelta = cast<SCEVConstant>(Constant)->getAPInt(); +  LLVM_DEBUG(dbgs() << "    ConstDelta = " << ConstDelta << "\n"); +  if (ConstDelta == 0) +    return false; +  RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ExtraGCD); +  LLVM_DEBUG(dbgs() << "    RunningGCD = " << RunningGCD << "\n"); +  APInt Remainder = ConstDelta.srem(RunningGCD); +  if (Remainder != 0) { +    ++GCDindependence; +    return true; +  } + +  // Try to disprove equal directions. +  // For example, given a subscript pair [3*i + 2*j] and [i' + 2*j' - 1], +  // the code above can't disprove the dependence because the GCD = 1. +  // So we consider what happen if i = i' and what happens if j = j'. +  // If i = i', we can simplify the subscript to [2*i + 2*j] and [2*j' - 1], +  // which is infeasible, so we can disallow the = direction for the i level. +  // Setting j = j' doesn't help matters, so we end up with a direction vector +  // of [<>, *] +  // +  // Given A[5*i + 10*j*M + 9*M*N] and A[15*i + 20*j*M - 21*N*M + 5], +  // we need to remember that the constant part is 5 and the RunningGCD should +  // be initialized to ExtraGCD = 30. +  LLVM_DEBUG(dbgs() << "    ExtraGCD = " << ExtraGCD << '\n'); + +  bool Improved = false; +  Coefficients = Src; +  while (const SCEVAddRecExpr *AddRec = +         dyn_cast<SCEVAddRecExpr>(Coefficients)) { +    Coefficients = AddRec->getStart(); +    const Loop *CurLoop = AddRec->getLoop(); +    RunningGCD = ExtraGCD; +    const SCEV *SrcCoeff = AddRec->getStepRecurrence(*SE); +    const SCEV *DstCoeff = SE->getMinusSCEV(SrcCoeff, SrcCoeff); +    const SCEV *Inner = Src; +    while (RunningGCD != 1 && isa<SCEVAddRecExpr>(Inner)) { +      AddRec = cast<SCEVAddRecExpr>(Inner); +      const SCEV *Coeff = AddRec->getStepRecurrence(*SE); +      if (CurLoop == AddRec->getLoop()) +        ; // SrcCoeff == Coeff +      else { +        // If the coefficient is the product of a constant and other stuff, +        // we can use the constant in the GCD computation. +        Constant = getConstantPart(Coeff); +        if (!Constant) +          return false; +        APInt ConstCoeff = Constant->getAPInt(); +        RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); +      } +      Inner = AddRec->getStart(); +    } +    Inner = Dst; +    while (RunningGCD != 1 && isa<SCEVAddRecExpr>(Inner)) { +      AddRec = cast<SCEVAddRecExpr>(Inner); +      const SCEV *Coeff = AddRec->getStepRecurrence(*SE); +      if (CurLoop == AddRec->getLoop()) +        DstCoeff = Coeff; +      else { +        // If the coefficient is the product of a constant and other stuff, +        // we can use the constant in the GCD computation. +        Constant = getConstantPart(Coeff); +        if (!Constant) +          return false; +        APInt ConstCoeff = Constant->getAPInt(); +        RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); +      } +      Inner = AddRec->getStart(); +    } +    Delta = SE->getMinusSCEV(SrcCoeff, DstCoeff); +    // If the coefficient is the product of a constant and other stuff, +    // we can use the constant in the GCD computation. +    Constant = getConstantPart(Delta); +    if (!Constant) +      // The difference of the two coefficients might not be a product +      // or constant, in which case we give up on this direction. +      continue; +    APInt ConstCoeff = Constant->getAPInt(); +    RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); +    LLVM_DEBUG(dbgs() << "\tRunningGCD = " << RunningGCD << "\n"); +    if (RunningGCD != 0) { +      Remainder = ConstDelta.srem(RunningGCD); +      LLVM_DEBUG(dbgs() << "\tRemainder = " << Remainder << "\n"); +      if (Remainder != 0) { +        unsigned Level = mapSrcLoop(CurLoop); +        Result.DV[Level - 1].Direction &= unsigned(~Dependence::DVEntry::EQ); +        Improved = true; +      } +    } +  } +  if (Improved) +    ++GCDsuccesses; +  LLVM_DEBUG(dbgs() << "all done\n"); +  return false; +} + + +//===----------------------------------------------------------------------===// +// banerjeeMIVtest - +// Use Banerjee's Inequalities to test an MIV subscript pair. +// (Wolfe, in the race-car book, calls this the Extreme Value Test.) +// Generally follows the discussion in Section 2.5.2 of +// +//    Optimizing Supercompilers for Supercomputers +//    Michael Wolfe +// +// The inequalities given on page 25 are simplified in that loops are +// normalized so that the lower bound is always 0 and the stride is always 1. +// For example, Wolfe gives +// +//     LB^<_k = (A^-_k - B_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k +// +// where A_k is the coefficient of the kth index in the source subscript, +// B_k is the coefficient of the kth index in the destination subscript, +// U_k is the upper bound of the kth index, L_k is the lower bound of the Kth +// index, and N_k is the stride of the kth index. Since all loops are normalized +// by the SCEV package, N_k = 1 and L_k = 0, allowing us to simplify the +// equation to +// +//     LB^<_k = (A^-_k - B_k)^- (U_k - 0 - 1) + (A_k - B_k)0 - B_k 1 +//            = (A^-_k - B_k)^- (U_k - 1)  - B_k +// +// Similar simplifications are possible for the other equations. +// +// When we can't determine the number of iterations for a loop, +// we use NULL as an indicator for the worst case, infinity. +// When computing the upper bound, NULL denotes +inf; +// for the lower bound, NULL denotes -inf. +// +// Return true if dependence disproved. +bool DependenceInfo::banerjeeMIVtest(const SCEV *Src, const SCEV *Dst, +                                     const SmallBitVector &Loops, +                                     FullDependence &Result) const { +  LLVM_DEBUG(dbgs() << "starting Banerjee\n"); +  ++BanerjeeApplications; +  LLVM_DEBUG(dbgs() << "    Src = " << *Src << '\n'); +  const SCEV *A0; +  CoefficientInfo *A = collectCoeffInfo(Src, true, A0); +  LLVM_DEBUG(dbgs() << "    Dst = " << *Dst << '\n'); +  const SCEV *B0; +  CoefficientInfo *B = collectCoeffInfo(Dst, false, B0); +  BoundInfo *Bound = new BoundInfo[MaxLevels + 1]; +  const SCEV *Delta = SE->getMinusSCEV(B0, A0); +  LLVM_DEBUG(dbgs() << "\tDelta = " << *Delta << '\n'); + +  // Compute bounds for all the * directions. +  LLVM_DEBUG(dbgs() << "\tBounds[*]\n"); +  for (unsigned K = 1; K <= MaxLevels; ++K) { +    Bound[K].Iterations = A[K].Iterations ? A[K].Iterations : B[K].Iterations; +    Bound[K].Direction = Dependence::DVEntry::ALL; +    Bound[K].DirSet = Dependence::DVEntry::NONE; +    findBoundsALL(A, B, Bound, K); +#ifndef NDEBUG +    LLVM_DEBUG(dbgs() << "\t    " << K << '\t'); +    if (Bound[K].Lower[Dependence::DVEntry::ALL]) +      LLVM_DEBUG(dbgs() << *Bound[K].Lower[Dependence::DVEntry::ALL] << '\t'); +    else +      LLVM_DEBUG(dbgs() << "-inf\t"); +    if (Bound[K].Upper[Dependence::DVEntry::ALL]) +      LLVM_DEBUG(dbgs() << *Bound[K].Upper[Dependence::DVEntry::ALL] << '\n'); +    else +      LLVM_DEBUG(dbgs() << "+inf\n"); +#endif +  } + +  // Test the *, *, *, ... case. +  bool Disproved = false; +  if (testBounds(Dependence::DVEntry::ALL, 0, Bound, Delta)) { +    // Explore the direction vector hierarchy. +    unsigned DepthExpanded = 0; +    unsigned NewDeps = exploreDirections(1, A, B, Bound, +                                         Loops, DepthExpanded, Delta); +    if (NewDeps > 0) { +      bool Improved = false; +      for (unsigned K = 1; K <= CommonLevels; ++K) { +        if (Loops[K]) { +          unsigned Old = Result.DV[K - 1].Direction; +          Result.DV[K - 1].Direction = Old & Bound[K].DirSet; +          Improved |= Old != Result.DV[K - 1].Direction; +          if (!Result.DV[K - 1].Direction) { +            Improved = false; +            Disproved = true; +            break; +          } +        } +      } +      if (Improved) +        ++BanerjeeSuccesses; +    } +    else { +      ++BanerjeeIndependence; +      Disproved = true; +    } +  } +  else { +    ++BanerjeeIndependence; +    Disproved = true; +  } +  delete [] Bound; +  delete [] A; +  delete [] B; +  return Disproved; +} + + +// Hierarchically expands the direction vector +// search space, combining the directions of discovered dependences +// in the DirSet field of Bound. Returns the number of distinct +// dependences discovered. If the dependence is disproved, +// it will return 0. +unsigned DependenceInfo::exploreDirections(unsigned Level, CoefficientInfo *A, +                                           CoefficientInfo *B, BoundInfo *Bound, +                                           const SmallBitVector &Loops, +                                           unsigned &DepthExpanded, +                                           const SCEV *Delta) const { +  if (Level > CommonLevels) { +    // record result +    LLVM_DEBUG(dbgs() << "\t["); +    for (unsigned K = 1; K <= CommonLevels; ++K) { +      if (Loops[K]) { +        Bound[K].DirSet |= Bound[K].Direction; +#ifndef NDEBUG +        switch (Bound[K].Direction) { +        case Dependence::DVEntry::LT: +          LLVM_DEBUG(dbgs() << " <"); +          break; +        case Dependence::DVEntry::EQ: +          LLVM_DEBUG(dbgs() << " ="); +          break; +        case Dependence::DVEntry::GT: +          LLVM_DEBUG(dbgs() << " >"); +          break; +        case Dependence::DVEntry::ALL: +          LLVM_DEBUG(dbgs() << " *"); +          break; +        default: +          llvm_unreachable("unexpected Bound[K].Direction"); +        } +#endif +      } +    } +    LLVM_DEBUG(dbgs() << " ]\n"); +    return 1; +  } +  if (Loops[Level]) { +    if (Level > DepthExpanded) { +      DepthExpanded = Level; +      // compute bounds for <, =, > at current level +      findBoundsLT(A, B, Bound, Level); +      findBoundsGT(A, B, Bound, Level); +      findBoundsEQ(A, B, Bound, Level); +#ifndef NDEBUG +      LLVM_DEBUG(dbgs() << "\tBound for level = " << Level << '\n'); +      LLVM_DEBUG(dbgs() << "\t    <\t"); +      if (Bound[Level].Lower[Dependence::DVEntry::LT]) +        LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::LT] +                          << '\t'); +      else +        LLVM_DEBUG(dbgs() << "-inf\t"); +      if (Bound[Level].Upper[Dependence::DVEntry::LT]) +        LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::LT] +                          << '\n'); +      else +        LLVM_DEBUG(dbgs() << "+inf\n"); +      LLVM_DEBUG(dbgs() << "\t    =\t"); +      if (Bound[Level].Lower[Dependence::DVEntry::EQ]) +        LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::EQ] +                          << '\t'); +      else +        LLVM_DEBUG(dbgs() << "-inf\t"); +      if (Bound[Level].Upper[Dependence::DVEntry::EQ]) +        LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::EQ] +                          << '\n'); +      else +        LLVM_DEBUG(dbgs() << "+inf\n"); +      LLVM_DEBUG(dbgs() << "\t    >\t"); +      if (Bound[Level].Lower[Dependence::DVEntry::GT]) +        LLVM_DEBUG(dbgs() << *Bound[Level].Lower[Dependence::DVEntry::GT] +                          << '\t'); +      else +        LLVM_DEBUG(dbgs() << "-inf\t"); +      if (Bound[Level].Upper[Dependence::DVEntry::GT]) +        LLVM_DEBUG(dbgs() << *Bound[Level].Upper[Dependence::DVEntry::GT] +                          << '\n'); +      else +        LLVM_DEBUG(dbgs() << "+inf\n"); +#endif +    } + +    unsigned NewDeps = 0; + +    // test bounds for <, *, *, ... +    if (testBounds(Dependence::DVEntry::LT, Level, Bound, Delta)) +      NewDeps += exploreDirections(Level + 1, A, B, Bound, +                                   Loops, DepthExpanded, Delta); + +    // Test bounds for =, *, *, ... +    if (testBounds(Dependence::DVEntry::EQ, Level, Bound, Delta)) +      NewDeps += exploreDirections(Level + 1, A, B, Bound, +                                   Loops, DepthExpanded, Delta); + +    // test bounds for >, *, *, ... +    if (testBounds(Dependence::DVEntry::GT, Level, Bound, Delta)) +      NewDeps += exploreDirections(Level + 1, A, B, Bound, +                                   Loops, DepthExpanded, Delta); + +    Bound[Level].Direction = Dependence::DVEntry::ALL; +    return NewDeps; +  } +  else +    return exploreDirections(Level + 1, A, B, Bound, Loops, DepthExpanded, Delta); +} + + +// Returns true iff the current bounds are plausible. +bool DependenceInfo::testBounds(unsigned char DirKind, unsigned Level, +                                BoundInfo *Bound, const SCEV *Delta) const { +  Bound[Level].Direction = DirKind; +  if (const SCEV *LowerBound = getLowerBound(Bound)) +    if (isKnownPredicate(CmpInst::ICMP_SGT, LowerBound, Delta)) +      return false; +  if (const SCEV *UpperBound = getUpperBound(Bound)) +    if (isKnownPredicate(CmpInst::ICMP_SGT, Delta, UpperBound)) +      return false; +  return true; +} + + +// Computes the upper and lower bounds for level K +// using the * direction. Records them in Bound. +// Wolfe gives the equations +// +//    LB^*_k = (A^-_k - B^+_k)(U_k - L_k) + (A_k - B_k)L_k +//    UB^*_k = (A^+_k - B^-_k)(U_k - L_k) + (A_k - B_k)L_k +// +// Since we normalize loops, we can simplify these equations to +// +//    LB^*_k = (A^-_k - B^+_k)U_k +//    UB^*_k = (A^+_k - B^-_k)U_k +// +// We must be careful to handle the case where the upper bound is unknown. +// Note that the lower bound is always <= 0 +// and the upper bound is always >= 0. +void DependenceInfo::findBoundsALL(CoefficientInfo *A, CoefficientInfo *B, +                                   BoundInfo *Bound, unsigned K) const { +  Bound[K].Lower[Dependence::DVEntry::ALL] = nullptr; // Default value = -infinity. +  Bound[K].Upper[Dependence::DVEntry::ALL] = nullptr; // Default value = +infinity. +  if (Bound[K].Iterations) { +    Bound[K].Lower[Dependence::DVEntry::ALL] = +      SE->getMulExpr(SE->getMinusSCEV(A[K].NegPart, B[K].PosPart), +                     Bound[K].Iterations); +    Bound[K].Upper[Dependence::DVEntry::ALL] = +      SE->getMulExpr(SE->getMinusSCEV(A[K].PosPart, B[K].NegPart), +                     Bound[K].Iterations); +  } +  else { +    // If the difference is 0, we won't need to know the number of iterations. +    if (isKnownPredicate(CmpInst::ICMP_EQ, A[K].NegPart, B[K].PosPart)) +      Bound[K].Lower[Dependence::DVEntry::ALL] = +          SE->getZero(A[K].Coeff->getType()); +    if (isKnownPredicate(CmpInst::ICMP_EQ, A[K].PosPart, B[K].NegPart)) +      Bound[K].Upper[Dependence::DVEntry::ALL] = +          SE->getZero(A[K].Coeff->getType()); +  } +} + + +// Computes the upper and lower bounds for level K +// using the = direction. Records them in Bound. +// Wolfe gives the equations +// +//    LB^=_k = (A_k - B_k)^- (U_k - L_k) + (A_k - B_k)L_k +//    UB^=_k = (A_k - B_k)^+ (U_k - L_k) + (A_k - B_k)L_k +// +// Since we normalize loops, we can simplify these equations to +// +//    LB^=_k = (A_k - B_k)^- U_k +//    UB^=_k = (A_k - B_k)^+ U_k +// +// We must be careful to handle the case where the upper bound is unknown. +// Note that the lower bound is always <= 0 +// and the upper bound is always >= 0. +void DependenceInfo::findBoundsEQ(CoefficientInfo *A, CoefficientInfo *B, +                                  BoundInfo *Bound, unsigned K) const { +  Bound[K].Lower[Dependence::DVEntry::EQ] = nullptr; // Default value = -infinity. +  Bound[K].Upper[Dependence::DVEntry::EQ] = nullptr; // Default value = +infinity. +  if (Bound[K].Iterations) { +    const SCEV *Delta = SE->getMinusSCEV(A[K].Coeff, B[K].Coeff); +    const SCEV *NegativePart = getNegativePart(Delta); +    Bound[K].Lower[Dependence::DVEntry::EQ] = +      SE->getMulExpr(NegativePart, Bound[K].Iterations); +    const SCEV *PositivePart = getPositivePart(Delta); +    Bound[K].Upper[Dependence::DVEntry::EQ] = +      SE->getMulExpr(PositivePart, Bound[K].Iterations); +  } +  else { +    // If the positive/negative part of the difference is 0, +    // we won't need to know the number of iterations. +    const SCEV *Delta = SE->getMinusSCEV(A[K].Coeff, B[K].Coeff); +    const SCEV *NegativePart = getNegativePart(Delta); +    if (NegativePart->isZero()) +      Bound[K].Lower[Dependence::DVEntry::EQ] = NegativePart; // Zero +    const SCEV *PositivePart = getPositivePart(Delta); +    if (PositivePart->isZero()) +      Bound[K].Upper[Dependence::DVEntry::EQ] = PositivePart; // Zero +  } +} + + +// Computes the upper and lower bounds for level K +// using the < direction. Records them in Bound. +// Wolfe gives the equations +// +//    LB^<_k = (A^-_k - B_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k +//    UB^<_k = (A^+_k - B_k)^+ (U_k - L_k - N_k) + (A_k - B_k)L_k - B_k N_k +// +// Since we normalize loops, we can simplify these equations to +// +//    LB^<_k = (A^-_k - B_k)^- (U_k - 1) - B_k +//    UB^<_k = (A^+_k - B_k)^+ (U_k - 1) - B_k +// +// We must be careful to handle the case where the upper bound is unknown. +void DependenceInfo::findBoundsLT(CoefficientInfo *A, CoefficientInfo *B, +                                  BoundInfo *Bound, unsigned K) const { +  Bound[K].Lower[Dependence::DVEntry::LT] = nullptr; // Default value = -infinity. +  Bound[K].Upper[Dependence::DVEntry::LT] = nullptr; // Default value = +infinity. +  if (Bound[K].Iterations) { +    const SCEV *Iter_1 = SE->getMinusSCEV( +        Bound[K].Iterations, SE->getOne(Bound[K].Iterations->getType())); +    const SCEV *NegPart = +      getNegativePart(SE->getMinusSCEV(A[K].NegPart, B[K].Coeff)); +    Bound[K].Lower[Dependence::DVEntry::LT] = +      SE->getMinusSCEV(SE->getMulExpr(NegPart, Iter_1), B[K].Coeff); +    const SCEV *PosPart = +      getPositivePart(SE->getMinusSCEV(A[K].PosPart, B[K].Coeff)); +    Bound[K].Upper[Dependence::DVEntry::LT] = +      SE->getMinusSCEV(SE->getMulExpr(PosPart, Iter_1), B[K].Coeff); +  } +  else { +    // If the positive/negative part of the difference is 0, +    // we won't need to know the number of iterations. +    const SCEV *NegPart = +      getNegativePart(SE->getMinusSCEV(A[K].NegPart, B[K].Coeff)); +    if (NegPart->isZero()) +      Bound[K].Lower[Dependence::DVEntry::LT] = SE->getNegativeSCEV(B[K].Coeff); +    const SCEV *PosPart = +      getPositivePart(SE->getMinusSCEV(A[K].PosPart, B[K].Coeff)); +    if (PosPart->isZero()) +      Bound[K].Upper[Dependence::DVEntry::LT] = SE->getNegativeSCEV(B[K].Coeff); +  } +} + + +// Computes the upper and lower bounds for level K +// using the > direction. Records them in Bound. +// Wolfe gives the equations +// +//    LB^>_k = (A_k - B^+_k)^- (U_k - L_k - N_k) + (A_k - B_k)L_k + A_k N_k +//    UB^>_k = (A_k - B^-_k)^+ (U_k - L_k - N_k) + (A_k - B_k)L_k + A_k N_k +// +// Since we normalize loops, we can simplify these equations to +// +//    LB^>_k = (A_k - B^+_k)^- (U_k - 1) + A_k +//    UB^>_k = (A_k - B^-_k)^+ (U_k - 1) + A_k +// +// We must be careful to handle the case where the upper bound is unknown. +void DependenceInfo::findBoundsGT(CoefficientInfo *A, CoefficientInfo *B, +                                  BoundInfo *Bound, unsigned K) const { +  Bound[K].Lower[Dependence::DVEntry::GT] = nullptr; // Default value = -infinity. +  Bound[K].Upper[Dependence::DVEntry::GT] = nullptr; // Default value = +infinity. +  if (Bound[K].Iterations) { +    const SCEV *Iter_1 = SE->getMinusSCEV( +        Bound[K].Iterations, SE->getOne(Bound[K].Iterations->getType())); +    const SCEV *NegPart = +      getNegativePart(SE->getMinusSCEV(A[K].Coeff, B[K].PosPart)); +    Bound[K].Lower[Dependence::DVEntry::GT] = +      SE->getAddExpr(SE->getMulExpr(NegPart, Iter_1), A[K].Coeff); +    const SCEV *PosPart = +      getPositivePart(SE->getMinusSCEV(A[K].Coeff, B[K].NegPart)); +    Bound[K].Upper[Dependence::DVEntry::GT] = +      SE->getAddExpr(SE->getMulExpr(PosPart, Iter_1), A[K].Coeff); +  } +  else { +    // If the positive/negative part of the difference is 0, +    // we won't need to know the number of iterations. +    const SCEV *NegPart = getNegativePart(SE->getMinusSCEV(A[K].Coeff, B[K].PosPart)); +    if (NegPart->isZero()) +      Bound[K].Lower[Dependence::DVEntry::GT] = A[K].Coeff; +    const SCEV *PosPart = getPositivePart(SE->getMinusSCEV(A[K].Coeff, B[K].NegPart)); +    if (PosPart->isZero()) +      Bound[K].Upper[Dependence::DVEntry::GT] = A[K].Coeff; +  } +} + + +// X^+ = max(X, 0) +const SCEV *DependenceInfo::getPositivePart(const SCEV *X) const { +  return SE->getSMaxExpr(X, SE->getZero(X->getType())); +} + + +// X^- = min(X, 0) +const SCEV *DependenceInfo::getNegativePart(const SCEV *X) const { +  return SE->getSMinExpr(X, SE->getZero(X->getType())); +} + + +// Walks through the subscript, +// collecting each coefficient, the associated loop bounds, +// and recording its positive and negative parts for later use. +DependenceInfo::CoefficientInfo * +DependenceInfo::collectCoeffInfo(const SCEV *Subscript, bool SrcFlag, +                                 const SCEV *&Constant) const { +  const SCEV *Zero = SE->getZero(Subscript->getType()); +  CoefficientInfo *CI = new CoefficientInfo[MaxLevels + 1]; +  for (unsigned K = 1; K <= MaxLevels; ++K) { +    CI[K].Coeff = Zero; +    CI[K].PosPart = Zero; +    CI[K].NegPart = Zero; +    CI[K].Iterations = nullptr; +  } +  while (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Subscript)) { +    const Loop *L = AddRec->getLoop(); +    unsigned K = SrcFlag ? mapSrcLoop(L) : mapDstLoop(L); +    CI[K].Coeff = AddRec->getStepRecurrence(*SE); +    CI[K].PosPart = getPositivePart(CI[K].Coeff); +    CI[K].NegPart = getNegativePart(CI[K].Coeff); +    CI[K].Iterations = collectUpperBound(L, Subscript->getType()); +    Subscript = AddRec->getStart(); +  } +  Constant = Subscript; +#ifndef NDEBUG +  LLVM_DEBUG(dbgs() << "\tCoefficient Info\n"); +  for (unsigned K = 1; K <= MaxLevels; ++K) { +    LLVM_DEBUG(dbgs() << "\t    " << K << "\t" << *CI[K].Coeff); +    LLVM_DEBUG(dbgs() << "\tPos Part = "); +    LLVM_DEBUG(dbgs() << *CI[K].PosPart); +    LLVM_DEBUG(dbgs() << "\tNeg Part = "); +    LLVM_DEBUG(dbgs() << *CI[K].NegPart); +    LLVM_DEBUG(dbgs() << "\tUpper Bound = "); +    if (CI[K].Iterations) +      LLVM_DEBUG(dbgs() << *CI[K].Iterations); +    else +      LLVM_DEBUG(dbgs() << "+inf"); +    LLVM_DEBUG(dbgs() << '\n'); +  } +  LLVM_DEBUG(dbgs() << "\t    Constant = " << *Subscript << '\n'); +#endif +  return CI; +} + + +// Looks through all the bounds info and +// computes the lower bound given the current direction settings +// at each level. If the lower bound for any level is -inf, +// the result is -inf. +const SCEV *DependenceInfo::getLowerBound(BoundInfo *Bound) const { +  const SCEV *Sum = Bound[1].Lower[Bound[1].Direction]; +  for (unsigned K = 2; Sum && K <= MaxLevels; ++K) { +    if (Bound[K].Lower[Bound[K].Direction]) +      Sum = SE->getAddExpr(Sum, Bound[K].Lower[Bound[K].Direction]); +    else +      Sum = nullptr; +  } +  return Sum; +} + + +// Looks through all the bounds info and +// computes the upper bound given the current direction settings +// at each level. If the upper bound at any level is +inf, +// the result is +inf. +const SCEV *DependenceInfo::getUpperBound(BoundInfo *Bound) const { +  const SCEV *Sum = Bound[1].Upper[Bound[1].Direction]; +  for (unsigned K = 2; Sum && K <= MaxLevels; ++K) { +    if (Bound[K].Upper[Bound[K].Direction]) +      Sum = SE->getAddExpr(Sum, Bound[K].Upper[Bound[K].Direction]); +    else +      Sum = nullptr; +  } +  return Sum; +} + + +//===----------------------------------------------------------------------===// +// Constraint manipulation for Delta test. + +// Given a linear SCEV, +// return the coefficient (the step) +// corresponding to the specified loop. +// If there isn't one, return 0. +// For example, given a*i + b*j + c*k, finding the coefficient +// corresponding to the j loop would yield b. +const SCEV *DependenceInfo::findCoefficient(const SCEV *Expr, +                                            const Loop *TargetLoop) const { +  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); +  if (!AddRec) +    return SE->getZero(Expr->getType()); +  if (AddRec->getLoop() == TargetLoop) +    return AddRec->getStepRecurrence(*SE); +  return findCoefficient(AddRec->getStart(), TargetLoop); +} + + +// Given a linear SCEV, +// return the SCEV given by zeroing out the coefficient +// corresponding to the specified loop. +// For example, given a*i + b*j + c*k, zeroing the coefficient +// corresponding to the j loop would yield a*i + c*k. +const SCEV *DependenceInfo::zeroCoefficient(const SCEV *Expr, +                                            const Loop *TargetLoop) const { +  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); +  if (!AddRec) +    return Expr; // ignore +  if (AddRec->getLoop() == TargetLoop) +    return AddRec->getStart(); +  return SE->getAddRecExpr(zeroCoefficient(AddRec->getStart(), TargetLoop), +                           AddRec->getStepRecurrence(*SE), +                           AddRec->getLoop(), +                           AddRec->getNoWrapFlags()); +} + + +// Given a linear SCEV Expr, +// return the SCEV given by adding some Value to the +// coefficient corresponding to the specified TargetLoop. +// For example, given a*i + b*j + c*k, adding 1 to the coefficient +// corresponding to the j loop would yield a*i + (b+1)*j + c*k. +const SCEV *DependenceInfo::addToCoefficient(const SCEV *Expr, +                                             const Loop *TargetLoop, +                                             const SCEV *Value) const { +  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); +  if (!AddRec) // create a new addRec +    return SE->getAddRecExpr(Expr, +                             Value, +                             TargetLoop, +                             SCEV::FlagAnyWrap); // Worst case, with no info. +  if (AddRec->getLoop() == TargetLoop) { +    const SCEV *Sum = SE->getAddExpr(AddRec->getStepRecurrence(*SE), Value); +    if (Sum->isZero()) +      return AddRec->getStart(); +    return SE->getAddRecExpr(AddRec->getStart(), +                             Sum, +                             AddRec->getLoop(), +                             AddRec->getNoWrapFlags()); +  } +  if (SE->isLoopInvariant(AddRec, TargetLoop)) +    return SE->getAddRecExpr(AddRec, Value, TargetLoop, SCEV::FlagAnyWrap); +  return SE->getAddRecExpr( +      addToCoefficient(AddRec->getStart(), TargetLoop, Value), +      AddRec->getStepRecurrence(*SE), AddRec->getLoop(), +      AddRec->getNoWrapFlags()); +} + + +// Review the constraints, looking for opportunities +// to simplify a subscript pair (Src and Dst). +// Return true if some simplification occurs. +// If the simplification isn't exact (that is, if it is conservative +// in terms of dependence), set consistent to false. +// Corresponds to Figure 5 from the paper +// +//            Practical Dependence Testing +//            Goff, Kennedy, Tseng +//            PLDI 1991 +bool DependenceInfo::propagate(const SCEV *&Src, const SCEV *&Dst, +                               SmallBitVector &Loops, +                               SmallVectorImpl<Constraint> &Constraints, +                               bool &Consistent) { +  bool Result = false; +  for (unsigned LI : Loops.set_bits()) { +    LLVM_DEBUG(dbgs() << "\t    Constraint[" << LI << "] is"); +    LLVM_DEBUG(Constraints[LI].dump(dbgs())); +    if (Constraints[LI].isDistance()) +      Result |= propagateDistance(Src, Dst, Constraints[LI], Consistent); +    else if (Constraints[LI].isLine()) +      Result |= propagateLine(Src, Dst, Constraints[LI], Consistent); +    else if (Constraints[LI].isPoint()) +      Result |= propagatePoint(Src, Dst, Constraints[LI]); +  } +  return Result; +} + + +// Attempt to propagate a distance +// constraint into a subscript pair (Src and Dst). +// Return true if some simplification occurs. +// If the simplification isn't exact (that is, if it is conservative +// in terms of dependence), set consistent to false. +bool DependenceInfo::propagateDistance(const SCEV *&Src, const SCEV *&Dst, +                                       Constraint &CurConstraint, +                                       bool &Consistent) { +  const Loop *CurLoop = CurConstraint.getAssociatedLoop(); +  LLVM_DEBUG(dbgs() << "\t\tSrc is " << *Src << "\n"); +  const SCEV *A_K = findCoefficient(Src, CurLoop); +  if (A_K->isZero()) +    return false; +  const SCEV *DA_K = SE->getMulExpr(A_K, CurConstraint.getD()); +  Src = SE->getMinusSCEV(Src, DA_K); +  Src = zeroCoefficient(Src, CurLoop); +  LLVM_DEBUG(dbgs() << "\t\tnew Src is " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "\t\tDst is " << *Dst << "\n"); +  Dst = addToCoefficient(Dst, CurLoop, SE->getNegativeSCEV(A_K)); +  LLVM_DEBUG(dbgs() << "\t\tnew Dst is " << *Dst << "\n"); +  if (!findCoefficient(Dst, CurLoop)->isZero()) +    Consistent = false; +  return true; +} + + +// Attempt to propagate a line +// constraint into a subscript pair (Src and Dst). +// Return true if some simplification occurs. +// If the simplification isn't exact (that is, if it is conservative +// in terms of dependence), set consistent to false. +bool DependenceInfo::propagateLine(const SCEV *&Src, const SCEV *&Dst, +                                   Constraint &CurConstraint, +                                   bool &Consistent) { +  const Loop *CurLoop = CurConstraint.getAssociatedLoop(); +  const SCEV *A = CurConstraint.getA(); +  const SCEV *B = CurConstraint.getB(); +  const SCEV *C = CurConstraint.getC(); +  LLVM_DEBUG(dbgs() << "\t\tA = " << *A << ", B = " << *B << ", C = " << *C +                    << "\n"); +  LLVM_DEBUG(dbgs() << "\t\tSrc = " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "\t\tDst = " << *Dst << "\n"); +  if (A->isZero()) { +    const SCEVConstant *Bconst = dyn_cast<SCEVConstant>(B); +    const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C); +    if (!Bconst || !Cconst) return false; +    APInt Beta = Bconst->getAPInt(); +    APInt Charlie = Cconst->getAPInt(); +    APInt CdivB = Charlie.sdiv(Beta); +    assert(Charlie.srem(Beta) == 0 && "C should be evenly divisible by B"); +    const SCEV *AP_K = findCoefficient(Dst, CurLoop); +    //    Src = SE->getAddExpr(Src, SE->getMulExpr(AP_K, SE->getConstant(CdivB))); +    Src = SE->getMinusSCEV(Src, SE->getMulExpr(AP_K, SE->getConstant(CdivB))); +    Dst = zeroCoefficient(Dst, CurLoop); +    if (!findCoefficient(Src, CurLoop)->isZero()) +      Consistent = false; +  } +  else if (B->isZero()) { +    const SCEVConstant *Aconst = dyn_cast<SCEVConstant>(A); +    const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C); +    if (!Aconst || !Cconst) return false; +    APInt Alpha = Aconst->getAPInt(); +    APInt Charlie = Cconst->getAPInt(); +    APInt CdivA = Charlie.sdiv(Alpha); +    assert(Charlie.srem(Alpha) == 0 && "C should be evenly divisible by A"); +    const SCEV *A_K = findCoefficient(Src, CurLoop); +    Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, SE->getConstant(CdivA))); +    Src = zeroCoefficient(Src, CurLoop); +    if (!findCoefficient(Dst, CurLoop)->isZero()) +      Consistent = false; +  } +  else if (isKnownPredicate(CmpInst::ICMP_EQ, A, B)) { +    const SCEVConstant *Aconst = dyn_cast<SCEVConstant>(A); +    const SCEVConstant *Cconst = dyn_cast<SCEVConstant>(C); +    if (!Aconst || !Cconst) return false; +    APInt Alpha = Aconst->getAPInt(); +    APInt Charlie = Cconst->getAPInt(); +    APInt CdivA = Charlie.sdiv(Alpha); +    assert(Charlie.srem(Alpha) == 0 && "C should be evenly divisible by A"); +    const SCEV *A_K = findCoefficient(Src, CurLoop); +    Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, SE->getConstant(CdivA))); +    Src = zeroCoefficient(Src, CurLoop); +    Dst = addToCoefficient(Dst, CurLoop, A_K); +    if (!findCoefficient(Dst, CurLoop)->isZero()) +      Consistent = false; +  } +  else { +    // paper is incorrect here, or perhaps just misleading +    const SCEV *A_K = findCoefficient(Src, CurLoop); +    Src = SE->getMulExpr(Src, A); +    Dst = SE->getMulExpr(Dst, A); +    Src = SE->getAddExpr(Src, SE->getMulExpr(A_K, C)); +    Src = zeroCoefficient(Src, CurLoop); +    Dst = addToCoefficient(Dst, CurLoop, SE->getMulExpr(A_K, B)); +    if (!findCoefficient(Dst, CurLoop)->isZero()) +      Consistent = false; +  } +  LLVM_DEBUG(dbgs() << "\t\tnew Src = " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "\t\tnew Dst = " << *Dst << "\n"); +  return true; +} + + +// Attempt to propagate a point +// constraint into a subscript pair (Src and Dst). +// Return true if some simplification occurs. +bool DependenceInfo::propagatePoint(const SCEV *&Src, const SCEV *&Dst, +                                    Constraint &CurConstraint) { +  const Loop *CurLoop = CurConstraint.getAssociatedLoop(); +  const SCEV *A_K = findCoefficient(Src, CurLoop); +  const SCEV *AP_K = findCoefficient(Dst, CurLoop); +  const SCEV *XA_K = SE->getMulExpr(A_K, CurConstraint.getX()); +  const SCEV *YAP_K = SE->getMulExpr(AP_K, CurConstraint.getY()); +  LLVM_DEBUG(dbgs() << "\t\tSrc is " << *Src << "\n"); +  Src = SE->getAddExpr(Src, SE->getMinusSCEV(XA_K, YAP_K)); +  Src = zeroCoefficient(Src, CurLoop); +  LLVM_DEBUG(dbgs() << "\t\tnew Src is " << *Src << "\n"); +  LLVM_DEBUG(dbgs() << "\t\tDst is " << *Dst << "\n"); +  Dst = zeroCoefficient(Dst, CurLoop); +  LLVM_DEBUG(dbgs() << "\t\tnew Dst is " << *Dst << "\n"); +  return true; +} + + +// Update direction vector entry based on the current constraint. +void DependenceInfo::updateDirection(Dependence::DVEntry &Level, +                                     const Constraint &CurConstraint) const { +  LLVM_DEBUG(dbgs() << "\tUpdate direction, constraint ="); +  LLVM_DEBUG(CurConstraint.dump(dbgs())); +  if (CurConstraint.isAny()) +    ; // use defaults +  else if (CurConstraint.isDistance()) { +    // this one is consistent, the others aren't +    Level.Scalar = false; +    Level.Distance = CurConstraint.getD(); +    unsigned NewDirection = Dependence::DVEntry::NONE; +    if (!SE->isKnownNonZero(Level.Distance)) // if may be zero +      NewDirection = Dependence::DVEntry::EQ; +    if (!SE->isKnownNonPositive(Level.Distance)) // if may be positive +      NewDirection |= Dependence::DVEntry::LT; +    if (!SE->isKnownNonNegative(Level.Distance)) // if may be negative +      NewDirection |= Dependence::DVEntry::GT; +    Level.Direction &= NewDirection; +  } +  else if (CurConstraint.isLine()) { +    Level.Scalar = false; +    Level.Distance = nullptr; +    // direction should be accurate +  } +  else if (CurConstraint.isPoint()) { +    Level.Scalar = false; +    Level.Distance = nullptr; +    unsigned NewDirection = Dependence::DVEntry::NONE; +    if (!isKnownPredicate(CmpInst::ICMP_NE, +                          CurConstraint.getY(), +                          CurConstraint.getX())) +      // if X may be = Y +      NewDirection |= Dependence::DVEntry::EQ; +    if (!isKnownPredicate(CmpInst::ICMP_SLE, +                          CurConstraint.getY(), +                          CurConstraint.getX())) +      // if Y may be > X +      NewDirection |= Dependence::DVEntry::LT; +    if (!isKnownPredicate(CmpInst::ICMP_SGE, +                          CurConstraint.getY(), +                          CurConstraint.getX())) +      // if Y may be < X +      NewDirection |= Dependence::DVEntry::GT; +    Level.Direction &= NewDirection; +  } +  else +    llvm_unreachable("constraint has unexpected kind"); +} + +/// Check if we can delinearize the subscripts. If the SCEVs representing the +/// source and destination array references are recurrences on a nested loop, +/// this function flattens the nested recurrences into separate recurrences +/// for each loop level. +bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst, +                                    SmallVectorImpl<Subscript> &Pair) { +  assert(isLoadOrStore(Src) && "instruction is not load or store"); +  assert(isLoadOrStore(Dst) && "instruction is not load or store"); +  Value *SrcPtr = getLoadStorePointerOperand(Src); +  Value *DstPtr = getLoadStorePointerOperand(Dst); + +  Loop *SrcLoop = LI->getLoopFor(Src->getParent()); +  Loop *DstLoop = LI->getLoopFor(Dst->getParent()); + +  // Below code mimics the code in Delinearization.cpp +  const SCEV *SrcAccessFn = +    SE->getSCEVAtScope(SrcPtr, SrcLoop); +  const SCEV *DstAccessFn = +    SE->getSCEVAtScope(DstPtr, DstLoop); + +  const SCEVUnknown *SrcBase = +      dyn_cast<SCEVUnknown>(SE->getPointerBase(SrcAccessFn)); +  const SCEVUnknown *DstBase = +      dyn_cast<SCEVUnknown>(SE->getPointerBase(DstAccessFn)); + +  if (!SrcBase || !DstBase || SrcBase != DstBase) +    return false; + +  const SCEV *ElementSize = SE->getElementSize(Src); +  if (ElementSize != SE->getElementSize(Dst)) +    return false; + +  const SCEV *SrcSCEV = SE->getMinusSCEV(SrcAccessFn, SrcBase); +  const SCEV *DstSCEV = SE->getMinusSCEV(DstAccessFn, DstBase); + +  const SCEVAddRecExpr *SrcAR = dyn_cast<SCEVAddRecExpr>(SrcSCEV); +  const SCEVAddRecExpr *DstAR = dyn_cast<SCEVAddRecExpr>(DstSCEV); +  if (!SrcAR || !DstAR || !SrcAR->isAffine() || !DstAR->isAffine()) +    return false; + +  // First step: collect parametric terms in both array references. +  SmallVector<const SCEV *, 4> Terms; +  SE->collectParametricTerms(SrcAR, Terms); +  SE->collectParametricTerms(DstAR, Terms); + +  // Second step: find subscript sizes. +  SmallVector<const SCEV *, 4> Sizes; +  SE->findArrayDimensions(Terms, Sizes, ElementSize); + +  // Third step: compute the access functions for each subscript. +  SmallVector<const SCEV *, 4> SrcSubscripts, DstSubscripts; +  SE->computeAccessFunctions(SrcAR, SrcSubscripts, Sizes); +  SE->computeAccessFunctions(DstAR, DstSubscripts, Sizes); + +  // Fail when there is only a subscript: that's a linearized access function. +  if (SrcSubscripts.size() < 2 || DstSubscripts.size() < 2 || +      SrcSubscripts.size() != DstSubscripts.size()) +    return false; + +  int size = SrcSubscripts.size(); + +  // Statically check that the array bounds are in-range. The first subscript we +  // don't have a size for and it cannot overflow into another subscript, so is +  // always safe. The others need to be 0 <= subscript[i] < bound, for both src +  // and dst. +  // FIXME: It may be better to record these sizes and add them as constraints +  // to the dependency checks. +  for (int i = 1; i < size; ++i) { +    if (!isKnownNonNegative(SrcSubscripts[i], SrcPtr)) +      return false; + +    if (!isKnownLessThan(SrcSubscripts[i], Sizes[i - 1])) +      return false; + +    if (!isKnownNonNegative(DstSubscripts[i], DstPtr)) +      return false; + +    if (!isKnownLessThan(DstSubscripts[i], Sizes[i - 1])) +      return false; +  } + +  LLVM_DEBUG({ +    dbgs() << "\nSrcSubscripts: "; +    for (int i = 0; i < size; i++) +      dbgs() << *SrcSubscripts[i]; +    dbgs() << "\nDstSubscripts: "; +    for (int i = 0; i < size; i++) +      dbgs() << *DstSubscripts[i]; +  }); + +  // The delinearization transforms a single-subscript MIV dependence test into +  // a multi-subscript SIV dependence test that is easier to compute. So we +  // resize Pair to contain as many pairs of subscripts as the delinearization +  // has found, and then initialize the pairs following the delinearization. +  Pair.resize(size); +  for (int i = 0; i < size; ++i) { +    Pair[i].Src = SrcSubscripts[i]; +    Pair[i].Dst = DstSubscripts[i]; +    unifySubscriptType(&Pair[i]); +  } + +  return true; +} + +//===----------------------------------------------------------------------===// + +#ifndef NDEBUG +// For debugging purposes, dump a small bit vector to dbgs(). +static void dumpSmallBitVector(SmallBitVector &BV) { +  dbgs() << "{"; +  for (unsigned VI : BV.set_bits()) { +    dbgs() << VI; +    if (BV.find_next(VI) >= 0) +      dbgs() << ' '; +  } +  dbgs() << "}\n"; +} +#endif + +// depends - +// Returns NULL if there is no dependence. +// Otherwise, return a Dependence with as many details as possible. +// Corresponds to Section 3.1 in the paper +// +//            Practical Dependence Testing +//            Goff, Kennedy, Tseng +//            PLDI 1991 +// +// Care is required to keep the routine below, getSplitIteration(), +// up to date with respect to this routine. +std::unique_ptr<Dependence> +DependenceInfo::depends(Instruction *Src, Instruction *Dst, +                        bool PossiblyLoopIndependent) { +  if (Src == Dst) +    PossiblyLoopIndependent = false; + +  if ((!Src->mayReadFromMemory() && !Src->mayWriteToMemory()) || +      (!Dst->mayReadFromMemory() && !Dst->mayWriteToMemory())) +    // if both instructions don't reference memory, there's no dependence +    return nullptr; + +  if (!isLoadOrStore(Src) || !isLoadOrStore(Dst)) { +    // can only analyze simple loads and stores, i.e., no calls, invokes, etc. +    LLVM_DEBUG(dbgs() << "can only handle simple loads and stores\n"); +    return make_unique<Dependence>(Src, Dst); +  } + +  assert(isLoadOrStore(Src) && "instruction is not load or store"); +  assert(isLoadOrStore(Dst) && "instruction is not load or store"); +  Value *SrcPtr = getLoadStorePointerOperand(Src); +  Value *DstPtr = getLoadStorePointerOperand(Dst); + +  switch (underlyingObjectsAlias(AA, F->getParent()->getDataLayout(), +                                 MemoryLocation::get(Dst), +                                 MemoryLocation::get(Src))) { +  case MayAlias: +  case PartialAlias: +    // cannot analyse objects if we don't understand their aliasing. +    LLVM_DEBUG(dbgs() << "can't analyze may or partial alias\n"); +    return make_unique<Dependence>(Src, Dst); +  case NoAlias: +    // If the objects noalias, they are distinct, accesses are independent. +    LLVM_DEBUG(dbgs() << "no alias\n"); +    return nullptr; +  case MustAlias: +    break; // The underlying objects alias; test accesses for dependence. +  } + +  // establish loop nesting levels +  establishNestingLevels(Src, Dst); +  LLVM_DEBUG(dbgs() << "    common nesting levels = " << CommonLevels << "\n"); +  LLVM_DEBUG(dbgs() << "    maximum nesting levels = " << MaxLevels << "\n"); + +  FullDependence Result(Src, Dst, PossiblyLoopIndependent, CommonLevels); +  ++TotalArrayPairs; + +  unsigned Pairs = 1; +  SmallVector<Subscript, 2> Pair(Pairs); +  const SCEV *SrcSCEV = SE->getSCEV(SrcPtr); +  const SCEV *DstSCEV = SE->getSCEV(DstPtr); +  LLVM_DEBUG(dbgs() << "    SrcSCEV = " << *SrcSCEV << "\n"); +  LLVM_DEBUG(dbgs() << "    DstSCEV = " << *DstSCEV << "\n"); +  Pair[0].Src = SrcSCEV; +  Pair[0].Dst = DstSCEV; + +  if (Delinearize) { +    if (tryDelinearize(Src, Dst, Pair)) { +      LLVM_DEBUG(dbgs() << "    delinearized\n"); +      Pairs = Pair.size(); +    } +  } + +  for (unsigned P = 0; P < Pairs; ++P) { +    Pair[P].Loops.resize(MaxLevels + 1); +    Pair[P].GroupLoops.resize(MaxLevels + 1); +    Pair[P].Group.resize(Pairs); +    removeMatchingExtensions(&Pair[P]); +    Pair[P].Classification = +      classifyPair(Pair[P].Src, LI->getLoopFor(Src->getParent()), +                   Pair[P].Dst, LI->getLoopFor(Dst->getParent()), +                   Pair[P].Loops); +    Pair[P].GroupLoops = Pair[P].Loops; +    Pair[P].Group.set(P); +    LLVM_DEBUG(dbgs() << "    subscript " << P << "\n"); +    LLVM_DEBUG(dbgs() << "\tsrc = " << *Pair[P].Src << "\n"); +    LLVM_DEBUG(dbgs() << "\tdst = " << *Pair[P].Dst << "\n"); +    LLVM_DEBUG(dbgs() << "\tclass = " << Pair[P].Classification << "\n"); +    LLVM_DEBUG(dbgs() << "\tloops = "); +    LLVM_DEBUG(dumpSmallBitVector(Pair[P].Loops)); +  } + +  SmallBitVector Separable(Pairs); +  SmallBitVector Coupled(Pairs); + +  // Partition subscripts into separable and minimally-coupled groups +  // Algorithm in paper is algorithmically better; +  // this may be faster in practice. Check someday. +  // +  // Here's an example of how it works. Consider this code: +  // +  //   for (i = ...) { +  //     for (j = ...) { +  //       for (k = ...) { +  //         for (l = ...) { +  //           for (m = ...) { +  //             A[i][j][k][m] = ...; +  //             ... = A[0][j][l][i + j]; +  //           } +  //         } +  //       } +  //     } +  //   } +  // +  // There are 4 subscripts here: +  //    0 [i] and [0] +  //    1 [j] and [j] +  //    2 [k] and [l] +  //    3 [m] and [i + j] +  // +  // We've already classified each subscript pair as ZIV, SIV, etc., +  // and collected all the loops mentioned by pair P in Pair[P].Loops. +  // In addition, we've initialized Pair[P].GroupLoops to Pair[P].Loops +  // and set Pair[P].Group = {P}. +  // +  //      Src Dst    Classification Loops  GroupLoops Group +  //    0 [i] [0]         SIV       {1}      {1}        {0} +  //    1 [j] [j]         SIV       {2}      {2}        {1} +  //    2 [k] [l]         RDIV      {3,4}    {3,4}      {2} +  //    3 [m] [i + j]     MIV       {1,2,5}  {1,2,5}    {3} +  // +  // For each subscript SI 0 .. 3, we consider each remaining subscript, SJ. +  // So, 0 is compared against 1, 2, and 3; 1 is compared against 2 and 3, etc. +  // +  // We begin by comparing 0 and 1. The intersection of the GroupLoops is empty. +  // Next, 0 and 2. Again, the intersection of their GroupLoops is empty. +  // Next 0 and 3. The intersection of their GroupLoop = {1}, not empty, +  // so Pair[3].Group = {0,3} and Done = false (that is, 0 will not be added +  // to either Separable or Coupled). +  // +  // Next, we consider 1 and 2. The intersection of the GroupLoops is empty. +  // Next, 1 and 3. The intersectionof their GroupLoops = {2}, not empty, +  // so Pair[3].Group = {0, 1, 3} and Done = false. +  // +  // Next, we compare 2 against 3. The intersection of the GroupLoops is empty. +  // Since Done remains true, we add 2 to the set of Separable pairs. +  // +  // Finally, we consider 3. There's nothing to compare it with, +  // so Done remains true and we add it to the Coupled set. +  // Pair[3].Group = {0, 1, 3} and GroupLoops = {1, 2, 5}. +  // +  // In the end, we've got 1 separable subscript and 1 coupled group. +  for (unsigned SI = 0; SI < Pairs; ++SI) { +    if (Pair[SI].Classification == Subscript::NonLinear) { +      // ignore these, but collect loops for later +      ++NonlinearSubscriptPairs; +      collectCommonLoops(Pair[SI].Src, +                         LI->getLoopFor(Src->getParent()), +                         Pair[SI].Loops); +      collectCommonLoops(Pair[SI].Dst, +                         LI->getLoopFor(Dst->getParent()), +                         Pair[SI].Loops); +      Result.Consistent = false; +    } else if (Pair[SI].Classification == Subscript::ZIV) { +      // always separable +      Separable.set(SI); +    } +    else { +      // SIV, RDIV, or MIV, so check for coupled group +      bool Done = true; +      for (unsigned SJ = SI + 1; SJ < Pairs; ++SJ) { +        SmallBitVector Intersection = Pair[SI].GroupLoops; +        Intersection &= Pair[SJ].GroupLoops; +        if (Intersection.any()) { +          // accumulate set of all the loops in group +          Pair[SJ].GroupLoops |= Pair[SI].GroupLoops; +          // accumulate set of all subscripts in group +          Pair[SJ].Group |= Pair[SI].Group; +          Done = false; +        } +      } +      if (Done) { +        if (Pair[SI].Group.count() == 1) { +          Separable.set(SI); +          ++SeparableSubscriptPairs; +        } +        else { +          Coupled.set(SI); +          ++CoupledSubscriptPairs; +        } +      } +    } +  } + +  LLVM_DEBUG(dbgs() << "    Separable = "); +  LLVM_DEBUG(dumpSmallBitVector(Separable)); +  LLVM_DEBUG(dbgs() << "    Coupled = "); +  LLVM_DEBUG(dumpSmallBitVector(Coupled)); + +  Constraint NewConstraint; +  NewConstraint.setAny(SE); + +  // test separable subscripts +  for (unsigned SI : Separable.set_bits()) { +    LLVM_DEBUG(dbgs() << "testing subscript " << SI); +    switch (Pair[SI].Classification) { +    case Subscript::ZIV: +      LLVM_DEBUG(dbgs() << ", ZIV\n"); +      if (testZIV(Pair[SI].Src, Pair[SI].Dst, Result)) +        return nullptr; +      break; +    case Subscript::SIV: { +      LLVM_DEBUG(dbgs() << ", SIV\n"); +      unsigned Level; +      const SCEV *SplitIter = nullptr; +      if (testSIV(Pair[SI].Src, Pair[SI].Dst, Level, Result, NewConstraint, +                  SplitIter)) +        return nullptr; +      break; +    } +    case Subscript::RDIV: +      LLVM_DEBUG(dbgs() << ", RDIV\n"); +      if (testRDIV(Pair[SI].Src, Pair[SI].Dst, Result)) +        return nullptr; +      break; +    case Subscript::MIV: +      LLVM_DEBUG(dbgs() << ", MIV\n"); +      if (testMIV(Pair[SI].Src, Pair[SI].Dst, Pair[SI].Loops, Result)) +        return nullptr; +      break; +    default: +      llvm_unreachable("subscript has unexpected classification"); +    } +  } + +  if (Coupled.count()) { +    // test coupled subscript groups +    LLVM_DEBUG(dbgs() << "starting on coupled subscripts\n"); +    LLVM_DEBUG(dbgs() << "MaxLevels + 1 = " << MaxLevels + 1 << "\n"); +    SmallVector<Constraint, 4> Constraints(MaxLevels + 1); +    for (unsigned II = 0; II <= MaxLevels; ++II) +      Constraints[II].setAny(SE); +    for (unsigned SI : Coupled.set_bits()) { +      LLVM_DEBUG(dbgs() << "testing subscript group " << SI << " { "); +      SmallBitVector Group(Pair[SI].Group); +      SmallBitVector Sivs(Pairs); +      SmallBitVector Mivs(Pairs); +      SmallBitVector ConstrainedLevels(MaxLevels + 1); +      SmallVector<Subscript *, 4> PairsInGroup; +      for (unsigned SJ : Group.set_bits()) { +        LLVM_DEBUG(dbgs() << SJ << " "); +        if (Pair[SJ].Classification == Subscript::SIV) +          Sivs.set(SJ); +        else +          Mivs.set(SJ); +        PairsInGroup.push_back(&Pair[SJ]); +      } +      unifySubscriptType(PairsInGroup); +      LLVM_DEBUG(dbgs() << "}\n"); +      while (Sivs.any()) { +        bool Changed = false; +        for (unsigned SJ : Sivs.set_bits()) { +          LLVM_DEBUG(dbgs() << "testing subscript " << SJ << ", SIV\n"); +          // SJ is an SIV subscript that's part of the current coupled group +          unsigned Level; +          const SCEV *SplitIter = nullptr; +          LLVM_DEBUG(dbgs() << "SIV\n"); +          if (testSIV(Pair[SJ].Src, Pair[SJ].Dst, Level, Result, NewConstraint, +                      SplitIter)) +            return nullptr; +          ConstrainedLevels.set(Level); +          if (intersectConstraints(&Constraints[Level], &NewConstraint)) { +            if (Constraints[Level].isEmpty()) { +              ++DeltaIndependence; +              return nullptr; +            } +            Changed = true; +          } +          Sivs.reset(SJ); +        } +        if (Changed) { +          // propagate, possibly creating new SIVs and ZIVs +          LLVM_DEBUG(dbgs() << "    propagating\n"); +          LLVM_DEBUG(dbgs() << "\tMivs = "); +          LLVM_DEBUG(dumpSmallBitVector(Mivs)); +          for (unsigned SJ : Mivs.set_bits()) { +            // SJ is an MIV subscript that's part of the current coupled group +            LLVM_DEBUG(dbgs() << "\tSJ = " << SJ << "\n"); +            if (propagate(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, +                          Constraints, Result.Consistent)) { +              LLVM_DEBUG(dbgs() << "\t    Changed\n"); +              ++DeltaPropagations; +              Pair[SJ].Classification = +                classifyPair(Pair[SJ].Src, LI->getLoopFor(Src->getParent()), +                             Pair[SJ].Dst, LI->getLoopFor(Dst->getParent()), +                             Pair[SJ].Loops); +              switch (Pair[SJ].Classification) { +              case Subscript::ZIV: +                LLVM_DEBUG(dbgs() << "ZIV\n"); +                if (testZIV(Pair[SJ].Src, Pair[SJ].Dst, Result)) +                  return nullptr; +                Mivs.reset(SJ); +                break; +              case Subscript::SIV: +                Sivs.set(SJ); +                Mivs.reset(SJ); +                break; +              case Subscript::RDIV: +              case Subscript::MIV: +                break; +              default: +                llvm_unreachable("bad subscript classification"); +              } +            } +          } +        } +      } + +      // test & propagate remaining RDIVs +      for (unsigned SJ : Mivs.set_bits()) { +        if (Pair[SJ].Classification == Subscript::RDIV) { +          LLVM_DEBUG(dbgs() << "RDIV test\n"); +          if (testRDIV(Pair[SJ].Src, Pair[SJ].Dst, Result)) +            return nullptr; +          // I don't yet understand how to propagate RDIV results +          Mivs.reset(SJ); +        } +      } + +      // test remaining MIVs +      // This code is temporary. +      // Better to somehow test all remaining subscripts simultaneously. +      for (unsigned SJ : Mivs.set_bits()) { +        if (Pair[SJ].Classification == Subscript::MIV) { +          LLVM_DEBUG(dbgs() << "MIV test\n"); +          if (testMIV(Pair[SJ].Src, Pair[SJ].Dst, Pair[SJ].Loops, Result)) +            return nullptr; +        } +        else +          llvm_unreachable("expected only MIV subscripts at this point"); +      } + +      // update Result.DV from constraint vector +      LLVM_DEBUG(dbgs() << "    updating\n"); +      for (unsigned SJ : ConstrainedLevels.set_bits()) { +        if (SJ > CommonLevels) +          break; +        updateDirection(Result.DV[SJ - 1], Constraints[SJ]); +        if (Result.DV[SJ - 1].Direction == Dependence::DVEntry::NONE) +          return nullptr; +      } +    } +  } + +  // Make sure the Scalar flags are set correctly. +  SmallBitVector CompleteLoops(MaxLevels + 1); +  for (unsigned SI = 0; SI < Pairs; ++SI) +    CompleteLoops |= Pair[SI].Loops; +  for (unsigned II = 1; II <= CommonLevels; ++II) +    if (CompleteLoops[II]) +      Result.DV[II - 1].Scalar = false; + +  if (PossiblyLoopIndependent) { +    // Make sure the LoopIndependent flag is set correctly. +    // All directions must include equal, otherwise no +    // loop-independent dependence is possible. +    for (unsigned II = 1; II <= CommonLevels; ++II) { +      if (!(Result.getDirection(II) & Dependence::DVEntry::EQ)) { +        Result.LoopIndependent = false; +        break; +      } +    } +  } +  else { +    // On the other hand, if all directions are equal and there's no +    // loop-independent dependence possible, then no dependence exists. +    bool AllEqual = true; +    for (unsigned II = 1; II <= CommonLevels; ++II) { +      if (Result.getDirection(II) != Dependence::DVEntry::EQ) { +        AllEqual = false; +        break; +      } +    } +    if (AllEqual) +      return nullptr; +  } + +  return make_unique<FullDependence>(std::move(Result)); +} + + + +//===----------------------------------------------------------------------===// +// getSplitIteration - +// Rather than spend rarely-used space recording the splitting iteration +// during the Weak-Crossing SIV test, we re-compute it on demand. +// The re-computation is basically a repeat of the entire dependence test, +// though simplified since we know that the dependence exists. +// It's tedious, since we must go through all propagations, etc. +// +// Care is required to keep this code up to date with respect to the routine +// above, depends(). +// +// Generally, the dependence analyzer will be used to build +// a dependence graph for a function (basically a map from instructions +// to dependences). Looking for cycles in the graph shows us loops +// that cannot be trivially vectorized/parallelized. +// +// We can try to improve the situation by examining all the dependences +// that make up the cycle, looking for ones we can break. +// Sometimes, peeling the first or last iteration of a loop will break +// dependences, and we've got flags for those possibilities. +// Sometimes, splitting a loop at some other iteration will do the trick, +// and we've got a flag for that case. Rather than waste the space to +// record the exact iteration (since we rarely know), we provide +// a method that calculates the iteration. It's a drag that it must work +// from scratch, but wonderful in that it's possible. +// +// Here's an example: +// +//    for (i = 0; i < 10; i++) +//        A[i] = ... +//        ... = A[11 - i] +// +// There's a loop-carried flow dependence from the store to the load, +// found by the weak-crossing SIV test. The dependence will have a flag, +// indicating that the dependence can be broken by splitting the loop. +// Calling getSplitIteration will return 5. +// Splitting the loop breaks the dependence, like so: +// +//    for (i = 0; i <= 5; i++) +//        A[i] = ... +//        ... = A[11 - i] +//    for (i = 6; i < 10; i++) +//        A[i] = ... +//        ... = A[11 - i] +// +// breaks the dependence and allows us to vectorize/parallelize +// both loops. +const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep, +                                              unsigned SplitLevel) { +  assert(Dep.isSplitable(SplitLevel) && +         "Dep should be splitable at SplitLevel"); +  Instruction *Src = Dep.getSrc(); +  Instruction *Dst = Dep.getDst(); +  assert(Src->mayReadFromMemory() || Src->mayWriteToMemory()); +  assert(Dst->mayReadFromMemory() || Dst->mayWriteToMemory()); +  assert(isLoadOrStore(Src)); +  assert(isLoadOrStore(Dst)); +  Value *SrcPtr = getLoadStorePointerOperand(Src); +  Value *DstPtr = getLoadStorePointerOperand(Dst); +  assert(underlyingObjectsAlias(AA, F->getParent()->getDataLayout(), +                                MemoryLocation::get(Dst), +                                MemoryLocation::get(Src)) == MustAlias); + +  // establish loop nesting levels +  establishNestingLevels(Src, Dst); + +  FullDependence Result(Src, Dst, false, CommonLevels); + +  unsigned Pairs = 1; +  SmallVector<Subscript, 2> Pair(Pairs); +  const SCEV *SrcSCEV = SE->getSCEV(SrcPtr); +  const SCEV *DstSCEV = SE->getSCEV(DstPtr); +  Pair[0].Src = SrcSCEV; +  Pair[0].Dst = DstSCEV; + +  if (Delinearize) { +    if (tryDelinearize(Src, Dst, Pair)) { +      LLVM_DEBUG(dbgs() << "    delinearized\n"); +      Pairs = Pair.size(); +    } +  } + +  for (unsigned P = 0; P < Pairs; ++P) { +    Pair[P].Loops.resize(MaxLevels + 1); +    Pair[P].GroupLoops.resize(MaxLevels + 1); +    Pair[P].Group.resize(Pairs); +    removeMatchingExtensions(&Pair[P]); +    Pair[P].Classification = +      classifyPair(Pair[P].Src, LI->getLoopFor(Src->getParent()), +                   Pair[P].Dst, LI->getLoopFor(Dst->getParent()), +                   Pair[P].Loops); +    Pair[P].GroupLoops = Pair[P].Loops; +    Pair[P].Group.set(P); +  } + +  SmallBitVector Separable(Pairs); +  SmallBitVector Coupled(Pairs); + +  // partition subscripts into separable and minimally-coupled groups +  for (unsigned SI = 0; SI < Pairs; ++SI) { +    if (Pair[SI].Classification == Subscript::NonLinear) { +      // ignore these, but collect loops for later +      collectCommonLoops(Pair[SI].Src, +                         LI->getLoopFor(Src->getParent()), +                         Pair[SI].Loops); +      collectCommonLoops(Pair[SI].Dst, +                         LI->getLoopFor(Dst->getParent()), +                         Pair[SI].Loops); +      Result.Consistent = false; +    } +    else if (Pair[SI].Classification == Subscript::ZIV) +      Separable.set(SI); +    else { +      // SIV, RDIV, or MIV, so check for coupled group +      bool Done = true; +      for (unsigned SJ = SI + 1; SJ < Pairs; ++SJ) { +        SmallBitVector Intersection = Pair[SI].GroupLoops; +        Intersection &= Pair[SJ].GroupLoops; +        if (Intersection.any()) { +          // accumulate set of all the loops in group +          Pair[SJ].GroupLoops |= Pair[SI].GroupLoops; +          // accumulate set of all subscripts in group +          Pair[SJ].Group |= Pair[SI].Group; +          Done = false; +        } +      } +      if (Done) { +        if (Pair[SI].Group.count() == 1) +          Separable.set(SI); +        else +          Coupled.set(SI); +      } +    } +  } + +  Constraint NewConstraint; +  NewConstraint.setAny(SE); + +  // test separable subscripts +  for (unsigned SI : Separable.set_bits()) { +    switch (Pair[SI].Classification) { +    case Subscript::SIV: { +      unsigned Level; +      const SCEV *SplitIter = nullptr; +      (void) testSIV(Pair[SI].Src, Pair[SI].Dst, Level, +                     Result, NewConstraint, SplitIter); +      if (Level == SplitLevel) { +        assert(SplitIter != nullptr); +        return SplitIter; +      } +      break; +    } +    case Subscript::ZIV: +    case Subscript::RDIV: +    case Subscript::MIV: +      break; +    default: +      llvm_unreachable("subscript has unexpected classification"); +    } +  } + +  if (Coupled.count()) { +    // test coupled subscript groups +    SmallVector<Constraint, 4> Constraints(MaxLevels + 1); +    for (unsigned II = 0; II <= MaxLevels; ++II) +      Constraints[II].setAny(SE); +    for (unsigned SI : Coupled.set_bits()) { +      SmallBitVector Group(Pair[SI].Group); +      SmallBitVector Sivs(Pairs); +      SmallBitVector Mivs(Pairs); +      SmallBitVector ConstrainedLevels(MaxLevels + 1); +      for (unsigned SJ : Group.set_bits()) { +        if (Pair[SJ].Classification == Subscript::SIV) +          Sivs.set(SJ); +        else +          Mivs.set(SJ); +      } +      while (Sivs.any()) { +        bool Changed = false; +        for (unsigned SJ : Sivs.set_bits()) { +          // SJ is an SIV subscript that's part of the current coupled group +          unsigned Level; +          const SCEV *SplitIter = nullptr; +          (void) testSIV(Pair[SJ].Src, Pair[SJ].Dst, Level, +                         Result, NewConstraint, SplitIter); +          if (Level == SplitLevel && SplitIter) +            return SplitIter; +          ConstrainedLevels.set(Level); +          if (intersectConstraints(&Constraints[Level], &NewConstraint)) +            Changed = true; +          Sivs.reset(SJ); +        } +        if (Changed) { +          // propagate, possibly creating new SIVs and ZIVs +          for (unsigned SJ : Mivs.set_bits()) { +            // SJ is an MIV subscript that's part of the current coupled group +            if (propagate(Pair[SJ].Src, Pair[SJ].Dst, +                          Pair[SJ].Loops, Constraints, Result.Consistent)) { +              Pair[SJ].Classification = +                classifyPair(Pair[SJ].Src, LI->getLoopFor(Src->getParent()), +                             Pair[SJ].Dst, LI->getLoopFor(Dst->getParent()), +                             Pair[SJ].Loops); +              switch (Pair[SJ].Classification) { +              case Subscript::ZIV: +                Mivs.reset(SJ); +                break; +              case Subscript::SIV: +                Sivs.set(SJ); +                Mivs.reset(SJ); +                break; +              case Subscript::RDIV: +              case Subscript::MIV: +                break; +              default: +                llvm_unreachable("bad subscript classification"); +              } +            } +          } +        } +      } +    } +  } +  llvm_unreachable("somehow reached end of routine"); +  return nullptr; +} diff --git a/contrib/llvm/lib/Analysis/DivergenceAnalysis.cpp b/contrib/llvm/lib/Analysis/DivergenceAnalysis.cpp new file mode 100644 index 000000000000..f5f1874c9303 --- /dev/null +++ b/contrib/llvm/lib/Analysis/DivergenceAnalysis.cpp @@ -0,0 +1,340 @@ +//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements divergence analysis which determines whether a branch +// in a GPU program is divergent.It can help branch optimizations such as jump +// threading and loop unswitching to make better decisions. +// +// GPU programs typically use the SIMD execution model, where multiple threads +// in the same execution group have to execute in lock-step. Therefore, if the +// code contains divergent branches (i.e., threads in a group do not agree on +// which path of the branch to take), the group of threads has to execute all +// the paths from that branch with different subsets of threads enabled until +// they converge at the immediately post-dominating BB of the paths. +// +// Due to this execution model, some optimizations such as jump +// threading and loop unswitching can be unfortunately harmful when performed on +// divergent branches. Therefore, an analysis that computes which branches in a +// GPU program are divergent can help the compiler to selectively run these +// optimizations. +// +// This file defines divergence analysis which computes a conservative but +// non-trivial approximation of all divergent branches in a GPU program. It +// partially implements the approach described in +// +//   Divergence Analysis +//   Sampaio, Souza, Collange, Pereira +//   TOPLAS '13 +// +// The divergence analysis identifies the sources of divergence (e.g., special +// variables that hold the thread ID), and recursively marks variables that are +// data or sync dependent on a source of divergence as divergent. +// +// While data dependency is a well-known concept, the notion of sync dependency +// is worth more explanation. Sync dependence characterizes the control flow +// aspect of the propagation of branch divergence. For example, +// +//   %cond = icmp slt i32 %tid, 10 +//   br i1 %cond, label %then, label %else +// then: +//   br label %merge +// else: +//   br label %merge +// merge: +//   %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// The current implementation has the following limitations: +// 1. intra-procedural. It conservatively considers the arguments of a +//    non-kernel-entry function and the return value of a function call as +//    divergent. +// 2. memory as black box. It conservatively considers values loaded from +//    generic or local address as divergent. This can be improved by leveraging +//    pointer analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <vector> +using namespace llvm; + +#define DEBUG_TYPE "divergence" + +namespace { + +class DivergencePropagator { +public: +  DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT, +                       PostDominatorTree &PDT, DenseSet<const Value *> &DV) +      : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {} +  void populateWithSourcesOfDivergence(); +  void propagate(); + +private: +  // A helper function that explores data dependents of V. +  void exploreDataDependency(Value *V); +  // A helper function that explores sync dependents of TI. +  void exploreSyncDependency(TerminatorInst *TI); +  // Computes the influence region from Start to End. This region includes all +  // basic blocks on any simple path from Start to End. +  void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End, +                              DenseSet<BasicBlock *> &InfluenceRegion); +  // Finds all users of I that are outside the influence region, and add these +  // users to Worklist. +  void findUsersOutsideInfluenceRegion( +      Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion); + +  Function &F; +  TargetTransformInfo &TTI; +  DominatorTree &DT; +  PostDominatorTree &PDT; +  std::vector<Value *> Worklist; // Stack for DFS. +  DenseSet<const Value *> &DV;   // Stores all divergent values. +}; + +void DivergencePropagator::populateWithSourcesOfDivergence() { +  Worklist.clear(); +  DV.clear(); +  for (auto &I : instructions(F)) { +    if (TTI.isSourceOfDivergence(&I)) { +      Worklist.push_back(&I); +      DV.insert(&I); +    } +  } +  for (auto &Arg : F.args()) { +    if (TTI.isSourceOfDivergence(&Arg)) { +      Worklist.push_back(&Arg); +      DV.insert(&Arg); +    } +  } +} + +void DivergencePropagator::exploreSyncDependency(TerminatorInst *TI) { +  // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's +  // immediate post dominator are divergent. This rule handles if-then-else +  // patterns. For example, +  // +  // if (tid < 5) +  //   a1 = 1; +  // else +  //   a2 = 2; +  // a = phi(a1, a2); // sync dependent on (tid < 5) +  BasicBlock *ThisBB = TI->getParent(); + +  // Unreachable blocks may not be in the dominator tree. +  if (!DT.isReachableFromEntry(ThisBB)) +    return; + +  // If the function has no exit blocks or doesn't reach any exit blocks, the +  // post dominator may be null. +  DomTreeNode *ThisNode = PDT.getNode(ThisBB); +  if (!ThisNode) +    return; + +  BasicBlock *IPostDom = ThisNode->getIDom()->getBlock(); +  if (IPostDom == nullptr) +    return; + +  for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) { +    // A PHINode is uniform if it returns the same value no matter which path is +    // taken. +    if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second) +      Worklist.push_back(&*I); +  } + +  // Propagation rule 2: if a value defined in a loop is used outside, the user +  // is sync dependent on the condition of the loop exits that dominate the +  // user. For example, +  // +  // int i = 0; +  // do { +  //   i++; +  //   if (foo(i)) ... // uniform +  // } while (i < tid); +  // if (bar(i)) ...   // divergent +  // +  // A program may contain unstructured loops. Therefore, we cannot leverage +  // LoopInfo, which only recognizes natural loops. +  // +  // The algorithm used here handles both natural and unstructured loops.  Given +  // a branch TI, we first compute its influence region, the union of all simple +  // paths from TI to its immediate post dominator (IPostDom). Then, we search +  // for all the values defined in the influence region but used outside. All +  // these users are sync dependent on TI. +  DenseSet<BasicBlock *> InfluenceRegion; +  computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion); +  // An insight that can speed up the search process is that all the in-region +  // values that are used outside must dominate TI. Therefore, instead of +  // searching every basic blocks in the influence region, we search all the +  // dominators of TI until it is outside the influence region. +  BasicBlock *InfluencedBB = ThisBB; +  while (InfluenceRegion.count(InfluencedBB)) { +    for (auto &I : *InfluencedBB) +      findUsersOutsideInfluenceRegion(I, InfluenceRegion); +    DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom(); +    if (IDomNode == nullptr) +      break; +    InfluencedBB = IDomNode->getBlock(); +  } +} + +void DivergencePropagator::findUsersOutsideInfluenceRegion( +    Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) { +  for (User *U : I.users()) { +    Instruction *UserInst = cast<Instruction>(U); +    if (!InfluenceRegion.count(UserInst->getParent())) { +      if (DV.insert(UserInst).second) +        Worklist.push_back(UserInst); +    } +  } +} + +// A helper function for computeInfluenceRegion that adds successors of "ThisBB" +// to the influence region. +static void +addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End, +                               DenseSet<BasicBlock *> &InfluenceRegion, +                               std::vector<BasicBlock *> &InfluenceStack) { +  for (BasicBlock *Succ : successors(ThisBB)) { +    if (Succ != End && InfluenceRegion.insert(Succ).second) +      InfluenceStack.push_back(Succ); +  } +} + +void DivergencePropagator::computeInfluenceRegion( +    BasicBlock *Start, BasicBlock *End, +    DenseSet<BasicBlock *> &InfluenceRegion) { +  assert(PDT.properlyDominates(End, Start) && +         "End does not properly dominate Start"); + +  // The influence region starts from the end of "Start" to the beginning of +  // "End". Therefore, "Start" should not be in the region unless "Start" is in +  // a loop that doesn't contain "End". +  std::vector<BasicBlock *> InfluenceStack; +  addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack); +  while (!InfluenceStack.empty()) { +    BasicBlock *BB = InfluenceStack.back(); +    InfluenceStack.pop_back(); +    addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack); +  } +} + +void DivergencePropagator::exploreDataDependency(Value *V) { +  // Follow def-use chains of V. +  for (User *U : V->users()) { +    Instruction *UserInst = cast<Instruction>(U); +    if (!TTI.isAlwaysUniform(U) && DV.insert(UserInst).second) +      Worklist.push_back(UserInst); +  } +} + +void DivergencePropagator::propagate() { +  // Traverse the dependency graph using DFS. +  while (!Worklist.empty()) { +    Value *V = Worklist.back(); +    Worklist.pop_back(); +    if (TerminatorInst *TI = dyn_cast<TerminatorInst>(V)) { +      // Terminators with less than two successors won't introduce sync +      // dependency. Ignore them. +      if (TI->getNumSuccessors() > 1) +        exploreSyncDependency(TI); +    } +    exploreDataDependency(V); +  } +} + +} /// end namespace anonymous + +// Register this pass. +char DivergenceAnalysis::ID = 0; +INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis", +                      false, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis", +                    false, true) + +FunctionPass *llvm::createDivergenceAnalysisPass() { +  return new DivergenceAnalysis(); +} + +void DivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.addRequired<PostDominatorTreeWrapperPass>(); +  AU.setPreservesAll(); +} + +bool DivergenceAnalysis::runOnFunction(Function &F) { +  auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); +  if (TTIWP == nullptr) +    return false; + +  TargetTransformInfo &TTI = TTIWP->getTTI(F); +  // Fast path: if the target does not have branch divergence, we do not mark +  // any branch as divergent. +  if (!TTI.hasBranchDivergence()) +    return false; + +  DivergentValues.clear(); +  auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); +  DivergencePropagator DP(F, TTI, +                          getAnalysis<DominatorTreeWrapperPass>().getDomTree(), +                          PDT, DivergentValues); +  DP.populateWithSourcesOfDivergence(); +  DP.propagate(); +  LLVM_DEBUG( +    dbgs() << "\nAfter divergence analysis on " << F.getName() << ":\n"; +    print(dbgs(), F.getParent()) +  ); +  return false; +} + +void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { +  if (DivergentValues.empty()) +    return; +  const Value *FirstDivergentValue = *DivergentValues.begin(); +  const Function *F; +  if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) { +    F = Arg->getParent(); +  } else if (const Instruction *I = +                 dyn_cast<Instruction>(FirstDivergentValue)) { +    F = I->getParent()->getParent(); +  } else { +    llvm_unreachable("Only arguments and instructions can be divergent"); +  } + +  // Dumps all divergent values in F, arguments and then instructions. +  for (auto &Arg : F->args()) { +    OS << (DivergentValues.count(&Arg) ? "DIVERGENT: " : "           "); +    OS << Arg << "\n"; +  } +  // Iterate instructions using instructions() to ensure a deterministic order. +  for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) { +    auto &BB = *BI; +    OS << "\n           " << BB.getName() << ":\n"; +    for (auto &I : BB.instructionsWithoutDebug()) { +      OS << (DivergentValues.count(&I) ? "DIVERGENT:     " : "               "); +      OS << I << "\n"; +    } +  } +  OS << "\n"; +} diff --git a/contrib/llvm/lib/Analysis/DomPrinter.cpp b/contrib/llvm/lib/Analysis/DomPrinter.cpp new file mode 100644 index 000000000000..8abc0e7d0df9 --- /dev/null +++ b/contrib/llvm/lib/Analysis/DomPrinter.cpp @@ -0,0 +1,298 @@ +//===- DomPrinter.cpp - DOT printer for the dominance trees    ------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines '-dot-dom' and '-dot-postdom' analysis passes, which emit +// a dom.<fnname>.dot or postdom.<fnname>.dot file for each function in the +// program, with a graph of the dominance/postdominance tree of that +// function. +// +// There are also passes available to directly call dotty ('-view-dom' or +// '-view-postdom'). By appending '-only' like '-dot-dom-only' only the +// names of the bbs are printed, but the content is hidden. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DomPrinter.h" +#include "llvm/Analysis/DOTGraphTraitsPass.h" +#include "llvm/Analysis/PostDominators.h" + +using namespace llvm; + +namespace llvm { +template<> +struct DOTGraphTraits<DomTreeNode*> : public DefaultDOTGraphTraits { + +  DOTGraphTraits (bool isSimple=false) +    : DefaultDOTGraphTraits(isSimple) {} + +  std::string getNodeLabel(DomTreeNode *Node, DomTreeNode *Graph) { + +    BasicBlock *BB = Node->getBlock(); + +    if (!BB) +      return "Post dominance root node"; + + +    if (isSimple()) +      return DOTGraphTraits<const Function*> +        ::getSimpleNodeLabel(BB, BB->getParent()); +    else +      return DOTGraphTraits<const Function*> +        ::getCompleteNodeLabel(BB, BB->getParent()); +  } +}; + +template<> +struct DOTGraphTraits<DominatorTree*> : public DOTGraphTraits<DomTreeNode*> { + +  DOTGraphTraits (bool isSimple=false) +    : DOTGraphTraits<DomTreeNode*>(isSimple) {} + +  static std::string getGraphName(DominatorTree *DT) { +    return "Dominator tree"; +  } + +  std::string getNodeLabel(DomTreeNode *Node, DominatorTree *G) { +    return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode()); +  } +}; + +template<> +struct DOTGraphTraits<PostDominatorTree*> +  : public DOTGraphTraits<DomTreeNode*> { + +  DOTGraphTraits (bool isSimple=false) +    : DOTGraphTraits<DomTreeNode*>(isSimple) {} + +  static std::string getGraphName(PostDominatorTree *DT) { +    return "Post dominator tree"; +  } + +  std::string getNodeLabel(DomTreeNode *Node, PostDominatorTree *G ) { +    return DOTGraphTraits<DomTreeNode*>::getNodeLabel(Node, G->getRootNode()); +  } +}; +} + +void DominatorTree::viewGraph(const Twine &Name, const Twine &Title) { +#ifndef NDEBUG +  ViewGraph(this, Name, false, Title); +#else +  errs() << "DomTree dump not available, build with DEBUG\n"; +#endif  // NDEBUG +} + +void DominatorTree::viewGraph() { +#ifndef NDEBUG +  this->viewGraph("domtree", "Dominator Tree for function"); +#else +  errs() << "DomTree dump not available, build with DEBUG\n"; +#endif  // NDEBUG +} + +namespace { +struct DominatorTreeWrapperPassAnalysisGraphTraits { +  static DominatorTree *getGraph(DominatorTreeWrapperPass *DTWP) { +    return &DTWP->getDomTree(); +  } +}; + +struct DomViewer : public DOTGraphTraitsViewer< +                       DominatorTreeWrapperPass, false, DominatorTree *, +                       DominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  DomViewer() +      : DOTGraphTraitsViewer<DominatorTreeWrapperPass, false, DominatorTree *, +                             DominatorTreeWrapperPassAnalysisGraphTraits>( +            "dom", ID) { +    initializeDomViewerPass(*PassRegistry::getPassRegistry()); +  } +}; + +struct DomOnlyViewer : public DOTGraphTraitsViewer< +                           DominatorTreeWrapperPass, true, DominatorTree *, +                           DominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  DomOnlyViewer() +      : DOTGraphTraitsViewer<DominatorTreeWrapperPass, true, DominatorTree *, +                             DominatorTreeWrapperPassAnalysisGraphTraits>( +            "domonly", ID) { +    initializeDomOnlyViewerPass(*PassRegistry::getPassRegistry()); +  } +}; + +struct PostDominatorTreeWrapperPassAnalysisGraphTraits { +  static PostDominatorTree *getGraph(PostDominatorTreeWrapperPass *PDTWP) { +    return &PDTWP->getPostDomTree(); +  } +}; + +struct PostDomViewer : public DOTGraphTraitsViewer< +                          PostDominatorTreeWrapperPass, false, +                          PostDominatorTree *, +                          PostDominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  PostDomViewer() : +    DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, false, +                         PostDominatorTree *, +                         PostDominatorTreeWrapperPassAnalysisGraphTraits>( +        "postdom", ID){ +      initializePostDomViewerPass(*PassRegistry::getPassRegistry()); +    } +}; + +struct PostDomOnlyViewer : public DOTGraphTraitsViewer< +                            PostDominatorTreeWrapperPass, true, +                            PostDominatorTree *, +                            PostDominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  PostDomOnlyViewer() : +    DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, true, +                         PostDominatorTree *, +                         PostDominatorTreeWrapperPassAnalysisGraphTraits>( +        "postdomonly", ID){ +      initializePostDomOnlyViewerPass(*PassRegistry::getPassRegistry()); +    } +}; +} // end anonymous namespace + +char DomViewer::ID = 0; +INITIALIZE_PASS(DomViewer, "view-dom", +                "View dominance tree of function", false, false) + +char DomOnlyViewer::ID = 0; +INITIALIZE_PASS(DomOnlyViewer, "view-dom-only", +                "View dominance tree of function (with no function bodies)", +                false, false) + +char PostDomViewer::ID = 0; +INITIALIZE_PASS(PostDomViewer, "view-postdom", +                "View postdominance tree of function", false, false) + +char PostDomOnlyViewer::ID = 0; +INITIALIZE_PASS(PostDomOnlyViewer, "view-postdom-only", +                "View postdominance tree of function " +                "(with no function bodies)", +                false, false) + +namespace { +struct DomPrinter : public DOTGraphTraitsPrinter< +                        DominatorTreeWrapperPass, false, DominatorTree *, +                        DominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  DomPrinter() +      : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, false, DominatorTree *, +                              DominatorTreeWrapperPassAnalysisGraphTraits>( +            "dom", ID) { +    initializeDomPrinterPass(*PassRegistry::getPassRegistry()); +  } +}; + +struct DomOnlyPrinter : public DOTGraphTraitsPrinter< +                            DominatorTreeWrapperPass, true, DominatorTree *, +                            DominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  DomOnlyPrinter() +      : DOTGraphTraitsPrinter<DominatorTreeWrapperPass, true, DominatorTree *, +                              DominatorTreeWrapperPassAnalysisGraphTraits>( +            "domonly", ID) { +    initializeDomOnlyPrinterPass(*PassRegistry::getPassRegistry()); +  } +}; + +struct PostDomPrinter +  : public DOTGraphTraitsPrinter< +                            PostDominatorTreeWrapperPass, false, +                            PostDominatorTree *, +                            PostDominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  PostDomPrinter() : +    DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, false, +                          PostDominatorTree *, +                          PostDominatorTreeWrapperPassAnalysisGraphTraits>( +        "postdom", ID) { +      initializePostDomPrinterPass(*PassRegistry::getPassRegistry()); +    } +}; + +struct PostDomOnlyPrinter +  : public DOTGraphTraitsPrinter< +                            PostDominatorTreeWrapperPass, true, +                            PostDominatorTree *, +                            PostDominatorTreeWrapperPassAnalysisGraphTraits> { +  static char ID; +  PostDomOnlyPrinter() : +    DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, true, +                          PostDominatorTree *, +                          PostDominatorTreeWrapperPassAnalysisGraphTraits>( +        "postdomonly", ID) { +      initializePostDomOnlyPrinterPass(*PassRegistry::getPassRegistry()); +    } +}; +} // end anonymous namespace + + + +char DomPrinter::ID = 0; +INITIALIZE_PASS(DomPrinter, "dot-dom", +                "Print dominance tree of function to 'dot' file", +                false, false) + +char DomOnlyPrinter::ID = 0; +INITIALIZE_PASS(DomOnlyPrinter, "dot-dom-only", +                "Print dominance tree of function to 'dot' file " +                "(with no function bodies)", +                false, false) + +char PostDomPrinter::ID = 0; +INITIALIZE_PASS(PostDomPrinter, "dot-postdom", +                "Print postdominance tree of function to 'dot' file", +                false, false) + +char PostDomOnlyPrinter::ID = 0; +INITIALIZE_PASS(PostDomOnlyPrinter, "dot-postdom-only", +                "Print postdominance tree of function to 'dot' file " +                "(with no function bodies)", +                false, false) + +// Create methods available outside of this file, to use them +// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by +// the link time optimization. + +FunctionPass *llvm::createDomPrinterPass() { +  return new DomPrinter(); +} + +FunctionPass *llvm::createDomOnlyPrinterPass() { +  return new DomOnlyPrinter(); +} + +FunctionPass *llvm::createDomViewerPass() { +  return new DomViewer(); +} + +FunctionPass *llvm::createDomOnlyViewerPass() { +  return new DomOnlyViewer(); +} + +FunctionPass *llvm::createPostDomPrinterPass() { +  return new PostDomPrinter(); +} + +FunctionPass *llvm::createPostDomOnlyPrinterPass() { +  return new PostDomOnlyPrinter(); +} + +FunctionPass *llvm::createPostDomViewerPass() { +  return new PostDomViewer(); +} + +FunctionPass *llvm::createPostDomOnlyViewerPass() { +  return new PostDomOnlyViewer(); +} diff --git a/contrib/llvm/lib/Analysis/DominanceFrontier.cpp b/contrib/llvm/lib/Analysis/DominanceFrontier.cpp new file mode 100644 index 000000000000..de7f62cf4ecd --- /dev/null +++ b/contrib/llvm/lib/Analysis/DominanceFrontier.cpp @@ -0,0 +1,97 @@ +//===- DominanceFrontier.cpp - Dominance Frontier Calculation -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/DominanceFrontier.h" +#include "llvm/Analysis/DominanceFrontierImpl.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +namespace llvm { + +template class DominanceFrontierBase<BasicBlock, false>; +template class DominanceFrontierBase<BasicBlock, true>; +template class ForwardDominanceFrontierBase<BasicBlock>; + +} // end namespace llvm + +char DominanceFrontierWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(DominanceFrontierWrapperPass, "domfrontier", +                "Dominance Frontier Construction", true, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(DominanceFrontierWrapperPass, "domfrontier", +                "Dominance Frontier Construction", true, true) + +DominanceFrontierWrapperPass::DominanceFrontierWrapperPass() +    : FunctionPass(ID), DF() { +  initializeDominanceFrontierWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void DominanceFrontierWrapperPass::releaseMemory() { +  DF.releaseMemory(); +} + +bool DominanceFrontierWrapperPass::runOnFunction(Function &) { +  releaseMemory(); +  DF.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree()); +  return false; +} + +void DominanceFrontierWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<DominatorTreeWrapperPass>(); +} + +void DominanceFrontierWrapperPass::print(raw_ostream &OS, const Module *) const { +  DF.print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void DominanceFrontierWrapperPass::dump() const { +  print(dbgs()); +} +#endif + +/// Handle invalidation explicitly. +bool DominanceFrontier::invalidate(Function &F, const PreservedAnalyses &PA, +                                   FunctionAnalysisManager::Invalidator &) { +  // Check whether the analysis, all analyses on functions, or the function's +  // CFG have been preserved. +  auto PAC = PA.getChecker<DominanceFrontierAnalysis>(); +  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || +           PAC.preservedSet<CFGAnalyses>()); +} + +AnalysisKey DominanceFrontierAnalysis::Key; + +DominanceFrontier DominanceFrontierAnalysis::run(Function &F, +                                                 FunctionAnalysisManager &AM) { +  DominanceFrontier DF; +  DF.analyze(AM.getResult<DominatorTreeAnalysis>(F)); +  return DF; +} + +DominanceFrontierPrinterPass::DominanceFrontierPrinterPass(raw_ostream &OS) +  : OS(OS) {} + +PreservedAnalyses +DominanceFrontierPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { +  OS << "DominanceFrontier for function: " << F.getName() << "\n"; +  AM.getResult<DominanceFrontierAnalysis>(F).print(OS); + +  return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Analysis/EHPersonalities.cpp b/contrib/llvm/lib/Analysis/EHPersonalities.cpp new file mode 100644 index 000000000000..2d35a3fa9118 --- /dev/null +++ b/contrib/llvm/lib/Analysis/EHPersonalities.cpp @@ -0,0 +1,136 @@ +//===- EHPersonalities.cpp - Compute EH-related information ---------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/EHPersonalities.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +/// See if the given exception handling personality function is one that we +/// understand.  If so, return a description of it; otherwise return Unknown. +EHPersonality llvm::classifyEHPersonality(const Value *Pers) { +  const Function *F = +      Pers ? dyn_cast<Function>(Pers->stripPointerCasts()) : nullptr; +  if (!F) +    return EHPersonality::Unknown; +  return StringSwitch<EHPersonality>(F->getName()) +    .Case("__gnat_eh_personality",     EHPersonality::GNU_Ada) +    .Case("__gxx_personality_v0",      EHPersonality::GNU_CXX) +    .Case("__gxx_personality_seh0",    EHPersonality::GNU_CXX) +    .Case("__gxx_personality_sj0",     EHPersonality::GNU_CXX_SjLj) +    .Case("__gcc_personality_v0",      EHPersonality::GNU_C) +    .Case("__gcc_personality_seh0",    EHPersonality::GNU_C) +    .Case("__gcc_personality_sj0",     EHPersonality::GNU_C_SjLj) +    .Case("__objc_personality_v0",     EHPersonality::GNU_ObjC) +    .Case("_except_handler3",          EHPersonality::MSVC_X86SEH) +    .Case("_except_handler4",          EHPersonality::MSVC_X86SEH) +    .Case("__C_specific_handler",      EHPersonality::MSVC_Win64SEH) +    .Case("__CxxFrameHandler3",        EHPersonality::MSVC_CXX) +    .Case("ProcessCLRException",       EHPersonality::CoreCLR) +    .Case("rust_eh_personality",       EHPersonality::Rust) +    .Case("__gxx_wasm_personality_v0", EHPersonality::Wasm_CXX) +    .Default(EHPersonality::Unknown); +} + +StringRef llvm::getEHPersonalityName(EHPersonality Pers) { +  switch (Pers) { +  case EHPersonality::GNU_Ada:       return "__gnat_eh_personality"; +  case EHPersonality::GNU_CXX:       return "__gxx_personality_v0"; +  case EHPersonality::GNU_CXX_SjLj:  return "__gxx_personality_sj0"; +  case EHPersonality::GNU_C:         return "__gcc_personality_v0"; +  case EHPersonality::GNU_C_SjLj:    return "__gcc_personality_sj0"; +  case EHPersonality::GNU_ObjC:      return "__objc_personality_v0"; +  case EHPersonality::MSVC_X86SEH:   return "_except_handler3"; +  case EHPersonality::MSVC_Win64SEH: return "__C_specific_handler"; +  case EHPersonality::MSVC_CXX:      return "__CxxFrameHandler3"; +  case EHPersonality::CoreCLR:       return "ProcessCLRException"; +  case EHPersonality::Rust:          return "rust_eh_personality"; +  case EHPersonality::Wasm_CXX:      return "__gxx_wasm_personality_v0"; +  case EHPersonality::Unknown:       llvm_unreachable("Unknown EHPersonality!"); +  } + +  llvm_unreachable("Invalid EHPersonality!"); +} + +EHPersonality llvm::getDefaultEHPersonality(const Triple &T) { +  return EHPersonality::GNU_C; +} + +bool llvm::canSimplifyInvokeNoUnwind(const Function *F) { +  EHPersonality Personality = classifyEHPersonality(F->getPersonalityFn()); +  // We can't simplify any invokes to nounwind functions if the personality +  // function wants to catch asynch exceptions.  The nounwind attribute only +  // implies that the function does not throw synchronous exceptions. +  return !isAsynchronousEHPersonality(Personality); +} + +DenseMap<BasicBlock *, ColorVector> llvm::colorEHFunclets(Function &F) { +  SmallVector<std::pair<BasicBlock *, BasicBlock *>, 16> Worklist; +  BasicBlock *EntryBlock = &F.getEntryBlock(); +  DenseMap<BasicBlock *, ColorVector> BlockColors; + +  // Build up the color map, which maps each block to its set of 'colors'. +  // For any block B the "colors" of B are the set of funclets F (possibly +  // including a root "funclet" representing the main function) such that +  // F will need to directly contain B or a copy of B (where the term "directly +  // contain" is used to distinguish from being "transitively contained" in +  // a nested funclet). +  // +  // Note: Despite not being a funclet in the truest sense, a catchswitch is +  // considered to belong to its own funclet for the purposes of coloring. + +  DEBUG_WITH_TYPE("winehprepare-coloring", dbgs() << "\nColoring funclets for " +                                                  << F.getName() << "\n"); + +  Worklist.push_back({EntryBlock, EntryBlock}); + +  while (!Worklist.empty()) { +    BasicBlock *Visiting; +    BasicBlock *Color; +    std::tie(Visiting, Color) = Worklist.pop_back_val(); +    DEBUG_WITH_TYPE("winehprepare-coloring", +                    dbgs() << "Visiting " << Visiting->getName() << ", " +                           << Color->getName() << "\n"); +    Instruction *VisitingHead = Visiting->getFirstNonPHI(); +    if (VisitingHead->isEHPad()) { +      // Mark this funclet head as a member of itself. +      Color = Visiting; +    } +    // Note that this is a member of the given color. +    ColorVector &Colors = BlockColors[Visiting]; +    if (!is_contained(Colors, Color)) +      Colors.push_back(Color); +    else +      continue; + +    DEBUG_WITH_TYPE("winehprepare-coloring", +                    dbgs() << "  Assigned color \'" << Color->getName() +                           << "\' to block \'" << Visiting->getName() +                           << "\'.\n"); + +    BasicBlock *SuccColor = Color; +    TerminatorInst *Terminator = Visiting->getTerminator(); +    if (auto *CatchRet = dyn_cast<CatchReturnInst>(Terminator)) { +      Value *ParentPad = CatchRet->getCatchSwitchParentPad(); +      if (isa<ConstantTokenNone>(ParentPad)) +        SuccColor = EntryBlock; +      else +        SuccColor = cast<Instruction>(ParentPad)->getParent(); +    } + +    for (BasicBlock *Succ : successors(Visiting)) +      Worklist.push_back({Succ, SuccColor}); +  } +  return BlockColors; +} diff --git a/contrib/llvm/lib/Analysis/GlobalsModRef.cpp b/contrib/llvm/lib/Analysis/GlobalsModRef.cpp new file mode 100644 index 000000000000..2c503609d96b --- /dev/null +++ b/contrib/llvm/lib/Analysis/GlobalsModRef.cpp @@ -0,0 +1,1014 @@ +//===- GlobalsModRef.cpp - Simple Mod/Ref Analysis for Globals ------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This simple pass provides alias and mod/ref information for global values +// that do not have their address taken, and keeps track of whether functions +// read or write memory (are "pure").  For this simple (but very common) case, +// we can provide pretty accurate and useful information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +using namespace llvm; + +#define DEBUG_TYPE "globalsmodref-aa" + +STATISTIC(NumNonAddrTakenGlobalVars, +          "Number of global vars without address taken"); +STATISTIC(NumNonAddrTakenFunctions,"Number of functions without address taken"); +STATISTIC(NumNoMemFunctions, "Number of functions that do not access memory"); +STATISTIC(NumReadMemFunctions, "Number of functions that only read memory"); +STATISTIC(NumIndirectGlobalVars, "Number of indirect global objects"); + +// An option to enable unsafe alias results from the GlobalsModRef analysis. +// When enabled, GlobalsModRef will provide no-alias results which in extremely +// rare cases may not be conservatively correct. In particular, in the face of +// transforms which cause assymetry between how effective GetUnderlyingObject +// is for two pointers, it may produce incorrect results. +// +// These unsafe results have been returned by GMR for many years without +// causing significant issues in the wild and so we provide a mechanism to +// re-enable them for users of LLVM that have a particular performance +// sensitivity and no known issues. The option also makes it easy to evaluate +// the performance impact of these results. +static cl::opt<bool> EnableUnsafeGlobalsModRefAliasResults( +    "enable-unsafe-globalsmodref-alias-results", cl::init(false), cl::Hidden); + +/// The mod/ref information collected for a particular function. +/// +/// We collect information about mod/ref behavior of a function here, both in +/// general and as pertains to specific globals. We only have this detailed +/// information when we know *something* useful about the behavior. If we +/// saturate to fully general mod/ref, we remove the info for the function. +class GlobalsAAResult::FunctionInfo { +  typedef SmallDenseMap<const GlobalValue *, ModRefInfo, 16> GlobalInfoMapType; + +  /// Build a wrapper struct that has 8-byte alignment. All heap allocations +  /// should provide this much alignment at least, but this makes it clear we +  /// specifically rely on this amount of alignment. +  struct alignas(8) AlignedMap { +    AlignedMap() {} +    AlignedMap(const AlignedMap &Arg) : Map(Arg.Map) {} +    GlobalInfoMapType Map; +  }; + +  /// Pointer traits for our aligned map. +  struct AlignedMapPointerTraits { +    static inline void *getAsVoidPointer(AlignedMap *P) { return P; } +    static inline AlignedMap *getFromVoidPointer(void *P) { +      return (AlignedMap *)P; +    } +    enum { NumLowBitsAvailable = 3 }; +    static_assert(alignof(AlignedMap) >= (1 << NumLowBitsAvailable), +                  "AlignedMap insufficiently aligned to have enough low bits."); +  }; + +  /// The bit that flags that this function may read any global. This is +  /// chosen to mix together with ModRefInfo bits. +  /// FIXME: This assumes ModRefInfo lattice will remain 4 bits! +  /// It overlaps with ModRefInfo::Must bit! +  /// FunctionInfo.getModRefInfo() masks out everything except ModRef so +  /// this remains correct, but the Must info is lost. +  enum { MayReadAnyGlobal = 4 }; + +  /// Checks to document the invariants of the bit packing here. +  static_assert((MayReadAnyGlobal & static_cast<int>(ModRefInfo::MustModRef)) == +                    0, +                "ModRef and the MayReadAnyGlobal flag bits overlap."); +  static_assert(((MayReadAnyGlobal | +                  static_cast<int>(ModRefInfo::MustModRef)) >> +                 AlignedMapPointerTraits::NumLowBitsAvailable) == 0, +                "Insufficient low bits to store our flag and ModRef info."); + +public: +  FunctionInfo() : Info() {} +  ~FunctionInfo() { +    delete Info.getPointer(); +  } +  // Spell out the copy ond move constructors and assignment operators to get +  // deep copy semantics and correct move semantics in the face of the +  // pointer-int pair. +  FunctionInfo(const FunctionInfo &Arg) +      : Info(nullptr, Arg.Info.getInt()) { +    if (const auto *ArgPtr = Arg.Info.getPointer()) +      Info.setPointer(new AlignedMap(*ArgPtr)); +  } +  FunctionInfo(FunctionInfo &&Arg) +      : Info(Arg.Info.getPointer(), Arg.Info.getInt()) { +    Arg.Info.setPointerAndInt(nullptr, 0); +  } +  FunctionInfo &operator=(const FunctionInfo &RHS) { +    delete Info.getPointer(); +    Info.setPointerAndInt(nullptr, RHS.Info.getInt()); +    if (const auto *RHSPtr = RHS.Info.getPointer()) +      Info.setPointer(new AlignedMap(*RHSPtr)); +    return *this; +  } +  FunctionInfo &operator=(FunctionInfo &&RHS) { +    delete Info.getPointer(); +    Info.setPointerAndInt(RHS.Info.getPointer(), RHS.Info.getInt()); +    RHS.Info.setPointerAndInt(nullptr, 0); +    return *this; +  } + +  /// This method clears MayReadAnyGlobal bit added by GlobalsAAResult to return +  /// the corresponding ModRefInfo. It must align in functionality with +  /// clearMust(). +  ModRefInfo globalClearMayReadAnyGlobal(int I) const { +    return ModRefInfo((I & static_cast<int>(ModRefInfo::ModRef)) | +                      static_cast<int>(ModRefInfo::NoModRef)); +  } + +  /// Returns the \c ModRefInfo info for this function. +  ModRefInfo getModRefInfo() const { +    return globalClearMayReadAnyGlobal(Info.getInt()); +  } + +  /// Adds new \c ModRefInfo for this function to its state. +  void addModRefInfo(ModRefInfo NewMRI) { +    Info.setInt(Info.getInt() | static_cast<int>(setMust(NewMRI))); +  } + +  /// Returns whether this function may read any global variable, and we don't +  /// know which global. +  bool mayReadAnyGlobal() const { return Info.getInt() & MayReadAnyGlobal; } + +  /// Sets this function as potentially reading from any global. +  void setMayReadAnyGlobal() { Info.setInt(Info.getInt() | MayReadAnyGlobal); } + +  /// Returns the \c ModRefInfo info for this function w.r.t. a particular +  /// global, which may be more precise than the general information above. +  ModRefInfo getModRefInfoForGlobal(const GlobalValue &GV) const { +    ModRefInfo GlobalMRI = +        mayReadAnyGlobal() ? ModRefInfo::Ref : ModRefInfo::NoModRef; +    if (AlignedMap *P = Info.getPointer()) { +      auto I = P->Map.find(&GV); +      if (I != P->Map.end()) +        GlobalMRI = unionModRef(GlobalMRI, I->second); +    } +    return GlobalMRI; +  } + +  /// Add mod/ref info from another function into ours, saturating towards +  /// ModRef. +  void addFunctionInfo(const FunctionInfo &FI) { +    addModRefInfo(FI.getModRefInfo()); + +    if (FI.mayReadAnyGlobal()) +      setMayReadAnyGlobal(); + +    if (AlignedMap *P = FI.Info.getPointer()) +      for (const auto &G : P->Map) +        addModRefInfoForGlobal(*G.first, G.second); +  } + +  void addModRefInfoForGlobal(const GlobalValue &GV, ModRefInfo NewMRI) { +    AlignedMap *P = Info.getPointer(); +    if (!P) { +      P = new AlignedMap(); +      Info.setPointer(P); +    } +    auto &GlobalMRI = P->Map[&GV]; +    GlobalMRI = unionModRef(GlobalMRI, NewMRI); +  } + +  /// Clear a global's ModRef info. Should be used when a global is being +  /// deleted. +  void eraseModRefInfoForGlobal(const GlobalValue &GV) { +    if (AlignedMap *P = Info.getPointer()) +      P->Map.erase(&GV); +  } + +private: +  /// All of the information is encoded into a single pointer, with a three bit +  /// integer in the low three bits. The high bit provides a flag for when this +  /// function may read any global. The low two bits are the ModRefInfo. And +  /// the pointer, when non-null, points to a map from GlobalValue to +  /// ModRefInfo specific to that GlobalValue. +  PointerIntPair<AlignedMap *, 3, unsigned, AlignedMapPointerTraits> Info; +}; + +void GlobalsAAResult::DeletionCallbackHandle::deleted() { +  Value *V = getValPtr(); +  if (auto *F = dyn_cast<Function>(V)) +    GAR->FunctionInfos.erase(F); + +  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) { +    if (GAR->NonAddressTakenGlobals.erase(GV)) { +      // This global might be an indirect global.  If so, remove it and +      // remove any AllocRelatedValues for it. +      if (GAR->IndirectGlobals.erase(GV)) { +        // Remove any entries in AllocsForIndirectGlobals for this global. +        for (auto I = GAR->AllocsForIndirectGlobals.begin(), +                  E = GAR->AllocsForIndirectGlobals.end(); +             I != E; ++I) +          if (I->second == GV) +            GAR->AllocsForIndirectGlobals.erase(I); +      } + +      // Scan the function info we have collected and remove this global +      // from all of them. +      for (auto &FIPair : GAR->FunctionInfos) +        FIPair.second.eraseModRefInfoForGlobal(*GV); +    } +  } + +  // If this is an allocation related to an indirect global, remove it. +  GAR->AllocsForIndirectGlobals.erase(V); + +  // And clear out the handle. +  setValPtr(nullptr); +  GAR->Handles.erase(I); +  // This object is now destroyed! +} + +FunctionModRefBehavior GlobalsAAResult::getModRefBehavior(const Function *F) { +  FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + +  if (FunctionInfo *FI = getFunctionInfo(F)) { +    if (!isModOrRefSet(FI->getModRefInfo())) +      Min = FMRB_DoesNotAccessMemory; +    else if (!isModSet(FI->getModRefInfo())) +      Min = FMRB_OnlyReadsMemory; +  } + +  return FunctionModRefBehavior(AAResultBase::getModRefBehavior(F) & Min); +} + +FunctionModRefBehavior +GlobalsAAResult::getModRefBehavior(ImmutableCallSite CS) { +  FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + +  if (!CS.hasOperandBundles()) +    if (const Function *F = CS.getCalledFunction()) +      if (FunctionInfo *FI = getFunctionInfo(F)) { +        if (!isModOrRefSet(FI->getModRefInfo())) +          Min = FMRB_DoesNotAccessMemory; +        else if (!isModSet(FI->getModRefInfo())) +          Min = FMRB_OnlyReadsMemory; +      } + +  return FunctionModRefBehavior(AAResultBase::getModRefBehavior(CS) & Min); +} + +/// Returns the function info for the function, or null if we don't have +/// anything useful to say about it. +GlobalsAAResult::FunctionInfo * +GlobalsAAResult::getFunctionInfo(const Function *F) { +  auto I = FunctionInfos.find(F); +  if (I != FunctionInfos.end()) +    return &I->second; +  return nullptr; +} + +/// AnalyzeGlobals - Scan through the users of all of the internal +/// GlobalValue's in the program.  If none of them have their "address taken" +/// (really, their address passed to something nontrivial), record this fact, +/// and record the functions that they are used directly in. +void GlobalsAAResult::AnalyzeGlobals(Module &M) { +  SmallPtrSet<Function *, 32> TrackedFunctions; +  for (Function &F : M) +    if (F.hasLocalLinkage()) +      if (!AnalyzeUsesOfPointer(&F)) { +        // Remember that we are tracking this global. +        NonAddressTakenGlobals.insert(&F); +        TrackedFunctions.insert(&F); +        Handles.emplace_front(*this, &F); +        Handles.front().I = Handles.begin(); +        ++NumNonAddrTakenFunctions; +      } + +  SmallPtrSet<Function *, 16> Readers, Writers; +  for (GlobalVariable &GV : M.globals()) +    if (GV.hasLocalLinkage()) { +      if (!AnalyzeUsesOfPointer(&GV, &Readers, +                                GV.isConstant() ? nullptr : &Writers)) { +        // Remember that we are tracking this global, and the mod/ref fns +        NonAddressTakenGlobals.insert(&GV); +        Handles.emplace_front(*this, &GV); +        Handles.front().I = Handles.begin(); + +        for (Function *Reader : Readers) { +          if (TrackedFunctions.insert(Reader).second) { +            Handles.emplace_front(*this, Reader); +            Handles.front().I = Handles.begin(); +          } +          FunctionInfos[Reader].addModRefInfoForGlobal(GV, ModRefInfo::Ref); +        } + +        if (!GV.isConstant()) // No need to keep track of writers to constants +          for (Function *Writer : Writers) { +            if (TrackedFunctions.insert(Writer).second) { +              Handles.emplace_front(*this, Writer); +              Handles.front().I = Handles.begin(); +            } +            FunctionInfos[Writer].addModRefInfoForGlobal(GV, ModRefInfo::Mod); +          } +        ++NumNonAddrTakenGlobalVars; + +        // If this global holds a pointer type, see if it is an indirect global. +        if (GV.getValueType()->isPointerTy() && +            AnalyzeIndirectGlobalMemory(&GV)) +          ++NumIndirectGlobalVars; +      } +      Readers.clear(); +      Writers.clear(); +    } +} + +/// AnalyzeUsesOfPointer - Look at all of the users of the specified pointer. +/// If this is used by anything complex (i.e., the address escapes), return +/// true.  Also, while we are at it, keep track of those functions that read and +/// write to the value. +/// +/// If OkayStoreDest is non-null, stores into this global are allowed. +bool GlobalsAAResult::AnalyzeUsesOfPointer(Value *V, +                                           SmallPtrSetImpl<Function *> *Readers, +                                           SmallPtrSetImpl<Function *> *Writers, +                                           GlobalValue *OkayStoreDest) { +  if (!V->getType()->isPointerTy()) +    return true; + +  for (Use &U : V->uses()) { +    User *I = U.getUser(); +    if (LoadInst *LI = dyn_cast<LoadInst>(I)) { +      if (Readers) +        Readers->insert(LI->getParent()->getParent()); +    } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { +      if (V == SI->getOperand(1)) { +        if (Writers) +          Writers->insert(SI->getParent()->getParent()); +      } else if (SI->getOperand(1) != OkayStoreDest) { +        return true; // Storing the pointer +      } +    } else if (Operator::getOpcode(I) == Instruction::GetElementPtr) { +      if (AnalyzeUsesOfPointer(I, Readers, Writers)) +        return true; +    } else if (Operator::getOpcode(I) == Instruction::BitCast) { +      if (AnalyzeUsesOfPointer(I, Readers, Writers, OkayStoreDest)) +        return true; +    } else if (auto CS = CallSite(I)) { +      // Make sure that this is just the function being called, not that it is +      // passing into the function. +      if (CS.isDataOperand(&U)) { +        // Detect calls to free. +        if (CS.isArgOperand(&U) && isFreeCall(I, &TLI)) { +          if (Writers) +            Writers->insert(CS->getParent()->getParent()); +        } else { +          return true; // Argument of an unknown call. +        } +      } +    } else if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) { +      if (!isa<ConstantPointerNull>(ICI->getOperand(1))) +        return true; // Allow comparison against null. +    } else if (Constant *C = dyn_cast<Constant>(I)) { +      // Ignore constants which don't have any live uses. +      if (isa<GlobalValue>(C) || C->isConstantUsed()) +        return true; +    } else { +      return true; +    } +  } + +  return false; +} + +/// AnalyzeIndirectGlobalMemory - We found an non-address-taken global variable +/// which holds a pointer type.  See if the global always points to non-aliased +/// heap memory: that is, all initializers of the globals are allocations, and +/// those allocations have no use other than initialization of the global. +/// Further, all loads out of GV must directly use the memory, not store the +/// pointer somewhere.  If this is true, we consider the memory pointed to by +/// GV to be owned by GV and can disambiguate other pointers from it. +bool GlobalsAAResult::AnalyzeIndirectGlobalMemory(GlobalVariable *GV) { +  // Keep track of values related to the allocation of the memory, f.e. the +  // value produced by the malloc call and any casts. +  std::vector<Value *> AllocRelatedValues; + +  // If the initializer is a valid pointer, bail. +  if (Constant *C = GV->getInitializer()) +    if (!C->isNullValue()) +      return false; + +  // Walk the user list of the global.  If we find anything other than a direct +  // load or store, bail out. +  for (User *U : GV->users()) { +    if (LoadInst *LI = dyn_cast<LoadInst>(U)) { +      // The pointer loaded from the global can only be used in simple ways: +      // we allow addressing of it and loading storing to it.  We do *not* allow +      // storing the loaded pointer somewhere else or passing to a function. +      if (AnalyzeUsesOfPointer(LI)) +        return false; // Loaded pointer escapes. +      // TODO: Could try some IP mod/ref of the loaded pointer. +    } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) { +      // Storing the global itself. +      if (SI->getOperand(0) == GV) +        return false; + +      // If storing the null pointer, ignore it. +      if (isa<ConstantPointerNull>(SI->getOperand(0))) +        continue; + +      // Check the value being stored. +      Value *Ptr = GetUnderlyingObject(SI->getOperand(0), +                                       GV->getParent()->getDataLayout()); + +      if (!isAllocLikeFn(Ptr, &TLI)) +        return false; // Too hard to analyze. + +      // Analyze all uses of the allocation.  If any of them are used in a +      // non-simple way (e.g. stored to another global) bail out. +      if (AnalyzeUsesOfPointer(Ptr, /*Readers*/ nullptr, /*Writers*/ nullptr, +                               GV)) +        return false; // Loaded pointer escapes. + +      // Remember that this allocation is related to the indirect global. +      AllocRelatedValues.push_back(Ptr); +    } else { +      // Something complex, bail out. +      return false; +    } +  } + +  // Okay, this is an indirect global.  Remember all of the allocations for +  // this global in AllocsForIndirectGlobals. +  while (!AllocRelatedValues.empty()) { +    AllocsForIndirectGlobals[AllocRelatedValues.back()] = GV; +    Handles.emplace_front(*this, AllocRelatedValues.back()); +    Handles.front().I = Handles.begin(); +    AllocRelatedValues.pop_back(); +  } +  IndirectGlobals.insert(GV); +  Handles.emplace_front(*this, GV); +  Handles.front().I = Handles.begin(); +  return true; +} + +void GlobalsAAResult::CollectSCCMembership(CallGraph &CG) { +  // We do a bottom-up SCC traversal of the call graph.  In other words, we +  // visit all callees before callers (leaf-first). +  unsigned SCCID = 0; +  for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) { +    const std::vector<CallGraphNode *> &SCC = *I; +    assert(!SCC.empty() && "SCC with no functions?"); + +    for (auto *CGN : SCC) +      if (Function *F = CGN->getFunction()) +        FunctionToSCCMap[F] = SCCID; +    ++SCCID; +  } +} + +/// AnalyzeCallGraph - At this point, we know the functions where globals are +/// immediately stored to and read from.  Propagate this information up the call +/// graph to all callers and compute the mod/ref info for all memory for each +/// function. +void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { +  // We do a bottom-up SCC traversal of the call graph.  In other words, we +  // visit all callees before callers (leaf-first). +  for (scc_iterator<CallGraph *> I = scc_begin(&CG); !I.isAtEnd(); ++I) { +    const std::vector<CallGraphNode *> &SCC = *I; +    assert(!SCC.empty() && "SCC with no functions?"); + +    Function *F = SCC[0]->getFunction(); + +    if (!F || !F->isDefinitionExact()) { +      // Calls externally or not exact - can't say anything useful. Remove any +      // existing function records (may have been created when scanning +      // globals). +      for (auto *Node : SCC) +        FunctionInfos.erase(Node->getFunction()); +      continue; +    } + +    FunctionInfo &FI = FunctionInfos[F]; +    Handles.emplace_front(*this, F); +    Handles.front().I = Handles.begin(); +    bool KnowNothing = false; + +    // Collect the mod/ref properties due to called functions.  We only compute +    // one mod-ref set. +    for (unsigned i = 0, e = SCC.size(); i != e && !KnowNothing; ++i) { +      if (!F) { +        KnowNothing = true; +        break; +      } + +      if (F->isDeclaration() || F->hasFnAttribute(Attribute::OptimizeNone)) { +        // Try to get mod/ref behaviour from function attributes. +        if (F->doesNotAccessMemory()) { +          // Can't do better than that! +        } else if (F->onlyReadsMemory()) { +          FI.addModRefInfo(ModRefInfo::Ref); +          if (!F->isIntrinsic() && !F->onlyAccessesArgMemory()) +            // This function might call back into the module and read a global - +            // consider every global as possibly being read by this function. +            FI.setMayReadAnyGlobal(); +        } else { +          FI.addModRefInfo(ModRefInfo::ModRef); +          // Can't say anything useful unless it's an intrinsic - they don't +          // read or write global variables of the kind considered here. +          KnowNothing = !F->isIntrinsic(); +        } +        continue; +      } + +      for (CallGraphNode::iterator CI = SCC[i]->begin(), E = SCC[i]->end(); +           CI != E && !KnowNothing; ++CI) +        if (Function *Callee = CI->second->getFunction()) { +          if (FunctionInfo *CalleeFI = getFunctionInfo(Callee)) { +            // Propagate function effect up. +            FI.addFunctionInfo(*CalleeFI); +          } else { +            // Can't say anything about it.  However, if it is inside our SCC, +            // then nothing needs to be done. +            CallGraphNode *CalleeNode = CG[Callee]; +            if (!is_contained(SCC, CalleeNode)) +              KnowNothing = true; +          } +        } else { +          KnowNothing = true; +        } +    } + +    // If we can't say anything useful about this SCC, remove all SCC functions +    // from the FunctionInfos map. +    if (KnowNothing) { +      for (auto *Node : SCC) +        FunctionInfos.erase(Node->getFunction()); +      continue; +    } + +    // Scan the function bodies for explicit loads or stores. +    for (auto *Node : SCC) { +      if (isModAndRefSet(FI.getModRefInfo())) +        break; // The mod/ref lattice saturates here. + +      // Don't prove any properties based on the implementation of an optnone +      // function. Function attributes were already used as a best approximation +      // above. +      if (Node->getFunction()->hasFnAttribute(Attribute::OptimizeNone)) +        continue; + +      for (Instruction &I : instructions(Node->getFunction())) { +        if (isModAndRefSet(FI.getModRefInfo())) +          break; // The mod/ref lattice saturates here. + +        // We handle calls specially because the graph-relevant aspects are +        // handled above. +        if (auto CS = CallSite(&I)) { +          if (isAllocationFn(&I, &TLI) || isFreeCall(&I, &TLI)) { +            // FIXME: It is completely unclear why this is necessary and not +            // handled by the above graph code. +            FI.addModRefInfo(ModRefInfo::ModRef); +          } else if (Function *Callee = CS.getCalledFunction()) { +            // The callgraph doesn't include intrinsic calls. +            if (Callee->isIntrinsic()) { +              if (isa<DbgInfoIntrinsic>(I)) +                // Don't let dbg intrinsics affect alias info. +                continue; + +              FunctionModRefBehavior Behaviour = +                  AAResultBase::getModRefBehavior(Callee); +              FI.addModRefInfo(createModRefInfo(Behaviour)); +            } +          } +          continue; +        } + +        // All non-call instructions we use the primary predicates for whether +        // thay read or write memory. +        if (I.mayReadFromMemory()) +          FI.addModRefInfo(ModRefInfo::Ref); +        if (I.mayWriteToMemory()) +          FI.addModRefInfo(ModRefInfo::Mod); +      } +    } + +    if (!isModSet(FI.getModRefInfo())) +      ++NumReadMemFunctions; +    if (!isModOrRefSet(FI.getModRefInfo())) +      ++NumNoMemFunctions; + +    // Finally, now that we know the full effect on this SCC, clone the +    // information to each function in the SCC. +    // FI is a reference into FunctionInfos, so copy it now so that it doesn't +    // get invalidated if DenseMap decides to re-hash. +    FunctionInfo CachedFI = FI; +    for (unsigned i = 1, e = SCC.size(); i != e; ++i) +      FunctionInfos[SCC[i]->getFunction()] = CachedFI; +  } +} + +// GV is a non-escaping global. V is a pointer address that has been loaded from. +// If we can prove that V must escape, we can conclude that a load from V cannot +// alias GV. +static bool isNonEscapingGlobalNoAliasWithLoad(const GlobalValue *GV, +                                               const Value *V, +                                               int &Depth, +                                               const DataLayout &DL) { +  SmallPtrSet<const Value *, 8> Visited; +  SmallVector<const Value *, 8> Inputs; +  Visited.insert(V); +  Inputs.push_back(V); +  do { +    const Value *Input = Inputs.pop_back_val(); + +    if (isa<GlobalValue>(Input) || isa<Argument>(Input) || isa<CallInst>(Input) || +        isa<InvokeInst>(Input)) +      // Arguments to functions or returns from functions are inherently +      // escaping, so we can immediately classify those as not aliasing any +      // non-addr-taken globals. +      // +      // (Transitive) loads from a global are also safe - if this aliased +      // another global, its address would escape, so no alias. +      continue; + +    // Recurse through a limited number of selects, loads and PHIs. This is an +    // arbitrary depth of 4, lower numbers could be used to fix compile time +    // issues if needed, but this is generally expected to be only be important +    // for small depths. +    if (++Depth > 4) +      return false; + +    if (auto *LI = dyn_cast<LoadInst>(Input)) { +      Inputs.push_back(GetUnderlyingObject(LI->getPointerOperand(), DL)); +      continue; +    } +    if (auto *SI = dyn_cast<SelectInst>(Input)) { +      const Value *LHS = GetUnderlyingObject(SI->getTrueValue(), DL); +      const Value *RHS = GetUnderlyingObject(SI->getFalseValue(), DL); +      if (Visited.insert(LHS).second) +        Inputs.push_back(LHS); +      if (Visited.insert(RHS).second) +        Inputs.push_back(RHS); +      continue; +    } +    if (auto *PN = dyn_cast<PHINode>(Input)) { +      for (const Value *Op : PN->incoming_values()) { +        Op = GetUnderlyingObject(Op, DL); +        if (Visited.insert(Op).second) +          Inputs.push_back(Op); +      } +      continue; +    } + +    return false; +  } while (!Inputs.empty()); + +  // All inputs were known to be no-alias. +  return true; +} + +// There are particular cases where we can conclude no-alias between +// a non-addr-taken global and some other underlying object. Specifically, +// a non-addr-taken global is known to not be escaped from any function. It is +// also incorrect for a transformation to introduce an escape of a global in +// a way that is observable when it was not there previously. One function +// being transformed to introduce an escape which could possibly be observed +// (via loading from a global or the return value for example) within another +// function is never safe. If the observation is made through non-atomic +// operations on different threads, it is a data-race and UB. If the +// observation is well defined, by being observed the transformation would have +// changed program behavior by introducing the observed escape, making it an +// invalid transform. +// +// This property does require that transformations which *temporarily* escape +// a global that was not previously escaped, prior to restoring it, cannot rely +// on the results of GMR::alias. This seems a reasonable restriction, although +// currently there is no way to enforce it. There is also no realistic +// optimization pass that would make this mistake. The closest example is +// a transformation pass which does reg2mem of SSA values but stores them into +// global variables temporarily before restoring the global variable's value. +// This could be useful to expose "benign" races for example. However, it seems +// reasonable to require that a pass which introduces escapes of global +// variables in this way to either not trust AA results while the escape is +// active, or to be forced to operate as a module pass that cannot co-exist +// with an alias analysis such as GMR. +bool GlobalsAAResult::isNonEscapingGlobalNoAlias(const GlobalValue *GV, +                                                 const Value *V) { +  // In order to know that the underlying object cannot alias the +  // non-addr-taken global, we must know that it would have to be an escape. +  // Thus if the underlying object is a function argument, a load from +  // a global, or the return of a function, it cannot alias. We can also +  // recurse through PHI nodes and select nodes provided all of their inputs +  // resolve to one of these known-escaping roots. +  SmallPtrSet<const Value *, 8> Visited; +  SmallVector<const Value *, 8> Inputs; +  Visited.insert(V); +  Inputs.push_back(V); +  int Depth = 0; +  do { +    const Value *Input = Inputs.pop_back_val(); + +    if (auto *InputGV = dyn_cast<GlobalValue>(Input)) { +      // If one input is the very global we're querying against, then we can't +      // conclude no-alias. +      if (InputGV == GV) +        return false; + +      // Distinct GlobalVariables never alias, unless overriden or zero-sized. +      // FIXME: The condition can be refined, but be conservative for now. +      auto *GVar = dyn_cast<GlobalVariable>(GV); +      auto *InputGVar = dyn_cast<GlobalVariable>(InputGV); +      if (GVar && InputGVar && +          !GVar->isDeclaration() && !InputGVar->isDeclaration() && +          !GVar->isInterposable() && !InputGVar->isInterposable()) { +        Type *GVType = GVar->getInitializer()->getType(); +        Type *InputGVType = InputGVar->getInitializer()->getType(); +        if (GVType->isSized() && InputGVType->isSized() && +            (DL.getTypeAllocSize(GVType) > 0) && +            (DL.getTypeAllocSize(InputGVType) > 0)) +          continue; +      } + +      // Conservatively return false, even though we could be smarter +      // (e.g. look through GlobalAliases). +      return false; +    } + +    if (isa<Argument>(Input) || isa<CallInst>(Input) || +        isa<InvokeInst>(Input)) { +      // Arguments to functions or returns from functions are inherently +      // escaping, so we can immediately classify those as not aliasing any +      // non-addr-taken globals. +      continue; +    } + +    // Recurse through a limited number of selects, loads and PHIs. This is an +    // arbitrary depth of 4, lower numbers could be used to fix compile time +    // issues if needed, but this is generally expected to be only be important +    // for small depths. +    if (++Depth > 4) +      return false; + +    if (auto *LI = dyn_cast<LoadInst>(Input)) { +      // A pointer loaded from a global would have been captured, and we know +      // that the global is non-escaping, so no alias. +      const Value *Ptr = GetUnderlyingObject(LI->getPointerOperand(), DL); +      if (isNonEscapingGlobalNoAliasWithLoad(GV, Ptr, Depth, DL)) +        // The load does not alias with GV. +        continue; +      // Otherwise, a load could come from anywhere, so bail. +      return false; +    } +    if (auto *SI = dyn_cast<SelectInst>(Input)) { +      const Value *LHS = GetUnderlyingObject(SI->getTrueValue(), DL); +      const Value *RHS = GetUnderlyingObject(SI->getFalseValue(), DL); +      if (Visited.insert(LHS).second) +        Inputs.push_back(LHS); +      if (Visited.insert(RHS).second) +        Inputs.push_back(RHS); +      continue; +    } +    if (auto *PN = dyn_cast<PHINode>(Input)) { +      for (const Value *Op : PN->incoming_values()) { +        Op = GetUnderlyingObject(Op, DL); +        if (Visited.insert(Op).second) +          Inputs.push_back(Op); +      } +      continue; +    } + +    // FIXME: It would be good to handle other obvious no-alias cases here, but +    // it isn't clear how to do so reasonbly without building a small version +    // of BasicAA into this code. We could recurse into AAResultBase::alias +    // here but that seems likely to go poorly as we're inside the +    // implementation of such a query. Until then, just conservatievly retun +    // false. +    return false; +  } while (!Inputs.empty()); + +  // If all the inputs to V were definitively no-alias, then V is no-alias. +  return true; +} + +/// alias - If one of the pointers is to a global that we are tracking, and the +/// other is some random pointer, we know there cannot be an alias, because the +/// address of the global isn't taken. +AliasResult GlobalsAAResult::alias(const MemoryLocation &LocA, +                                   const MemoryLocation &LocB) { +  // Get the base object these pointers point to. +  const Value *UV1 = GetUnderlyingObject(LocA.Ptr, DL); +  const Value *UV2 = GetUnderlyingObject(LocB.Ptr, DL); + +  // If either of the underlying values is a global, they may be non-addr-taken +  // globals, which we can answer queries about. +  const GlobalValue *GV1 = dyn_cast<GlobalValue>(UV1); +  const GlobalValue *GV2 = dyn_cast<GlobalValue>(UV2); +  if (GV1 || GV2) { +    // If the global's address is taken, pretend we don't know it's a pointer to +    // the global. +    if (GV1 && !NonAddressTakenGlobals.count(GV1)) +      GV1 = nullptr; +    if (GV2 && !NonAddressTakenGlobals.count(GV2)) +      GV2 = nullptr; + +    // If the two pointers are derived from two different non-addr-taken +    // globals we know these can't alias. +    if (GV1 && GV2 && GV1 != GV2) +      return NoAlias; + +    // If one is and the other isn't, it isn't strictly safe but we can fake +    // this result if necessary for performance. This does not appear to be +    // a common problem in practice. +    if (EnableUnsafeGlobalsModRefAliasResults) +      if ((GV1 || GV2) && GV1 != GV2) +        return NoAlias; + +    // Check for a special case where a non-escaping global can be used to +    // conclude no-alias. +    if ((GV1 || GV2) && GV1 != GV2) { +      const GlobalValue *GV = GV1 ? GV1 : GV2; +      const Value *UV = GV1 ? UV2 : UV1; +      if (isNonEscapingGlobalNoAlias(GV, UV)) +        return NoAlias; +    } + +    // Otherwise if they are both derived from the same addr-taken global, we +    // can't know the two accesses don't overlap. +  } + +  // These pointers may be based on the memory owned by an indirect global.  If +  // so, we may be able to handle this.  First check to see if the base pointer +  // is a direct load from an indirect global. +  GV1 = GV2 = nullptr; +  if (const LoadInst *LI = dyn_cast<LoadInst>(UV1)) +    if (GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0))) +      if (IndirectGlobals.count(GV)) +        GV1 = GV; +  if (const LoadInst *LI = dyn_cast<LoadInst>(UV2)) +    if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getOperand(0))) +      if (IndirectGlobals.count(GV)) +        GV2 = GV; + +  // These pointers may also be from an allocation for the indirect global.  If +  // so, also handle them. +  if (!GV1) +    GV1 = AllocsForIndirectGlobals.lookup(UV1); +  if (!GV2) +    GV2 = AllocsForIndirectGlobals.lookup(UV2); + +  // Now that we know whether the two pointers are related to indirect globals, +  // use this to disambiguate the pointers. If the pointers are based on +  // different indirect globals they cannot alias. +  if (GV1 && GV2 && GV1 != GV2) +    return NoAlias; + +  // If one is based on an indirect global and the other isn't, it isn't +  // strictly safe but we can fake this result if necessary for performance. +  // This does not appear to be a common problem in practice. +  if (EnableUnsafeGlobalsModRefAliasResults) +    if ((GV1 || GV2) && GV1 != GV2) +      return NoAlias; + +  return AAResultBase::alias(LocA, LocB); +} + +ModRefInfo GlobalsAAResult::getModRefInfoForArgument(ImmutableCallSite CS, +                                                     const GlobalValue *GV) { +  if (CS.doesNotAccessMemory()) +    return ModRefInfo::NoModRef; +  ModRefInfo ConservativeResult = +      CS.onlyReadsMemory() ? ModRefInfo::Ref : ModRefInfo::ModRef; + +  // Iterate through all the arguments to the called function. If any argument +  // is based on GV, return the conservative result. +  for (auto &A : CS.args()) { +    SmallVector<Value*, 4> Objects; +    GetUnderlyingObjects(A, Objects, DL); + +    // All objects must be identified. +    if (!all_of(Objects, isIdentifiedObject) && +        // Try ::alias to see if all objects are known not to alias GV. +        !all_of(Objects, [&](Value *V) { +          return this->alias(MemoryLocation(V), MemoryLocation(GV)) == NoAlias; +        })) +      return ConservativeResult; + +    if (is_contained(Objects, GV)) +      return ConservativeResult; +  } + +  // We identified all objects in the argument list, and none of them were GV. +  return ModRefInfo::NoModRef; +} + +ModRefInfo GlobalsAAResult::getModRefInfo(ImmutableCallSite CS, +                                          const MemoryLocation &Loc) { +  ModRefInfo Known = ModRefInfo::ModRef; + +  // If we are asking for mod/ref info of a direct call with a pointer to a +  // global we are tracking, return information if we have it. +  if (const GlobalValue *GV = +          dyn_cast<GlobalValue>(GetUnderlyingObject(Loc.Ptr, DL))) +    if (GV->hasLocalLinkage()) +      if (const Function *F = CS.getCalledFunction()) +        if (NonAddressTakenGlobals.count(GV)) +          if (const FunctionInfo *FI = getFunctionInfo(F)) +            Known = unionModRef(FI->getModRefInfoForGlobal(*GV), +                                getModRefInfoForArgument(CS, GV)); + +  if (!isModOrRefSet(Known)) +    return ModRefInfo::NoModRef; // No need to query other mod/ref analyses +  return intersectModRef(Known, AAResultBase::getModRefInfo(CS, Loc)); +} + +GlobalsAAResult::GlobalsAAResult(const DataLayout &DL, +                                 const TargetLibraryInfo &TLI) +    : AAResultBase(), DL(DL), TLI(TLI) {} + +GlobalsAAResult::GlobalsAAResult(GlobalsAAResult &&Arg) +    : AAResultBase(std::move(Arg)), DL(Arg.DL), TLI(Arg.TLI), +      NonAddressTakenGlobals(std::move(Arg.NonAddressTakenGlobals)), +      IndirectGlobals(std::move(Arg.IndirectGlobals)), +      AllocsForIndirectGlobals(std::move(Arg.AllocsForIndirectGlobals)), +      FunctionInfos(std::move(Arg.FunctionInfos)), +      Handles(std::move(Arg.Handles)) { +  // Update the parent for each DeletionCallbackHandle. +  for (auto &H : Handles) { +    assert(H.GAR == &Arg); +    H.GAR = this; +  } +} + +GlobalsAAResult::~GlobalsAAResult() {} + +/*static*/ GlobalsAAResult +GlobalsAAResult::analyzeModule(Module &M, const TargetLibraryInfo &TLI, +                               CallGraph &CG) { +  GlobalsAAResult Result(M.getDataLayout(), TLI); + +  // Discover which functions aren't recursive, to feed into AnalyzeGlobals. +  Result.CollectSCCMembership(CG); + +  // Find non-addr taken globals. +  Result.AnalyzeGlobals(M); + +  // Propagate on CG. +  Result.AnalyzeCallGraph(CG, M); + +  return Result; +} + +AnalysisKey GlobalsAA::Key; + +GlobalsAAResult GlobalsAA::run(Module &M, ModuleAnalysisManager &AM) { +  return GlobalsAAResult::analyzeModule(M, +                                        AM.getResult<TargetLibraryAnalysis>(M), +                                        AM.getResult<CallGraphAnalysis>(M)); +} + +char GlobalsAAWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(GlobalsAAWrapperPass, "globals-aa", +                      "Globals Alias Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(GlobalsAAWrapperPass, "globals-aa", +                    "Globals Alias Analysis", false, true) + +ModulePass *llvm::createGlobalsAAWrapperPass() { +  return new GlobalsAAWrapperPass(); +} + +GlobalsAAWrapperPass::GlobalsAAWrapperPass() : ModulePass(ID) { +  initializeGlobalsAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool GlobalsAAWrapperPass::runOnModule(Module &M) { +  Result.reset(new GlobalsAAResult(GlobalsAAResult::analyzeModule( +      M, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), +      getAnalysis<CallGraphWrapperPass>().getCallGraph()))); +  return false; +} + +bool GlobalsAAWrapperPass::doFinalization(Module &M) { +  Result.reset(); +  return false; +} + +void GlobalsAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<CallGraphWrapperPass>(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/contrib/llvm/lib/Analysis/IVUsers.cpp b/contrib/llvm/lib/Analysis/IVUsers.cpp new file mode 100644 index 000000000000..609e5e3a1448 --- /dev/null +++ b/contrib/llvm/lib/Analysis/IVUsers.cpp @@ -0,0 +1,427 @@ +//===- IVUsers.cpp - Induction Variable Users -------------------*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements bookkeeping for "interesting" users of expressions +// computed from induction variables. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IVUsers.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +using namespace llvm; + +#define DEBUG_TYPE "iv-users" + +AnalysisKey IVUsersAnalysis::Key; + +IVUsers IVUsersAnalysis::run(Loop &L, LoopAnalysisManager &AM, +                             LoopStandardAnalysisResults &AR) { +  return IVUsers(&L, &AR.AC, &AR.LI, &AR.DT, &AR.SE); +} + +char IVUsersWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(IVUsersWrapperPass, "iv-users", +                      "Induction Variable Users", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(IVUsersWrapperPass, "iv-users", "Induction Variable Users", +                    false, true) + +Pass *llvm::createIVUsersPass() { return new IVUsersWrapperPass(); } + +/// isInteresting - Test whether the given expression is "interesting" when +/// used by the given expression, within the context of analyzing the +/// given loop. +static bool isInteresting(const SCEV *S, const Instruction *I, const Loop *L, +                          ScalarEvolution *SE, LoopInfo *LI) { +  // An addrec is interesting if it's affine or if it has an interesting start. +  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { +    // Keep things simple. Don't touch loop-variant strides unless they're +    // only used outside the loop and we can simplify them. +    if (AR->getLoop() == L) +      return AR->isAffine() || +             (!L->contains(I) && +              SE->getSCEVAtScope(AR, LI->getLoopFor(I->getParent())) != AR); +    // Otherwise recurse to see if the start value is interesting, and that +    // the step value is not interesting, since we don't yet know how to +    // do effective SCEV expansions for addrecs with interesting steps. +    return isInteresting(AR->getStart(), I, L, SE, LI) && +          !isInteresting(AR->getStepRecurrence(*SE), I, L, SE, LI); +  } + +  // An add is interesting if exactly one of its operands is interesting. +  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { +    bool AnyInterestingYet = false; +    for (const auto *Op : Add->operands()) +      if (isInteresting(Op, I, L, SE, LI)) { +        if (AnyInterestingYet) +          return false; +        AnyInterestingYet = true; +      } +    return AnyInterestingYet; +  } + +  // Nothing else is interesting here. +  return false; +} + +/// Return true if all loop headers that dominate this block are in simplified +/// form. +static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT, +                                 const LoopInfo *LI, +                                 SmallPtrSetImpl<Loop*> &SimpleLoopNests) { +  Loop *NearestLoop = nullptr; +  for (DomTreeNode *Rung = DT->getNode(BB); +       Rung; Rung = Rung->getIDom()) { +    BasicBlock *DomBB = Rung->getBlock(); +    Loop *DomLoop = LI->getLoopFor(DomBB); +    if (DomLoop && DomLoop->getHeader() == DomBB) { +      // If the domtree walk reaches a loop with no preheader, return false. +      if (!DomLoop->isLoopSimplifyForm()) +        return false; +      // If we have already checked this loop nest, stop checking. +      if (SimpleLoopNests.count(DomLoop)) +        break; +      // If we have not already checked this loop nest, remember the loop +      // header nearest to BB. The nearest loop may not contain BB. +      if (!NearestLoop) +        NearestLoop = DomLoop; +    } +  } +  if (NearestLoop) +    SimpleLoopNests.insert(NearestLoop); +  return true; +} + +/// IVUseShouldUsePostIncValue - We have discovered a "User" of an IV expression +/// and now we need to decide whether the user should use the preinc or post-inc +/// value.  If this user should use the post-inc version of the IV, return true. +/// +/// Choosing wrong here can break dominance properties (if we choose to use the +/// post-inc value when we cannot) or it can end up adding extra live-ranges to +/// the loop, resulting in reg-reg copies (if we use the pre-inc value when we +/// should use the post-inc value). +static bool IVUseShouldUsePostIncValue(Instruction *User, Value *Operand, +                                       const Loop *L, DominatorTree *DT) { +  // If the user is in the loop, use the preinc value. +  if (L->contains(User)) +    return false; + +  BasicBlock *LatchBlock = L->getLoopLatch(); +  if (!LatchBlock) +    return false; + +  // Ok, the user is outside of the loop.  If it is dominated by the latch +  // block, use the post-inc value. +  if (DT->dominates(LatchBlock, User->getParent())) +    return true; + +  // There is one case we have to be careful of: PHI nodes.  These little guys +  // can live in blocks that are not dominated by the latch block, but (since +  // their uses occur in the predecessor block, not the block the PHI lives in) +  // should still use the post-inc value.  Check for this case now. +  PHINode *PN = dyn_cast<PHINode>(User); +  if (!PN || !Operand) +    return false; // not a phi, not dominated by latch block. + +  // Look at all of the uses of Operand by the PHI node.  If any use corresponds +  // to a block that is not dominated by the latch block, give up and use the +  // preincremented value. +  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) +    if (PN->getIncomingValue(i) == Operand && +        !DT->dominates(LatchBlock, PN->getIncomingBlock(i))) +      return false; + +  // Okay, all uses of Operand by PN are in predecessor blocks that really are +  // dominated by the latch block.  Use the post-incremented value. +  return true; +} + +/// AddUsersImpl - Inspect the specified instruction.  If it is a +/// reducible SCEV, recursively add its users to the IVUsesByStride set and +/// return true.  Otherwise, return false. +bool IVUsers::AddUsersImpl(Instruction *I, +                           SmallPtrSetImpl<Loop*> &SimpleLoopNests) { +  const DataLayout &DL = I->getModule()->getDataLayout(); + +  // Add this IV user to the Processed set before returning false to ensure that +  // all IV users are members of the set. See IVUsers::isIVUserOrOperand. +  if (!Processed.insert(I).second) +    return true;    // Instruction already handled. + +  if (!SE->isSCEVable(I->getType())) +    return false;   // Void and FP expressions cannot be reduced. + +  // IVUsers is used by LSR which assumes that all SCEV expressions are safe to +  // pass to SCEVExpander. Expressions are not safe to expand if they represent +  // operations that are not safe to speculate, namely integer division. +  if (!isa<PHINode>(I) && !isSafeToSpeculativelyExecute(I)) +    return false; + +  // LSR is not APInt clean, do not touch integers bigger than 64-bits. +  // Also avoid creating IVs of non-native types. For example, we don't want a +  // 64-bit IV in 32-bit code just because the loop has one 64-bit cast. +  uint64_t Width = SE->getTypeSizeInBits(I->getType()); +  if (Width > 64 || !DL.isLegalInteger(Width)) +    return false; + +  // Don't attempt to promote ephemeral values to indvars. They will be removed +  // later anyway. +  if (EphValues.count(I)) +    return false; + +  // Get the symbolic expression for this instruction. +  const SCEV *ISE = SE->getSCEV(I); + +  // If we've come to an uninteresting expression, stop the traversal and +  // call this a user. +  if (!isInteresting(ISE, I, L, SE, LI)) +    return false; + +  SmallPtrSet<Instruction *, 4> UniqueUsers; +  for (Use &U : I->uses()) { +    Instruction *User = cast<Instruction>(U.getUser()); +    if (!UniqueUsers.insert(User).second) +      continue; + +    // Do not infinitely recurse on PHI nodes. +    if (isa<PHINode>(User) && Processed.count(User)) +      continue; + +    // Only consider IVUsers that are dominated by simplified loop +    // headers. Otherwise, SCEVExpander will crash. +    BasicBlock *UseBB = User->getParent(); +    // A phi's use is live out of its predecessor block. +    if (PHINode *PHI = dyn_cast<PHINode>(User)) { +      unsigned OperandNo = U.getOperandNo(); +      unsigned ValNo = PHINode::getIncomingValueNumForOperand(OperandNo); +      UseBB = PHI->getIncomingBlock(ValNo); +    } +    if (!isSimplifiedLoopNest(UseBB, DT, LI, SimpleLoopNests)) +      return false; + +    // Descend recursively, but not into PHI nodes outside the current loop. +    // It's important to see the entire expression outside the loop to get +    // choices that depend on addressing mode use right, although we won't +    // consider references outside the loop in all cases. +    // If User is already in Processed, we don't want to recurse into it again, +    // but do want to record a second reference in the same instruction. +    bool AddUserToIVUsers = false; +    if (LI->getLoopFor(User->getParent()) != L) { +      if (isa<PHINode>(User) || Processed.count(User) || +          !AddUsersImpl(User, SimpleLoopNests)) { +        LLVM_DEBUG(dbgs() << "FOUND USER in other loop: " << *User << '\n' +                          << "   OF SCEV: " << *ISE << '\n'); +        AddUserToIVUsers = true; +      } +    } else if (Processed.count(User) || !AddUsersImpl(User, SimpleLoopNests)) { +      LLVM_DEBUG(dbgs() << "FOUND USER: " << *User << '\n' +                        << "   OF SCEV: " << *ISE << '\n'); +      AddUserToIVUsers = true; +    } + +    if (AddUserToIVUsers) { +      // Okay, we found a user that we cannot reduce. +      IVStrideUse &NewUse = AddUser(User, I); +      // Autodetect the post-inc loop set, populating NewUse.PostIncLoops. +      // The regular return value here is discarded; instead of recording +      // it, we just recompute it when we need it. +      const SCEV *OriginalISE = ISE; + +      auto NormalizePred = [&](const SCEVAddRecExpr *AR) { +        auto *L = AR->getLoop(); +        bool Result = IVUseShouldUsePostIncValue(User, I, L, DT); +        if (Result) +          NewUse.PostIncLoops.insert(L); +        return Result; +      }; + +      ISE = normalizeForPostIncUseIf(ISE, NormalizePred, *SE); + +      // PostIncNormalization effectively simplifies the expression under +      // pre-increment assumptions. Those assumptions (no wrapping) might not +      // hold for the post-inc value. Catch such cases by making sure the +      // transformation is invertible. +      if (OriginalISE != ISE) { +        const SCEV *DenormalizedISE = +            denormalizeForPostIncUse(ISE, NewUse.PostIncLoops, *SE); + +        // If we normalized the expression, but denormalization doesn't give the +        // original one, discard this user. +        if (OriginalISE != DenormalizedISE) { +          LLVM_DEBUG(dbgs() +                     << "   DISCARDING (NORMALIZATION ISN'T INVERTIBLE): " +                     << *ISE << '\n'); +          IVUses.pop_back(); +          return false; +        } +      } +      LLVM_DEBUG(if (SE->getSCEV(I) != ISE) dbgs() +                 << "   NORMALIZED TO: " << *ISE << '\n'); +    } +  } +  return true; +} + +bool IVUsers::AddUsersIfInteresting(Instruction *I) { +  // SCEVExpander can only handle users that are dominated by simplified loop +  // entries. Keep track of all loops that are only dominated by other simple +  // loops so we don't traverse the domtree for each user. +  SmallPtrSet<Loop*,16> SimpleLoopNests; + +  return AddUsersImpl(I, SimpleLoopNests); +} + +IVStrideUse &IVUsers::AddUser(Instruction *User, Value *Operand) { +  IVUses.push_back(new IVStrideUse(this, User, Operand)); +  return IVUses.back(); +} + +IVUsers::IVUsers(Loop *L, AssumptionCache *AC, LoopInfo *LI, DominatorTree *DT, +                 ScalarEvolution *SE) +    : L(L), AC(AC), LI(LI), DT(DT), SE(SE), IVUses() { +  // Collect ephemeral values so that AddUsersIfInteresting skips them. +  EphValues.clear(); +  CodeMetrics::collectEphemeralValues(L, AC, EphValues); + +  // Find all uses of induction variables in this loop, and categorize +  // them by stride.  Start by finding all of the PHI nodes in the header for +  // this loop.  If they are induction variables, inspect their uses. +  for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) +    (void)AddUsersIfInteresting(&*I); +} + +void IVUsers::print(raw_ostream &OS, const Module *M) const { +  OS << "IV Users for loop "; +  L->getHeader()->printAsOperand(OS, false); +  if (SE->hasLoopInvariantBackedgeTakenCount(L)) { +    OS << " with backedge-taken count " << *SE->getBackedgeTakenCount(L); +  } +  OS << ":\n"; + +  for (const IVStrideUse &IVUse : IVUses) { +    OS << "  "; +    IVUse.getOperandValToReplace()->printAsOperand(OS, false); +    OS << " = " << *getReplacementExpr(IVUse); +    for (auto PostIncLoop : IVUse.PostIncLoops) { +      OS << " (post-inc with loop "; +      PostIncLoop->getHeader()->printAsOperand(OS, false); +      OS << ")"; +    } +    OS << " in  "; +    if (IVUse.getUser()) +      IVUse.getUser()->print(OS); +    else +      OS << "Printing <null> User"; +    OS << '\n'; +  } +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void IVUsers::dump() const { print(dbgs()); } +#endif + +void IVUsers::releaseMemory() { +  Processed.clear(); +  IVUses.clear(); +} + +IVUsersWrapperPass::IVUsersWrapperPass() : LoopPass(ID) { +  initializeIVUsersWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void IVUsersWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.addRequired<AssumptionCacheTracker>(); +  AU.addRequired<LoopInfoWrapperPass>(); +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.addRequired<ScalarEvolutionWrapperPass>(); +  AU.setPreservesAll(); +} + +bool IVUsersWrapperPass::runOnLoop(Loop *L, LPPassManager &LPM) { +  auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( +      *L->getHeader()->getParent()); +  auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + +  IU.reset(new IVUsers(L, AC, LI, DT, SE)); +  return false; +} + +void IVUsersWrapperPass::print(raw_ostream &OS, const Module *M) const { +  IU->print(OS, M); +} + +void IVUsersWrapperPass::releaseMemory() { IU->releaseMemory(); } + +/// getReplacementExpr - Return a SCEV expression which computes the +/// value of the OperandValToReplace. +const SCEV *IVUsers::getReplacementExpr(const IVStrideUse &IU) const { +  return SE->getSCEV(IU.getOperandValToReplace()); +} + +/// getExpr - Return the expression for the use. +const SCEV *IVUsers::getExpr(const IVStrideUse &IU) const { +  return normalizeForPostIncUse(getReplacementExpr(IU), IU.getPostIncLoops(), +                                *SE); +} + +static const SCEVAddRecExpr *findAddRecForLoop(const SCEV *S, const Loop *L) { +  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { +    if (AR->getLoop() == L) +      return AR; +    return findAddRecForLoop(AR->getStart(), L); +  } + +  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { +    for (const auto *Op : Add->operands()) +      if (const SCEVAddRecExpr *AR = findAddRecForLoop(Op, L)) +        return AR; +    return nullptr; +  } + +  return nullptr; +} + +const SCEV *IVUsers::getStride(const IVStrideUse &IU, const Loop *L) const { +  if (const SCEVAddRecExpr *AR = findAddRecForLoop(getExpr(IU), L)) +    return AR->getStepRecurrence(*SE); +  return nullptr; +} + +void IVStrideUse::transformToPostInc(const Loop *L) { +  PostIncLoops.insert(L); +} + +void IVStrideUse::deleted() { +  // Remove this user from the list. +  Parent->Processed.erase(this->getUser()); +  Parent->IVUses.erase(this); +  // this now dangles! +} diff --git a/contrib/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp b/contrib/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp new file mode 100644 index 000000000000..4659c0a00629 --- /dev/null +++ b/contrib/llvm/lib/Analysis/IndirectCallPromotionAnalysis.cpp @@ -0,0 +1,107 @@ +//===-- IndirectCallPromotionAnalysis.cpp - Find promotion candidates ===// +// +//                      The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Helper methods for identifying profitable indirect call promotion +// candidates for an instruction when the indirect-call value profile metadata +// is available. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IndirectCallPromotionAnalysis.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Debug.h" +#include <string> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "pgo-icall-prom-analysis" + +// The percent threshold for the direct-call target (this call site vs the +// remaining call count) for it to be considered as the promotion target. +static cl::opt<unsigned> ICPRemainingPercentThreshold( +    "icp-remaining-percent-threshold", cl::init(30), cl::Hidden, cl::ZeroOrMore, +    cl::desc("The percentage threshold against remaining unpromoted indirect " +             "call count for the promotion")); + +// The percent threshold for the direct-call target (this call site vs the +// total call count) for it to be considered as the promotion target. +static cl::opt<unsigned> +    ICPTotalPercentThreshold("icp-total-percent-threshold", cl::init(5), +                             cl::Hidden, cl::ZeroOrMore, +                             cl::desc("The percentage threshold against total " +                                      "count for the promotion")); + +// Set the maximum number of targets to promote for a single indirect-call +// callsite. +static cl::opt<unsigned> +    MaxNumPromotions("icp-max-prom", cl::init(3), cl::Hidden, cl::ZeroOrMore, +                     cl::desc("Max number of promotions for a single indirect " +                              "call callsite")); + +ICallPromotionAnalysis::ICallPromotionAnalysis() { +  ValueDataArray = llvm::make_unique<InstrProfValueData[]>(MaxNumPromotions); +} + +bool ICallPromotionAnalysis::isPromotionProfitable(uint64_t Count, +                                                   uint64_t TotalCount, +                                                   uint64_t RemainingCount) { +  return Count * 100 >= ICPRemainingPercentThreshold * RemainingCount && +         Count * 100 >= ICPTotalPercentThreshold * TotalCount; +} + +// Indirect-call promotion heuristic. The direct targets are sorted based on +// the count. Stop at the first target that is not promoted. Returns the +// number of candidates deemed profitable. +uint32_t ICallPromotionAnalysis::getProfitablePromotionCandidates( +    const Instruction *Inst, uint32_t NumVals, uint64_t TotalCount) { +  ArrayRef<InstrProfValueData> ValueDataRef(ValueDataArray.get(), NumVals); + +  LLVM_DEBUG(dbgs() << " \nWork on callsite " << *Inst +                    << " Num_targets: " << NumVals << "\n"); + +  uint32_t I = 0; +  uint64_t RemainingCount = TotalCount; +  for (; I < MaxNumPromotions && I < NumVals; I++) { +    uint64_t Count = ValueDataRef[I].Count; +    assert(Count <= RemainingCount); +    LLVM_DEBUG(dbgs() << " Candidate " << I << " Count=" << Count +                      << "  Target_func: " << ValueDataRef[I].Value << "\n"); + +    if (!isPromotionProfitable(Count, TotalCount, RemainingCount)) { +      LLVM_DEBUG(dbgs() << " Not promote: Cold target.\n"); +      return I; +    } +    RemainingCount -= Count; +  } +  return I; +} + +ArrayRef<InstrProfValueData> +ICallPromotionAnalysis::getPromotionCandidatesForInstruction( +    const Instruction *I, uint32_t &NumVals, uint64_t &TotalCount, +    uint32_t &NumCandidates) { +  bool Res = +      getValueProfDataFromInst(*I, IPVK_IndirectCallTarget, MaxNumPromotions, +                               ValueDataArray.get(), NumVals, TotalCount); +  if (!Res) { +    NumCandidates = 0; +    return ArrayRef<InstrProfValueData>(); +  } +  NumCandidates = getProfitablePromotionCandidates(I, NumVals, TotalCount); +  return ArrayRef<InstrProfValueData>(ValueDataArray.get(), NumVals); +} diff --git a/contrib/llvm/lib/Analysis/InlineCost.cpp b/contrib/llvm/lib/Analysis/InlineCost.cpp new file mode 100644 index 000000000000..a6cccc3b5910 --- /dev/null +++ b/contrib/llvm/lib/Analysis/InlineCost.cpp @@ -0,0 +1,2153 @@ +//===- InlineCost.cpp - Cost analysis for inliner -------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements inline cost analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/InlineCost.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "inline-cost" + +STATISTIC(NumCallsAnalyzed, "Number of call sites analyzed"); + +static cl::opt<int> InlineThreshold( +    "inline-threshold", cl::Hidden, cl::init(225), cl::ZeroOrMore, +    cl::desc("Control the amount of inlining to perform (default = 225)")); + +static cl::opt<int> HintThreshold( +    "inlinehint-threshold", cl::Hidden, cl::init(325), +    cl::desc("Threshold for inlining functions with inline hint")); + +static cl::opt<int> +    ColdCallSiteThreshold("inline-cold-callsite-threshold", cl::Hidden, +                          cl::init(45), +                          cl::desc("Threshold for inlining cold callsites")); + +// We introduce this threshold to help performance of instrumentation based +// PGO before we actually hook up inliner with analysis passes such as BPI and +// BFI. +static cl::opt<int> ColdThreshold( +    "inlinecold-threshold", cl::Hidden, cl::init(45), +    cl::desc("Threshold for inlining functions with cold attribute")); + +static cl::opt<int> +    HotCallSiteThreshold("hot-callsite-threshold", cl::Hidden, cl::init(3000), +                         cl::ZeroOrMore, +                         cl::desc("Threshold for hot callsites ")); + +static cl::opt<int> LocallyHotCallSiteThreshold( +    "locally-hot-callsite-threshold", cl::Hidden, cl::init(525), cl::ZeroOrMore, +    cl::desc("Threshold for locally hot callsites ")); + +static cl::opt<int> ColdCallSiteRelFreq( +    "cold-callsite-rel-freq", cl::Hidden, cl::init(2), cl::ZeroOrMore, +    cl::desc("Maxmimum block frequency, expressed as a percentage of caller's " +             "entry frequency, for a callsite to be cold in the absence of " +             "profile information.")); + +static cl::opt<int> HotCallSiteRelFreq( +    "hot-callsite-rel-freq", cl::Hidden, cl::init(60), cl::ZeroOrMore, +    cl::desc("Minimum block frequency, expressed as a multiple of caller's " +             "entry frequency, for a callsite to be hot in the absence of " +             "profile information.")); + +static cl::opt<bool> OptComputeFullInlineCost( +    "inline-cost-full", cl::Hidden, cl::init(false), +    cl::desc("Compute the full inline cost of a call site even when the cost " +             "exceeds the threshold.")); + +namespace { + +class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { +  typedef InstVisitor<CallAnalyzer, bool> Base; +  friend class InstVisitor<CallAnalyzer, bool>; + +  /// The TargetTransformInfo available for this compilation. +  const TargetTransformInfo &TTI; + +  /// Getter for the cache of @llvm.assume intrinsics. +  std::function<AssumptionCache &(Function &)> &GetAssumptionCache; + +  /// Getter for BlockFrequencyInfo +  Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI; + +  /// Profile summary information. +  ProfileSummaryInfo *PSI; + +  /// The called function. +  Function &F; + +  // Cache the DataLayout since we use it a lot. +  const DataLayout &DL; + +  /// The OptimizationRemarkEmitter available for this compilation. +  OptimizationRemarkEmitter *ORE; + +  /// The candidate callsite being analyzed. Please do not use this to do +  /// analysis in the caller function; we want the inline cost query to be +  /// easily cacheable. Instead, use the cover function paramHasAttr. +  CallSite CandidateCS; + +  /// Tunable parameters that control the analysis. +  const InlineParams &Params; + +  int Threshold; +  int Cost; +  bool ComputeFullInlineCost; + +  bool IsCallerRecursive; +  bool IsRecursiveCall; +  bool ExposesReturnsTwice; +  bool HasDynamicAlloca; +  bool ContainsNoDuplicateCall; +  bool HasReturn; +  bool HasIndirectBr; +  bool HasUninlineableIntrinsic; +  bool UsesVarArgs; + +  /// Number of bytes allocated statically by the callee. +  uint64_t AllocatedSize; +  unsigned NumInstructions, NumVectorInstructions; +  int VectorBonus, TenPercentVectorBonus; +  // Bonus to be applied when the callee has only one reachable basic block. +  int SingleBBBonus; + +  /// While we walk the potentially-inlined instructions, we build up and +  /// maintain a mapping of simplified values specific to this callsite. The +  /// idea is to propagate any special information we have about arguments to +  /// this call through the inlinable section of the function, and account for +  /// likely simplifications post-inlining. The most important aspect we track +  /// is CFG altering simplifications -- when we prove a basic block dead, that +  /// can cause dramatic shifts in the cost of inlining a function. +  DenseMap<Value *, Constant *> SimplifiedValues; + +  /// Keep track of the values which map back (through function arguments) to +  /// allocas on the caller stack which could be simplified through SROA. +  DenseMap<Value *, Value *> SROAArgValues; + +  /// The mapping of caller Alloca values to their accumulated cost savings. If +  /// we have to disable SROA for one of the allocas, this tells us how much +  /// cost must be added. +  DenseMap<Value *, int> SROAArgCosts; + +  /// Keep track of values which map to a pointer base and constant offset. +  DenseMap<Value *, std::pair<Value *, APInt>> ConstantOffsetPtrs; + +  /// Keep track of dead blocks due to the constant arguments. +  SetVector<BasicBlock *> DeadBlocks; + +  /// The mapping of the blocks to their known unique successors due to the +  /// constant arguments. +  DenseMap<BasicBlock *, BasicBlock *> KnownSuccessors; + +  /// Model the elimination of repeated loads that is expected to happen +  /// whenever we simplify away the stores that would otherwise cause them to be +  /// loads. +  bool EnableLoadElimination; +  SmallPtrSet<Value *, 16> LoadAddrSet; +  int LoadEliminationCost; + +  // Custom simplification helper routines. +  bool isAllocaDerivedArg(Value *V); +  bool lookupSROAArgAndCost(Value *V, Value *&Arg, +                            DenseMap<Value *, int>::iterator &CostIt); +  void disableSROA(DenseMap<Value *, int>::iterator CostIt); +  void disableSROA(Value *V); +  void findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB); +  void accumulateSROACost(DenseMap<Value *, int>::iterator CostIt, +                          int InstructionCost); +  void disableLoadElimination(); +  bool isGEPFree(GetElementPtrInst &GEP); +  bool canFoldInboundsGEP(GetElementPtrInst &I); +  bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset); +  bool simplifyCallSite(Function *F, CallSite CS); +  template <typename Callable> +  bool simplifyInstruction(Instruction &I, Callable Evaluate); +  ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V); + +  /// Return true if the given argument to the function being considered for +  /// inlining has the given attribute set either at the call site or the +  /// function declaration.  Primarily used to inspect call site specific +  /// attributes since these can be more precise than the ones on the callee +  /// itself. +  bool paramHasAttr(Argument *A, Attribute::AttrKind Attr); + +  /// Return true if the given value is known non null within the callee if +  /// inlined through this particular callsite. +  bool isKnownNonNullInCallee(Value *V); + +  /// Update Threshold based on callsite properties such as callee +  /// attributes and callee hotness for PGO builds. The Callee is explicitly +  /// passed to support analyzing indirect calls whose target is inferred by +  /// analysis. +  void updateThreshold(CallSite CS, Function &Callee); + +  /// Return true if size growth is allowed when inlining the callee at CS. +  bool allowSizeGrowth(CallSite CS); + +  /// Return true if \p CS is a cold callsite. +  bool isColdCallSite(CallSite CS, BlockFrequencyInfo *CallerBFI); + +  /// Return a higher threshold if \p CS is a hot callsite. +  Optional<int> getHotCallSiteThreshold(CallSite CS, +                                        BlockFrequencyInfo *CallerBFI); + +  // Custom analysis routines. +  bool analyzeBlock(BasicBlock *BB, SmallPtrSetImpl<const Value *> &EphValues); + +  // Disable several entry points to the visitor so we don't accidentally use +  // them by declaring but not defining them here. +  void visit(Module *); +  void visit(Module &); +  void visit(Function *); +  void visit(Function &); +  void visit(BasicBlock *); +  void visit(BasicBlock &); + +  // Provide base case for our instruction visit. +  bool visitInstruction(Instruction &I); + +  // Our visit overrides. +  bool visitAlloca(AllocaInst &I); +  bool visitPHI(PHINode &I); +  bool visitGetElementPtr(GetElementPtrInst &I); +  bool visitBitCast(BitCastInst &I); +  bool visitPtrToInt(PtrToIntInst &I); +  bool visitIntToPtr(IntToPtrInst &I); +  bool visitCastInst(CastInst &I); +  bool visitUnaryInstruction(UnaryInstruction &I); +  bool visitCmpInst(CmpInst &I); +  bool visitSub(BinaryOperator &I); +  bool visitBinaryOperator(BinaryOperator &I); +  bool visitLoad(LoadInst &I); +  bool visitStore(StoreInst &I); +  bool visitExtractValue(ExtractValueInst &I); +  bool visitInsertValue(InsertValueInst &I); +  bool visitCallSite(CallSite CS); +  bool visitReturnInst(ReturnInst &RI); +  bool visitBranchInst(BranchInst &BI); +  bool visitSelectInst(SelectInst &SI); +  bool visitSwitchInst(SwitchInst &SI); +  bool visitIndirectBrInst(IndirectBrInst &IBI); +  bool visitResumeInst(ResumeInst &RI); +  bool visitCleanupReturnInst(CleanupReturnInst &RI); +  bool visitCatchReturnInst(CatchReturnInst &RI); +  bool visitUnreachableInst(UnreachableInst &I); + +public: +  CallAnalyzer(const TargetTransformInfo &TTI, +               std::function<AssumptionCache &(Function &)> &GetAssumptionCache, +               Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI, +               ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE, +               Function &Callee, CallSite CSArg, const InlineParams &Params) +      : TTI(TTI), GetAssumptionCache(GetAssumptionCache), GetBFI(GetBFI), +        PSI(PSI), F(Callee), DL(F.getParent()->getDataLayout()), ORE(ORE), +        CandidateCS(CSArg), Params(Params), Threshold(Params.DefaultThreshold), +        Cost(0), ComputeFullInlineCost(OptComputeFullInlineCost || +                                       Params.ComputeFullInlineCost || ORE), +        IsCallerRecursive(false), IsRecursiveCall(false), +        ExposesReturnsTwice(false), HasDynamicAlloca(false), +        ContainsNoDuplicateCall(false), HasReturn(false), HasIndirectBr(false), +        HasUninlineableIntrinsic(false), UsesVarArgs(false), AllocatedSize(0), +        NumInstructions(0), NumVectorInstructions(0), VectorBonus(0), +        SingleBBBonus(0), EnableLoadElimination(true), LoadEliminationCost(0), +        NumConstantArgs(0), NumConstantOffsetPtrArgs(0), NumAllocaArgs(0), +        NumConstantPtrCmps(0), NumConstantPtrDiffs(0), +        NumInstructionsSimplified(0), SROACostSavings(0), +        SROACostSavingsLost(0) {} + +  bool analyzeCall(CallSite CS); + +  int getThreshold() { return Threshold; } +  int getCost() { return Cost; } + +  // Keep a bunch of stats about the cost savings found so we can print them +  // out when debugging. +  unsigned NumConstantArgs; +  unsigned NumConstantOffsetPtrArgs; +  unsigned NumAllocaArgs; +  unsigned NumConstantPtrCmps; +  unsigned NumConstantPtrDiffs; +  unsigned NumInstructionsSimplified; +  unsigned SROACostSavings; +  unsigned SROACostSavingsLost; + +  void dump(); +}; + +} // namespace + +/// Test whether the given value is an Alloca-derived function argument. +bool CallAnalyzer::isAllocaDerivedArg(Value *V) { +  return SROAArgValues.count(V); +} + +/// Lookup the SROA-candidate argument and cost iterator which V maps to. +/// Returns false if V does not map to a SROA-candidate. +bool CallAnalyzer::lookupSROAArgAndCost( +    Value *V, Value *&Arg, DenseMap<Value *, int>::iterator &CostIt) { +  if (SROAArgValues.empty() || SROAArgCosts.empty()) +    return false; + +  DenseMap<Value *, Value *>::iterator ArgIt = SROAArgValues.find(V); +  if (ArgIt == SROAArgValues.end()) +    return false; + +  Arg = ArgIt->second; +  CostIt = SROAArgCosts.find(Arg); +  return CostIt != SROAArgCosts.end(); +} + +/// Disable SROA for the candidate marked by this cost iterator. +/// +/// This marks the candidate as no longer viable for SROA, and adds the cost +/// savings associated with it back into the inline cost measurement. +void CallAnalyzer::disableSROA(DenseMap<Value *, int>::iterator CostIt) { +  // If we're no longer able to perform SROA we need to undo its cost savings +  // and prevent subsequent analysis. +  Cost += CostIt->second; +  SROACostSavings -= CostIt->second; +  SROACostSavingsLost += CostIt->second; +  SROAArgCosts.erase(CostIt); +  disableLoadElimination(); +} + +/// If 'V' maps to a SROA candidate, disable SROA for it. +void CallAnalyzer::disableSROA(Value *V) { +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  if (lookupSROAArgAndCost(V, SROAArg, CostIt)) +    disableSROA(CostIt); +} + +/// Accumulate the given cost for a particular SROA candidate. +void CallAnalyzer::accumulateSROACost(DenseMap<Value *, int>::iterator CostIt, +                                      int InstructionCost) { +  CostIt->second += InstructionCost; +  SROACostSavings += InstructionCost; +} + +void CallAnalyzer::disableLoadElimination() { +  if (EnableLoadElimination) { +    Cost += LoadEliminationCost; +    LoadEliminationCost = 0; +    EnableLoadElimination = false; +  } +} + +/// Accumulate a constant GEP offset into an APInt if possible. +/// +/// Returns false if unable to compute the offset for any reason. Respects any +/// simplified values known during the analysis of this callsite. +bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) { +  unsigned IntPtrWidth = DL.getIndexTypeSizeInBits(GEP.getType()); +  assert(IntPtrWidth == Offset.getBitWidth()); + +  for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP); +       GTI != GTE; ++GTI) { +    ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand()); +    if (!OpC) +      if (Constant *SimpleOp = SimplifiedValues.lookup(GTI.getOperand())) +        OpC = dyn_cast<ConstantInt>(SimpleOp); +    if (!OpC) +      return false; +    if (OpC->isZero()) +      continue; + +    // Handle a struct index, which adds its field offset to the pointer. +    if (StructType *STy = GTI.getStructTypeOrNull()) { +      unsigned ElementIdx = OpC->getZExtValue(); +      const StructLayout *SL = DL.getStructLayout(STy); +      Offset += APInt(IntPtrWidth, SL->getElementOffset(ElementIdx)); +      continue; +    } + +    APInt TypeSize(IntPtrWidth, DL.getTypeAllocSize(GTI.getIndexedType())); +    Offset += OpC->getValue().sextOrTrunc(IntPtrWidth) * TypeSize; +  } +  return true; +} + +/// Use TTI to check whether a GEP is free. +/// +/// Respects any simplified values known during the analysis of this callsite. +bool CallAnalyzer::isGEPFree(GetElementPtrInst &GEP) { +  SmallVector<Value *, 4> Operands; +  Operands.push_back(GEP.getOperand(0)); +  for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) +    if (Constant *SimpleOp = SimplifiedValues.lookup(*I)) +       Operands.push_back(SimpleOp); +     else +       Operands.push_back(*I); +  return TargetTransformInfo::TCC_Free == TTI.getUserCost(&GEP, Operands); +} + +bool CallAnalyzer::visitAlloca(AllocaInst &I) { +  // Check whether inlining will turn a dynamic alloca into a static +  // alloca and handle that case. +  if (I.isArrayAllocation()) { +    Constant *Size = SimplifiedValues.lookup(I.getArraySize()); +    if (auto *AllocSize = dyn_cast_or_null<ConstantInt>(Size)) { +      Type *Ty = I.getAllocatedType(); +      AllocatedSize = SaturatingMultiplyAdd( +          AllocSize->getLimitedValue(), DL.getTypeAllocSize(Ty), AllocatedSize); +      return Base::visitAlloca(I); +    } +  } + +  // Accumulate the allocated size. +  if (I.isStaticAlloca()) { +    Type *Ty = I.getAllocatedType(); +    AllocatedSize = SaturatingAdd(DL.getTypeAllocSize(Ty), AllocatedSize); +  } + +  // We will happily inline static alloca instructions. +  if (I.isStaticAlloca()) +    return Base::visitAlloca(I); + +  // FIXME: This is overly conservative. Dynamic allocas are inefficient for +  // a variety of reasons, and so we would like to not inline them into +  // functions which don't currently have a dynamic alloca. This simply +  // disables inlining altogether in the presence of a dynamic alloca. +  HasDynamicAlloca = true; +  return false; +} + +bool CallAnalyzer::visitPHI(PHINode &I) { +  // FIXME: We need to propagate SROA *disabling* through phi nodes, even +  // though we don't want to propagate it's bonuses. The idea is to disable +  // SROA if it *might* be used in an inappropriate manner. + +  // Phi nodes are always zero-cost. +  // FIXME: Pointer sizes may differ between different address spaces, so do we +  // need to use correct address space in the call to getPointerSizeInBits here? +  // Or could we skip the getPointerSizeInBits call completely? As far as I can +  // see the ZeroOffset is used as a dummy value, so we can probably use any +  // bit width for the ZeroOffset? +  APInt ZeroOffset = APInt::getNullValue(DL.getPointerSizeInBits(0)); +  bool CheckSROA = I.getType()->isPointerTy(); + +  // Track the constant or pointer with constant offset we've seen so far. +  Constant *FirstC = nullptr; +  std::pair<Value *, APInt> FirstBaseAndOffset = {nullptr, ZeroOffset}; +  Value *FirstV = nullptr; + +  for (unsigned i = 0, e = I.getNumIncomingValues(); i != e; ++i) { +    BasicBlock *Pred = I.getIncomingBlock(i); +    // If the incoming block is dead, skip the incoming block. +    if (DeadBlocks.count(Pred)) +      continue; +    // If the parent block of phi is not the known successor of the incoming +    // block, skip the incoming block. +    BasicBlock *KnownSuccessor = KnownSuccessors[Pred]; +    if (KnownSuccessor && KnownSuccessor != I.getParent()) +      continue; + +    Value *V = I.getIncomingValue(i); +    // If the incoming value is this phi itself, skip the incoming value. +    if (&I == V) +      continue; + +    Constant *C = dyn_cast<Constant>(V); +    if (!C) +      C = SimplifiedValues.lookup(V); + +    std::pair<Value *, APInt> BaseAndOffset = {nullptr, ZeroOffset}; +    if (!C && CheckSROA) +      BaseAndOffset = ConstantOffsetPtrs.lookup(V); + +    if (!C && !BaseAndOffset.first) +      // The incoming value is neither a constant nor a pointer with constant +      // offset, exit early. +      return true; + +    if (FirstC) { +      if (FirstC == C) +        // If we've seen a constant incoming value before and it is the same +        // constant we see this time, continue checking the next incoming value. +        continue; +      // Otherwise early exit because we either see a different constant or saw +      // a constant before but we have a pointer with constant offset this time. +      return true; +    } + +    if (FirstV) { +      // The same logic as above, but check pointer with constant offset here. +      if (FirstBaseAndOffset == BaseAndOffset) +        continue; +      return true; +    } + +    if (C) { +      // This is the 1st time we've seen a constant, record it. +      FirstC = C; +      continue; +    } + +    // The remaining case is that this is the 1st time we've seen a pointer with +    // constant offset, record it. +    FirstV = V; +    FirstBaseAndOffset = BaseAndOffset; +  } + +  // Check if we can map phi to a constant. +  if (FirstC) { +    SimplifiedValues[&I] = FirstC; +    return true; +  } + +  // Check if we can map phi to a pointer with constant offset. +  if (FirstBaseAndOffset.first) { +    ConstantOffsetPtrs[&I] = FirstBaseAndOffset; + +    Value *SROAArg; +    DenseMap<Value *, int>::iterator CostIt; +    if (lookupSROAArgAndCost(FirstV, SROAArg, CostIt)) +      SROAArgValues[&I] = SROAArg; +  } + +  return true; +} + +/// Check we can fold GEPs of constant-offset call site argument pointers. +/// This requires target data and inbounds GEPs. +/// +/// \return true if the specified GEP can be folded. +bool CallAnalyzer::canFoldInboundsGEP(GetElementPtrInst &I) { +  // Check if we have a base + offset for the pointer. +  std::pair<Value *, APInt> BaseAndOffset = +      ConstantOffsetPtrs.lookup(I.getPointerOperand()); +  if (!BaseAndOffset.first) +    return false; + +  // Check if the offset of this GEP is constant, and if so accumulate it +  // into Offset. +  if (!accumulateGEPOffset(cast<GEPOperator>(I), BaseAndOffset.second)) +    return false; + +  // Add the result as a new mapping to Base + Offset. +  ConstantOffsetPtrs[&I] = BaseAndOffset; + +  return true; +} + +bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  bool SROACandidate = +      lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt); + +  // Lambda to check whether a GEP's indices are all constant. +  auto IsGEPOffsetConstant = [&](GetElementPtrInst &GEP) { +    for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end(); I != E; ++I) +      if (!isa<Constant>(*I) && !SimplifiedValues.lookup(*I)) +        return false; +    return true; +  }; + +  if ((I.isInBounds() && canFoldInboundsGEP(I)) || IsGEPOffsetConstant(I)) { +    if (SROACandidate) +      SROAArgValues[&I] = SROAArg; + +    // Constant GEPs are modeled as free. +    return true; +  } + +  // Variable GEPs will require math and will disable SROA. +  if (SROACandidate) +    disableSROA(CostIt); +  return isGEPFree(I); +} + +/// Simplify \p I if its operands are constants and update SimplifiedValues. +/// \p Evaluate is a callable specific to instruction type that evaluates the +/// instruction when all the operands are constants. +template <typename Callable> +bool CallAnalyzer::simplifyInstruction(Instruction &I, Callable Evaluate) { +  SmallVector<Constant *, 2> COps; +  for (Value *Op : I.operands()) { +    Constant *COp = dyn_cast<Constant>(Op); +    if (!COp) +      COp = SimplifiedValues.lookup(Op); +    if (!COp) +      return false; +    COps.push_back(COp); +  } +  auto *C = Evaluate(COps); +  if (!C) +    return false; +  SimplifiedValues[&I] = C; +  return true; +} + +bool CallAnalyzer::visitBitCast(BitCastInst &I) { +  // Propagate constants through bitcasts. +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantExpr::getBitCast(COps[0], I.getType()); +      })) +    return true; + +  // Track base/offsets through casts +  std::pair<Value *, APInt> BaseAndOffset = +      ConstantOffsetPtrs.lookup(I.getOperand(0)); +  // Casts don't change the offset, just wrap it up. +  if (BaseAndOffset.first) +    ConstantOffsetPtrs[&I] = BaseAndOffset; + +  // Also look for SROA candidates here. +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt)) +    SROAArgValues[&I] = SROAArg; + +  // Bitcasts are always zero cost. +  return true; +} + +bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) { +  // Propagate constants through ptrtoint. +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantExpr::getPtrToInt(COps[0], I.getType()); +      })) +    return true; + +  // Track base/offset pairs when converted to a plain integer provided the +  // integer is large enough to represent the pointer. +  unsigned IntegerSize = I.getType()->getScalarSizeInBits(); +  unsigned AS = I.getOperand(0)->getType()->getPointerAddressSpace(); +  if (IntegerSize >= DL.getPointerSizeInBits(AS)) { +    std::pair<Value *, APInt> BaseAndOffset = +        ConstantOffsetPtrs.lookup(I.getOperand(0)); +    if (BaseAndOffset.first) +      ConstantOffsetPtrs[&I] = BaseAndOffset; +  } + +  // This is really weird. Technically, ptrtoint will disable SROA. However, +  // unless that ptrtoint is *used* somewhere in the live basic blocks after +  // inlining, it will be nuked, and SROA should proceed. All of the uses which +  // would block SROA would also block SROA if applied directly to a pointer, +  // and so we can just add the integer in here. The only places where SROA is +  // preserved either cannot fire on an integer, or won't in-and-of themselves +  // disable SROA (ext) w/o some later use that we would see and disable. +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt)) +    SROAArgValues[&I] = SROAArg; + +  return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I); +} + +bool CallAnalyzer::visitIntToPtr(IntToPtrInst &I) { +  // Propagate constants through ptrtoint. +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantExpr::getIntToPtr(COps[0], I.getType()); +      })) +    return true; + +  // Track base/offset pairs when round-tripped through a pointer without +  // modifications provided the integer is not too large. +  Value *Op = I.getOperand(0); +  unsigned IntegerSize = Op->getType()->getScalarSizeInBits(); +  if (IntegerSize <= DL.getPointerTypeSizeInBits(I.getType())) { +    std::pair<Value *, APInt> BaseAndOffset = ConstantOffsetPtrs.lookup(Op); +    if (BaseAndOffset.first) +      ConstantOffsetPtrs[&I] = BaseAndOffset; +  } + +  // "Propagate" SROA here in the same manner as we do for ptrtoint above. +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  if (lookupSROAArgAndCost(Op, SROAArg, CostIt)) +    SROAArgValues[&I] = SROAArg; + +  return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I); +} + +bool CallAnalyzer::visitCastInst(CastInst &I) { +  // Propagate constants through ptrtoint. +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantExpr::getCast(I.getOpcode(), COps[0], I.getType()); +      })) +    return true; + +  // Disable SROA in the face of arbitrary casts we don't whitelist elsewhere. +  disableSROA(I.getOperand(0)); + +  // If this is a floating-point cast, and the target says this operation +  // is expensive, this may eventually become a library call. Treat the cost +  // as such. +  switch (I.getOpcode()) { +  case Instruction::FPTrunc: +  case Instruction::FPExt: +  case Instruction::UIToFP: +  case Instruction::SIToFP: +  case Instruction::FPToUI: +  case Instruction::FPToSI: +    if (TTI.getFPOpCost(I.getType()) == TargetTransformInfo::TCC_Expensive) +      Cost += InlineConstants::CallPenalty; +  default: +    break; +  } + +  return TargetTransformInfo::TCC_Free == TTI.getUserCost(&I); +} + +bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) { +  Value *Operand = I.getOperand(0); +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantFoldInstOperands(&I, COps[0], DL); +      })) +    return true; + +  // Disable any SROA on the argument to arbitrary unary operators. +  disableSROA(Operand); + +  return false; +} + +bool CallAnalyzer::paramHasAttr(Argument *A, Attribute::AttrKind Attr) { +  return CandidateCS.paramHasAttr(A->getArgNo(), Attr); +} + +bool CallAnalyzer::isKnownNonNullInCallee(Value *V) { +  // Does the *call site* have the NonNull attribute set on an argument?  We +  // use the attribute on the call site to memoize any analysis done in the +  // caller. This will also trip if the callee function has a non-null +  // parameter attribute, but that's a less interesting case because hopefully +  // the callee would already have been simplified based on that. +  if (Argument *A = dyn_cast<Argument>(V)) +    if (paramHasAttr(A, Attribute::NonNull)) +      return true; + +  // Is this an alloca in the caller?  This is distinct from the attribute case +  // above because attributes aren't updated within the inliner itself and we +  // always want to catch the alloca derived case. +  if (isAllocaDerivedArg(V)) +    // We can actually predict the result of comparisons between an +    // alloca-derived value and null. Note that this fires regardless of +    // SROA firing. +    return true; + +  return false; +} + +bool CallAnalyzer::allowSizeGrowth(CallSite CS) { +  // If the normal destination of the invoke or the parent block of the call +  // site is unreachable-terminated, there is little point in inlining this +  // unless there is literally zero cost. +  // FIXME: Note that it is possible that an unreachable-terminated block has a +  // hot entry. For example, in below scenario inlining hot_call_X() may be +  // beneficial : +  // main() { +  //   hot_call_1(); +  //   ... +  //   hot_call_N() +  //   exit(0); +  // } +  // For now, we are not handling this corner case here as it is rare in real +  // code. In future, we should elaborate this based on BPI and BFI in more +  // general threshold adjusting heuristics in updateThreshold(). +  Instruction *Instr = CS.getInstruction(); +  if (InvokeInst *II = dyn_cast<InvokeInst>(Instr)) { +    if (isa<UnreachableInst>(II->getNormalDest()->getTerminator())) +      return false; +  } else if (isa<UnreachableInst>(Instr->getParent()->getTerminator())) +    return false; + +  return true; +} + +bool CallAnalyzer::isColdCallSite(CallSite CS, BlockFrequencyInfo *CallerBFI) { +  // If global profile summary is available, then callsite's coldness is +  // determined based on that. +  if (PSI && PSI->hasProfileSummary()) +    return PSI->isColdCallSite(CS, CallerBFI); + +  // Otherwise we need BFI to be available. +  if (!CallerBFI) +    return false; + +  // Determine if the callsite is cold relative to caller's entry. We could +  // potentially cache the computation of scaled entry frequency, but the added +  // complexity is not worth it unless this scaling shows up high in the +  // profiles. +  const BranchProbability ColdProb(ColdCallSiteRelFreq, 100); +  auto CallSiteBB = CS.getInstruction()->getParent(); +  auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB); +  auto CallerEntryFreq = +      CallerBFI->getBlockFreq(&(CS.getCaller()->getEntryBlock())); +  return CallSiteFreq < CallerEntryFreq * ColdProb; +} + +Optional<int> +CallAnalyzer::getHotCallSiteThreshold(CallSite CS, +                                      BlockFrequencyInfo *CallerBFI) { + +  // If global profile summary is available, then callsite's hotness is +  // determined based on that. +  if (PSI && PSI->hasProfileSummary() && PSI->isHotCallSite(CS, CallerBFI)) +    return Params.HotCallSiteThreshold; + +  // Otherwise we need BFI to be available and to have a locally hot callsite +  // threshold. +  if (!CallerBFI || !Params.LocallyHotCallSiteThreshold) +    return None; + +  // Determine if the callsite is hot relative to caller's entry. We could +  // potentially cache the computation of scaled entry frequency, but the added +  // complexity is not worth it unless this scaling shows up high in the +  // profiles. +  auto CallSiteBB = CS.getInstruction()->getParent(); +  auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB).getFrequency(); +  auto CallerEntryFreq = CallerBFI->getEntryFreq(); +  if (CallSiteFreq >= CallerEntryFreq * HotCallSiteRelFreq) +    return Params.LocallyHotCallSiteThreshold; + +  // Otherwise treat it normally. +  return None; +} + +void CallAnalyzer::updateThreshold(CallSite CS, Function &Callee) { +  // If no size growth is allowed for this inlining, set Threshold to 0. +  if (!allowSizeGrowth(CS)) { +    Threshold = 0; +    return; +  } + +  Function *Caller = CS.getCaller(); + +  // return min(A, B) if B is valid. +  auto MinIfValid = [](int A, Optional<int> B) { +    return B ? std::min(A, B.getValue()) : A; +  }; + +  // return max(A, B) if B is valid. +  auto MaxIfValid = [](int A, Optional<int> B) { +    return B ? std::max(A, B.getValue()) : A; +  }; + +  // Various bonus percentages. These are multiplied by Threshold to get the +  // bonus values. +  // SingleBBBonus: This bonus is applied if the callee has a single reachable +  // basic block at the given callsite context. This is speculatively applied +  // and withdrawn if more than one basic block is seen. +  // +  // Vector bonuses: We want to more aggressively inline vector-dense kernels +  // and apply this bonus based on the percentage of vector instructions. A +  // bonus is applied if the vector instructions exceed 50% and half that amount +  // is applied if it exceeds 10%. Note that these bonuses are some what +  // arbitrary and evolved over time by accident as much as because they are +  // principled bonuses. +  // FIXME: It would be nice to base the bonus values on something more +  // scientific. +  // +  // LstCallToStaticBonus: This large bonus is applied to ensure the inlining +  // of the last call to a static function as inlining such functions is +  // guaranteed to reduce code size. +  // +  // These bonus percentages may be set to 0 based on properties of the caller +  // and the callsite. +  int SingleBBBonusPercent = 50; +  int VectorBonusPercent = 150; +  int LastCallToStaticBonus = InlineConstants::LastCallToStaticBonus; + +  // Lambda to set all the above bonus and bonus percentages to 0. +  auto DisallowAllBonuses = [&]() { +    SingleBBBonusPercent = 0; +    VectorBonusPercent = 0; +    LastCallToStaticBonus = 0; +  }; + +  // Use the OptMinSizeThreshold or OptSizeThreshold knob if they are available +  // and reduce the threshold if the caller has the necessary attribute. +  if (Caller->optForMinSize()) { +    Threshold = MinIfValid(Threshold, Params.OptMinSizeThreshold); +    // For minsize, we want to disable the single BB bonus and the vector +    // bonuses, but not the last-call-to-static bonus. Inlining the last call to +    // a static function will, at the minimum, eliminate the parameter setup and +    // call/return instructions. +    SingleBBBonusPercent = 0; +    VectorBonusPercent = 0; +  } else if (Caller->optForSize()) +    Threshold = MinIfValid(Threshold, Params.OptSizeThreshold); + +  // Adjust the threshold based on inlinehint attribute and profile based +  // hotness information if the caller does not have MinSize attribute. +  if (!Caller->optForMinSize()) { +    if (Callee.hasFnAttribute(Attribute::InlineHint)) +      Threshold = MaxIfValid(Threshold, Params.HintThreshold); + +    // FIXME: After switching to the new passmanager, simplify the logic below +    // by checking only the callsite hotness/coldness as we will reliably +    // have local profile information. +    // +    // Callsite hotness and coldness can be determined if sample profile is +    // used (which adds hotness metadata to calls) or if caller's +    // BlockFrequencyInfo is available. +    BlockFrequencyInfo *CallerBFI = GetBFI ? &((*GetBFI)(*Caller)) : nullptr; +    auto HotCallSiteThreshold = getHotCallSiteThreshold(CS, CallerBFI); +    if (!Caller->optForSize() && HotCallSiteThreshold) { +      LLVM_DEBUG(dbgs() << "Hot callsite.\n"); +      // FIXME: This should update the threshold only if it exceeds the +      // current threshold, but AutoFDO + ThinLTO currently relies on this +      // behavior to prevent inlining of hot callsites during ThinLTO +      // compile phase. +      Threshold = HotCallSiteThreshold.getValue(); +    } else if (isColdCallSite(CS, CallerBFI)) { +      LLVM_DEBUG(dbgs() << "Cold callsite.\n"); +      // Do not apply bonuses for a cold callsite including the +      // LastCallToStatic bonus. While this bonus might result in code size +      // reduction, it can cause the size of a non-cold caller to increase +      // preventing it from being inlined. +      DisallowAllBonuses(); +      Threshold = MinIfValid(Threshold, Params.ColdCallSiteThreshold); +    } else if (PSI) { +      // Use callee's global profile information only if we have no way of +      // determining this via callsite information. +      if (PSI->isFunctionEntryHot(&Callee)) { +        LLVM_DEBUG(dbgs() << "Hot callee.\n"); +        // If callsite hotness can not be determined, we may still know +        // that the callee is hot and treat it as a weaker hint for threshold +        // increase. +        Threshold = MaxIfValid(Threshold, Params.HintThreshold); +      } else if (PSI->isFunctionEntryCold(&Callee)) { +        LLVM_DEBUG(dbgs() << "Cold callee.\n"); +        // Do not apply bonuses for a cold callee including the +        // LastCallToStatic bonus. While this bonus might result in code size +        // reduction, it can cause the size of a non-cold caller to increase +        // preventing it from being inlined. +        DisallowAllBonuses(); +        Threshold = MinIfValid(Threshold, Params.ColdThreshold); +      } +    } +  } + +  // Finally, take the target-specific inlining threshold multiplier into +  // account. +  Threshold *= TTI.getInliningThresholdMultiplier(); + +  SingleBBBonus = Threshold * SingleBBBonusPercent / 100; +  VectorBonus = Threshold * VectorBonusPercent / 100; + +  bool OnlyOneCallAndLocalLinkage = +      F.hasLocalLinkage() && F.hasOneUse() && &F == CS.getCalledFunction(); +  // If there is only one call of the function, and it has internal linkage, +  // the cost of inlining it drops dramatically. It may seem odd to update +  // Cost in updateThreshold, but the bonus depends on the logic in this method. +  if (OnlyOneCallAndLocalLinkage) +    Cost -= LastCallToStaticBonus; +} + +bool CallAnalyzer::visitCmpInst(CmpInst &I) { +  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); +  // First try to handle simplified comparisons. +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantExpr::getCompare(I.getPredicate(), COps[0], COps[1]); +      })) +    return true; + +  if (I.getOpcode() == Instruction::FCmp) +    return false; + +  // Otherwise look for a comparison between constant offset pointers with +  // a common base. +  Value *LHSBase, *RHSBase; +  APInt LHSOffset, RHSOffset; +  std::tie(LHSBase, LHSOffset) = ConstantOffsetPtrs.lookup(LHS); +  if (LHSBase) { +    std::tie(RHSBase, RHSOffset) = ConstantOffsetPtrs.lookup(RHS); +    if (RHSBase && LHSBase == RHSBase) { +      // We have common bases, fold the icmp to a constant based on the +      // offsets. +      Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset); +      Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset); +      if (Constant *C = ConstantExpr::getICmp(I.getPredicate(), CLHS, CRHS)) { +        SimplifiedValues[&I] = C; +        ++NumConstantPtrCmps; +        return true; +      } +    } +  } + +  // If the comparison is an equality comparison with null, we can simplify it +  // if we know the value (argument) can't be null +  if (I.isEquality() && isa<ConstantPointerNull>(I.getOperand(1)) && +      isKnownNonNullInCallee(I.getOperand(0))) { +    bool IsNotEqual = I.getPredicate() == CmpInst::ICMP_NE; +    SimplifiedValues[&I] = IsNotEqual ? ConstantInt::getTrue(I.getType()) +                                      : ConstantInt::getFalse(I.getType()); +    return true; +  } +  // Finally check for SROA candidates in comparisons. +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  if (lookupSROAArgAndCost(I.getOperand(0), SROAArg, CostIt)) { +    if (isa<ConstantPointerNull>(I.getOperand(1))) { +      accumulateSROACost(CostIt, InlineConstants::InstrCost); +      return true; +    } + +    disableSROA(CostIt); +  } + +  return false; +} + +bool CallAnalyzer::visitSub(BinaryOperator &I) { +  // Try to handle a special case: we can fold computing the difference of two +  // constant-related pointers. +  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); +  Value *LHSBase, *RHSBase; +  APInt LHSOffset, RHSOffset; +  std::tie(LHSBase, LHSOffset) = ConstantOffsetPtrs.lookup(LHS); +  if (LHSBase) { +    std::tie(RHSBase, RHSOffset) = ConstantOffsetPtrs.lookup(RHS); +    if (RHSBase && LHSBase == RHSBase) { +      // We have common bases, fold the subtract to a constant based on the +      // offsets. +      Constant *CLHS = ConstantInt::get(LHS->getContext(), LHSOffset); +      Constant *CRHS = ConstantInt::get(RHS->getContext(), RHSOffset); +      if (Constant *C = ConstantExpr::getSub(CLHS, CRHS)) { +        SimplifiedValues[&I] = C; +        ++NumConstantPtrDiffs; +        return true; +      } +    } +  } + +  // Otherwise, fall back to the generic logic for simplifying and handling +  // instructions. +  return Base::visitSub(I); +} + +bool CallAnalyzer::visitBinaryOperator(BinaryOperator &I) { +  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); +  Constant *CLHS = dyn_cast<Constant>(LHS); +  if (!CLHS) +    CLHS = SimplifiedValues.lookup(LHS); +  Constant *CRHS = dyn_cast<Constant>(RHS); +  if (!CRHS) +    CRHS = SimplifiedValues.lookup(RHS); + +  Value *SimpleV = nullptr; +  if (auto FI = dyn_cast<FPMathOperator>(&I)) +    SimpleV = SimplifyFPBinOp(I.getOpcode(), CLHS ? CLHS : LHS, +                              CRHS ? CRHS : RHS, FI->getFastMathFlags(), DL); +  else +    SimpleV = +        SimplifyBinOp(I.getOpcode(), CLHS ? CLHS : LHS, CRHS ? CRHS : RHS, DL); + +  if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) +    SimplifiedValues[&I] = C; + +  if (SimpleV) +    return true; + +  // Disable any SROA on arguments to arbitrary, unsimplified binary operators. +  disableSROA(LHS); +  disableSROA(RHS); + +  // If the instruction is floating point, and the target says this operation +  // is expensive, this may eventually become a library call. Treat the cost +  // as such. +  if (I.getType()->isFloatingPointTy() && +      TTI.getFPOpCost(I.getType()) == TargetTransformInfo::TCC_Expensive) +    Cost += InlineConstants::CallPenalty; + +  return false; +} + +bool CallAnalyzer::visitLoad(LoadInst &I) { +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  if (lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt)) { +    if (I.isSimple()) { +      accumulateSROACost(CostIt, InlineConstants::InstrCost); +      return true; +    } + +    disableSROA(CostIt); +  } + +  // If the data is already loaded from this address and hasn't been clobbered +  // by any stores or calls, this load is likely to be redundant and can be +  // eliminated. +  if (EnableLoadElimination && +      !LoadAddrSet.insert(I.getPointerOperand()).second && I.isUnordered()) { +    LoadEliminationCost += InlineConstants::InstrCost; +    return true; +  } + +  return false; +} + +bool CallAnalyzer::visitStore(StoreInst &I) { +  Value *SROAArg; +  DenseMap<Value *, int>::iterator CostIt; +  if (lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt)) { +    if (I.isSimple()) { +      accumulateSROACost(CostIt, InlineConstants::InstrCost); +      return true; +    } + +    disableSROA(CostIt); +  } + +  // The store can potentially clobber loads and prevent repeated loads from +  // being eliminated. +  // FIXME: +  // 1. We can probably keep an initial set of eliminatable loads substracted +  // from the cost even when we finally see a store. We just need to disable +  // *further* accumulation of elimination savings. +  // 2. We should probably at some point thread MemorySSA for the callee into +  // this and then use that to actually compute *really* precise savings. +  disableLoadElimination(); +  return false; +} + +bool CallAnalyzer::visitExtractValue(ExtractValueInst &I) { +  // Constant folding for extract value is trivial. +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantExpr::getExtractValue(COps[0], I.getIndices()); +      })) +    return true; + +  // SROA can look through these but give them a cost. +  return false; +} + +bool CallAnalyzer::visitInsertValue(InsertValueInst &I) { +  // Constant folding for insert value is trivial. +  if (simplifyInstruction(I, [&](SmallVectorImpl<Constant *> &COps) { +        return ConstantExpr::getInsertValue(/*AggregateOperand*/ COps[0], +                                            /*InsertedValueOperand*/ COps[1], +                                            I.getIndices()); +      })) +    return true; + +  // SROA can look through these but give them a cost. +  return false; +} + +/// Try to simplify a call site. +/// +/// Takes a concrete function and callsite and tries to actually simplify it by +/// analyzing the arguments and call itself with instsimplify. Returns true if +/// it has simplified the callsite to some other entity (a constant), making it +/// free. +bool CallAnalyzer::simplifyCallSite(Function *F, CallSite CS) { +  // FIXME: Using the instsimplify logic directly for this is inefficient +  // because we have to continually rebuild the argument list even when no +  // simplifications can be performed. Until that is fixed with remapping +  // inside of instsimplify, directly constant fold calls here. +  if (!canConstantFoldCallTo(CS, F)) +    return false; + +  // Try to re-map the arguments to constants. +  SmallVector<Constant *, 4> ConstantArgs; +  ConstantArgs.reserve(CS.arg_size()); +  for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; +       ++I) { +    Constant *C = dyn_cast<Constant>(*I); +    if (!C) +      C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(*I)); +    if (!C) +      return false; // This argument doesn't map to a constant. + +    ConstantArgs.push_back(C); +  } +  if (Constant *C = ConstantFoldCall(CS, F, ConstantArgs)) { +    SimplifiedValues[CS.getInstruction()] = C; +    return true; +  } + +  return false; +} + +bool CallAnalyzer::visitCallSite(CallSite CS) { +  if (CS.hasFnAttr(Attribute::ReturnsTwice) && +      !F.hasFnAttribute(Attribute::ReturnsTwice)) { +    // This aborts the entire analysis. +    ExposesReturnsTwice = true; +    return false; +  } +  if (CS.isCall() && cast<CallInst>(CS.getInstruction())->cannotDuplicate()) +    ContainsNoDuplicateCall = true; + +  if (Function *F = CS.getCalledFunction()) { +    // When we have a concrete function, first try to simplify it directly. +    if (simplifyCallSite(F, CS)) +      return true; + +    // Next check if it is an intrinsic we know about. +    // FIXME: Lift this into part of the InstVisitor. +    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { +      switch (II->getIntrinsicID()) { +      default: +        if (!CS.onlyReadsMemory() && !isAssumeLikeIntrinsic(II)) +          disableLoadElimination(); +        return Base::visitCallSite(CS); + +      case Intrinsic::load_relative: +        // This is normally lowered to 4 LLVM instructions. +        Cost += 3 * InlineConstants::InstrCost; +        return false; + +      case Intrinsic::memset: +      case Intrinsic::memcpy: +      case Intrinsic::memmove: +        disableLoadElimination(); +        // SROA can usually chew through these intrinsics, but they aren't free. +        return false; +      case Intrinsic::icall_branch_funnel: +      case Intrinsic::localescape: +        HasUninlineableIntrinsic = true; +        return false; +      case Intrinsic::vastart: +      case Intrinsic::vaend: +        UsesVarArgs = true; +        return false; +      } +    } + +    if (F == CS.getInstruction()->getFunction()) { +      // This flag will fully abort the analysis, so don't bother with anything +      // else. +      IsRecursiveCall = true; +      return false; +    } + +    if (TTI.isLoweredToCall(F)) { +      // We account for the average 1 instruction per call argument setup +      // here. +      Cost += CS.arg_size() * InlineConstants::InstrCost; + +      // Everything other than inline ASM will also have a significant cost +      // merely from making the call. +      if (!isa<InlineAsm>(CS.getCalledValue())) +        Cost += InlineConstants::CallPenalty; +    } + +    if (!CS.onlyReadsMemory()) +      disableLoadElimination(); +    return Base::visitCallSite(CS); +  } + +  // Otherwise we're in a very special case -- an indirect function call. See +  // if we can be particularly clever about this. +  Value *Callee = CS.getCalledValue(); + +  // First, pay the price of the argument setup. We account for the average +  // 1 instruction per call argument setup here. +  Cost += CS.arg_size() * InlineConstants::InstrCost; + +  // Next, check if this happens to be an indirect function call to a known +  // function in this inline context. If not, we've done all we can. +  Function *F = dyn_cast_or_null<Function>(SimplifiedValues.lookup(Callee)); +  if (!F) { +    if (!CS.onlyReadsMemory()) +      disableLoadElimination(); +    return Base::visitCallSite(CS); +  } + +  // If we have a constant that we are calling as a function, we can peer +  // through it and see the function target. This happens not infrequently +  // during devirtualization and so we want to give it a hefty bonus for +  // inlining, but cap that bonus in the event that inlining wouldn't pan +  // out. Pretend to inline the function, with a custom threshold. +  auto IndirectCallParams = Params; +  IndirectCallParams.DefaultThreshold = InlineConstants::IndirectCallThreshold; +  CallAnalyzer CA(TTI, GetAssumptionCache, GetBFI, PSI, ORE, *F, CS, +                  IndirectCallParams); +  if (CA.analyzeCall(CS)) { +    // We were able to inline the indirect call! Subtract the cost from the +    // threshold to get the bonus we want to apply, but don't go below zero. +    Cost -= std::max(0, CA.getThreshold() - CA.getCost()); +  } + +  if (!F->onlyReadsMemory()) +    disableLoadElimination(); +  return Base::visitCallSite(CS); +} + +bool CallAnalyzer::visitReturnInst(ReturnInst &RI) { +  // At least one return instruction will be free after inlining. +  bool Free = !HasReturn; +  HasReturn = true; +  return Free; +} + +bool CallAnalyzer::visitBranchInst(BranchInst &BI) { +  // We model unconditional branches as essentially free -- they really +  // shouldn't exist at all, but handling them makes the behavior of the +  // inliner more regular and predictable. Interestingly, conditional branches +  // which will fold away are also free. +  return BI.isUnconditional() || isa<ConstantInt>(BI.getCondition()) || +         dyn_cast_or_null<ConstantInt>( +             SimplifiedValues.lookup(BI.getCondition())); +} + +bool CallAnalyzer::visitSelectInst(SelectInst &SI) { +  bool CheckSROA = SI.getType()->isPointerTy(); +  Value *TrueVal = SI.getTrueValue(); +  Value *FalseVal = SI.getFalseValue(); + +  Constant *TrueC = dyn_cast<Constant>(TrueVal); +  if (!TrueC) +    TrueC = SimplifiedValues.lookup(TrueVal); +  Constant *FalseC = dyn_cast<Constant>(FalseVal); +  if (!FalseC) +    FalseC = SimplifiedValues.lookup(FalseVal); +  Constant *CondC = +      dyn_cast_or_null<Constant>(SimplifiedValues.lookup(SI.getCondition())); + +  if (!CondC) { +    // Select C, X, X => X +    if (TrueC == FalseC && TrueC) { +      SimplifiedValues[&SI] = TrueC; +      return true; +    } + +    if (!CheckSROA) +      return Base::visitSelectInst(SI); + +    std::pair<Value *, APInt> TrueBaseAndOffset = +        ConstantOffsetPtrs.lookup(TrueVal); +    std::pair<Value *, APInt> FalseBaseAndOffset = +        ConstantOffsetPtrs.lookup(FalseVal); +    if (TrueBaseAndOffset == FalseBaseAndOffset && TrueBaseAndOffset.first) { +      ConstantOffsetPtrs[&SI] = TrueBaseAndOffset; + +      Value *SROAArg; +      DenseMap<Value *, int>::iterator CostIt; +      if (lookupSROAArgAndCost(TrueVal, SROAArg, CostIt)) +        SROAArgValues[&SI] = SROAArg; +      return true; +    } + +    return Base::visitSelectInst(SI); +  } + +  // Select condition is a constant. +  Value *SelectedV = CondC->isAllOnesValue() +                         ? TrueVal +                         : (CondC->isNullValue()) ? FalseVal : nullptr; +  if (!SelectedV) { +    // Condition is a vector constant that is not all 1s or all 0s.  If all +    // operands are constants, ConstantExpr::getSelect() can handle the cases +    // such as select vectors. +    if (TrueC && FalseC) { +      if (auto *C = ConstantExpr::getSelect(CondC, TrueC, FalseC)) { +        SimplifiedValues[&SI] = C; +        return true; +      } +    } +    return Base::visitSelectInst(SI); +  } + +  // Condition is either all 1s or all 0s. SI can be simplified. +  if (Constant *SelectedC = dyn_cast<Constant>(SelectedV)) { +    SimplifiedValues[&SI] = SelectedC; +    return true; +  } + +  if (!CheckSROA) +    return true; + +  std::pair<Value *, APInt> BaseAndOffset = +      ConstantOffsetPtrs.lookup(SelectedV); +  if (BaseAndOffset.first) { +    ConstantOffsetPtrs[&SI] = BaseAndOffset; + +    Value *SROAArg; +    DenseMap<Value *, int>::iterator CostIt; +    if (lookupSROAArgAndCost(SelectedV, SROAArg, CostIt)) +      SROAArgValues[&SI] = SROAArg; +  } + +  return true; +} + +bool CallAnalyzer::visitSwitchInst(SwitchInst &SI) { +  // We model unconditional switches as free, see the comments on handling +  // branches. +  if (isa<ConstantInt>(SI.getCondition())) +    return true; +  if (Value *V = SimplifiedValues.lookup(SI.getCondition())) +    if (isa<ConstantInt>(V)) +      return true; + +  // Assume the most general case where the switch is lowered into +  // either a jump table, bit test, or a balanced binary tree consisting of +  // case clusters without merging adjacent clusters with the same +  // destination. We do not consider the switches that are lowered with a mix +  // of jump table/bit test/binary search tree. The cost of the switch is +  // proportional to the size of the tree or the size of jump table range. +  // +  // NB: We convert large switches which are just used to initialize large phi +  // nodes to lookup tables instead in simplify-cfg, so this shouldn't prevent +  // inlining those. It will prevent inlining in cases where the optimization +  // does not (yet) fire. + +  // Maximum valid cost increased in this function. +  int CostUpperBound = INT_MAX - InlineConstants::InstrCost - 1; + +  // Exit early for a large switch, assuming one case needs at least one +  // instruction. +  // FIXME: This is not true for a bit test, but ignore such case for now to +  // save compile-time. +  int64_t CostLowerBound = +      std::min((int64_t)CostUpperBound, +               (int64_t)SI.getNumCases() * InlineConstants::InstrCost + Cost); + +  if (CostLowerBound > Threshold && !ComputeFullInlineCost) { +    Cost = CostLowerBound; +    return false; +  } + +  unsigned JumpTableSize = 0; +  unsigned NumCaseCluster = +      TTI.getEstimatedNumberOfCaseClusters(SI, JumpTableSize); + +  // If suitable for a jump table, consider the cost for the table size and +  // branch to destination. +  if (JumpTableSize) { +    int64_t JTCost = (int64_t)JumpTableSize * InlineConstants::InstrCost + +                     4 * InlineConstants::InstrCost; + +    Cost = std::min((int64_t)CostUpperBound, JTCost + Cost); +    return false; +  } + +  // Considering forming a binary search, we should find the number of nodes +  // which is same as the number of comparisons when lowered. For a given +  // number of clusters, n, we can define a recursive function, f(n), to find +  // the number of nodes in the tree. The recursion is : +  // f(n) = 1 + f(n/2) + f (n - n/2), when n > 3, +  // and f(n) = n, when n <= 3. +  // This will lead a binary tree where the leaf should be either f(2) or f(3) +  // when n > 3.  So, the number of comparisons from leaves should be n, while +  // the number of non-leaf should be : +  //   2^(log2(n) - 1) - 1 +  //   = 2^log2(n) * 2^-1 - 1 +  //   = n / 2 - 1. +  // Considering comparisons from leaf and non-leaf nodes, we can estimate the +  // number of comparisons in a simple closed form : +  //   n + n / 2 - 1 = n * 3 / 2 - 1 +  if (NumCaseCluster <= 3) { +    // Suppose a comparison includes one compare and one conditional branch. +    Cost += NumCaseCluster * 2 * InlineConstants::InstrCost; +    return false; +  } + +  int64_t ExpectedNumberOfCompare = 3 * (int64_t)NumCaseCluster / 2 - 1; +  int64_t SwitchCost = +      ExpectedNumberOfCompare * 2 * InlineConstants::InstrCost; + +  Cost = std::min((int64_t)CostUpperBound, SwitchCost + Cost); +  return false; +} + +bool CallAnalyzer::visitIndirectBrInst(IndirectBrInst &IBI) { +  // We never want to inline functions that contain an indirectbr.  This is +  // incorrect because all the blockaddress's (in static global initializers +  // for example) would be referring to the original function, and this +  // indirect jump would jump from the inlined copy of the function into the +  // original function which is extremely undefined behavior. +  // FIXME: This logic isn't really right; we can safely inline functions with +  // indirectbr's as long as no other function or global references the +  // blockaddress of a block within the current function. +  HasIndirectBr = true; +  return false; +} + +bool CallAnalyzer::visitResumeInst(ResumeInst &RI) { +  // FIXME: It's not clear that a single instruction is an accurate model for +  // the inline cost of a resume instruction. +  return false; +} + +bool CallAnalyzer::visitCleanupReturnInst(CleanupReturnInst &CRI) { +  // FIXME: It's not clear that a single instruction is an accurate model for +  // the inline cost of a cleanupret instruction. +  return false; +} + +bool CallAnalyzer::visitCatchReturnInst(CatchReturnInst &CRI) { +  // FIXME: It's not clear that a single instruction is an accurate model for +  // the inline cost of a catchret instruction. +  return false; +} + +bool CallAnalyzer::visitUnreachableInst(UnreachableInst &I) { +  // FIXME: It might be reasonably to discount the cost of instructions leading +  // to unreachable as they have the lowest possible impact on both runtime and +  // code size. +  return true; // No actual code is needed for unreachable. +} + +bool CallAnalyzer::visitInstruction(Instruction &I) { +  // Some instructions are free. All of the free intrinsics can also be +  // handled by SROA, etc. +  if (TargetTransformInfo::TCC_Free == TTI.getUserCost(&I)) +    return true; + +  // We found something we don't understand or can't handle. Mark any SROA-able +  // values in the operand list as no longer viable. +  for (User::op_iterator OI = I.op_begin(), OE = I.op_end(); OI != OE; ++OI) +    disableSROA(*OI); + +  return false; +} + +/// Analyze a basic block for its contribution to the inline cost. +/// +/// This method walks the analyzer over every instruction in the given basic +/// block and accounts for their cost during inlining at this callsite. It +/// aborts early if the threshold has been exceeded or an impossible to inline +/// construct has been detected. It returns false if inlining is no longer +/// viable, and true if inlining remains viable. +bool CallAnalyzer::analyzeBlock(BasicBlock *BB, +                                SmallPtrSetImpl<const Value *> &EphValues) { +  for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { +    // FIXME: Currently, the number of instructions in a function regardless of +    // our ability to simplify them during inline to constants or dead code, +    // are actually used by the vector bonus heuristic. As long as that's true, +    // we have to special case debug intrinsics here to prevent differences in +    // inlining due to debug symbols. Eventually, the number of unsimplified +    // instructions shouldn't factor into the cost computation, but until then, +    // hack around it here. +    if (isa<DbgInfoIntrinsic>(I)) +      continue; + +    // Skip ephemeral values. +    if (EphValues.count(&*I)) +      continue; + +    ++NumInstructions; +    if (isa<ExtractElementInst>(I) || I->getType()->isVectorTy()) +      ++NumVectorInstructions; + +    // If the instruction simplified to a constant, there is no cost to this +    // instruction. Visit the instructions using our InstVisitor to account for +    // all of the per-instruction logic. The visit tree returns true if we +    // consumed the instruction in any way, and false if the instruction's base +    // cost should count against inlining. +    if (Base::visit(&*I)) +      ++NumInstructionsSimplified; +    else +      Cost += InlineConstants::InstrCost; + +    using namespace ore; +    // If the visit this instruction detected an uninlinable pattern, abort. +    if (IsRecursiveCall || ExposesReturnsTwice || HasDynamicAlloca || +        HasIndirectBr || HasUninlineableIntrinsic || UsesVarArgs) { +      if (ORE) +        ORE->emit([&]() { +          return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", +                                          CandidateCS.getInstruction()) +                 << NV("Callee", &F) +                 << " has uninlinable pattern and cost is not fully computed"; +        }); +      return false; +    } + +    // If the caller is a recursive function then we don't want to inline +    // functions which allocate a lot of stack space because it would increase +    // the caller stack usage dramatically. +    if (IsCallerRecursive && +        AllocatedSize > InlineConstants::TotalAllocaSizeRecursiveCaller) { +      if (ORE) +        ORE->emit([&]() { +          return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", +                                          CandidateCS.getInstruction()) +                 << NV("Callee", &F) +                 << " is recursive and allocates too much stack space. Cost is " +                    "not fully computed"; +        }); +      return false; +    } + +    // Check if we've past the maximum possible threshold so we don't spin in +    // huge basic blocks that will never inline. +    if (Cost >= Threshold && !ComputeFullInlineCost) +      return false; +  } + +  return true; +} + +/// Compute the base pointer and cumulative constant offsets for V. +/// +/// This strips all constant offsets off of V, leaving it the base pointer, and +/// accumulates the total constant offset applied in the returned constant. It +/// returns 0 if V is not a pointer, and returns the constant '0' if there are +/// no constant offsets applied. +ConstantInt *CallAnalyzer::stripAndComputeInBoundsConstantOffsets(Value *&V) { +  if (!V->getType()->isPointerTy()) +    return nullptr; + +  unsigned AS = V->getType()->getPointerAddressSpace(); +  unsigned IntPtrWidth = DL.getIndexSizeInBits(AS); +  APInt Offset = APInt::getNullValue(IntPtrWidth); + +  // Even though we don't look through PHI nodes, we could be called on an +  // instruction in an unreachable block, which may be on a cycle. +  SmallPtrSet<Value *, 4> Visited; +  Visited.insert(V); +  do { +    if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { +      if (!GEP->isInBounds() || !accumulateGEPOffset(*GEP, Offset)) +        return nullptr; +      V = GEP->getPointerOperand(); +    } else if (Operator::getOpcode(V) == Instruction::BitCast) { +      V = cast<Operator>(V)->getOperand(0); +    } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { +      if (GA->isInterposable()) +        break; +      V = GA->getAliasee(); +    } else { +      break; +    } +    assert(V->getType()->isPointerTy() && "Unexpected operand type!"); +  } while (Visited.insert(V).second); + +  Type *IntPtrTy = DL.getIntPtrType(V->getContext(), AS); +  return cast<ConstantInt>(ConstantInt::get(IntPtrTy, Offset)); +} + +/// Find dead blocks due to deleted CFG edges during inlining. +/// +/// If we know the successor of the current block, \p CurrBB, has to be \p +/// NextBB, the other successors of \p CurrBB are dead if these successors have +/// no live incoming CFG edges.  If one block is found to be dead, we can +/// continue growing the dead block list by checking the successors of the dead +/// blocks to see if all their incoming edges are dead or not. +void CallAnalyzer::findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB) { +  auto IsEdgeDead = [&](BasicBlock *Pred, BasicBlock *Succ) { +    // A CFG edge is dead if the predecessor is dead or the predessor has a +    // known successor which is not the one under exam. +    return (DeadBlocks.count(Pred) || +            (KnownSuccessors[Pred] && KnownSuccessors[Pred] != Succ)); +  }; + +  auto IsNewlyDead = [&](BasicBlock *BB) { +    // If all the edges to a block are dead, the block is also dead. +    return (!DeadBlocks.count(BB) && +            llvm::all_of(predecessors(BB), +                         [&](BasicBlock *P) { return IsEdgeDead(P, BB); })); +  }; + +  for (BasicBlock *Succ : successors(CurrBB)) { +    if (Succ == NextBB || !IsNewlyDead(Succ)) +      continue; +    SmallVector<BasicBlock *, 4> NewDead; +    NewDead.push_back(Succ); +    while (!NewDead.empty()) { +      BasicBlock *Dead = NewDead.pop_back_val(); +      if (DeadBlocks.insert(Dead)) +        // Continue growing the dead block lists. +        for (BasicBlock *S : successors(Dead)) +          if (IsNewlyDead(S)) +            NewDead.push_back(S); +    } +  } +} + +/// Analyze a call site for potential inlining. +/// +/// Returns true if inlining this call is viable, and false if it is not +/// viable. It computes the cost and adjusts the threshold based on numerous +/// factors and heuristics. If this method returns false but the computed cost +/// is below the computed threshold, then inlining was forcibly disabled by +/// some artifact of the routine. +bool CallAnalyzer::analyzeCall(CallSite CS) { +  ++NumCallsAnalyzed; + +  // Perform some tweaks to the cost and threshold based on the direct +  // callsite information. + +  // We want to more aggressively inline vector-dense kernels, so up the +  // threshold, and we'll lower it if the % of vector instructions gets too +  // low. Note that these bonuses are some what arbitrary and evolved over time +  // by accident as much as because they are principled bonuses. +  // +  // FIXME: It would be nice to remove all such bonuses. At least it would be +  // nice to base the bonus values on something more scientific. +  assert(NumInstructions == 0); +  assert(NumVectorInstructions == 0); + +  // Update the threshold based on callsite properties +  updateThreshold(CS, F); + +  // Speculatively apply all possible bonuses to Threshold. If cost exceeds +  // this Threshold any time, and cost cannot decrease, we can stop processing +  // the rest of the function body. +  Threshold += (SingleBBBonus + VectorBonus); + +  // Give out bonuses for the callsite, as the instructions setting them up +  // will be gone after inlining. +  Cost -= getCallsiteCost(CS, DL); + +  // If this function uses the coldcc calling convention, prefer not to inline +  // it. +  if (F.getCallingConv() == CallingConv::Cold) +    Cost += InlineConstants::ColdccPenalty; + +  // Check if we're done. This can happen due to bonuses and penalties. +  if (Cost >= Threshold && !ComputeFullInlineCost) +    return false; + +  if (F.empty()) +    return true; + +  Function *Caller = CS.getInstruction()->getFunction(); +  // Check if the caller function is recursive itself. +  for (User *U : Caller->users()) { +    CallSite Site(U); +    if (!Site) +      continue; +    Instruction *I = Site.getInstruction(); +    if (I->getFunction() == Caller) { +      IsCallerRecursive = true; +      break; +    } +  } + +  // Populate our simplified values by mapping from function arguments to call +  // arguments with known important simplifications. +  CallSite::arg_iterator CAI = CS.arg_begin(); +  for (Function::arg_iterator FAI = F.arg_begin(), FAE = F.arg_end(); +       FAI != FAE; ++FAI, ++CAI) { +    assert(CAI != CS.arg_end()); +    if (Constant *C = dyn_cast<Constant>(CAI)) +      SimplifiedValues[&*FAI] = C; + +    Value *PtrArg = *CAI; +    if (ConstantInt *C = stripAndComputeInBoundsConstantOffsets(PtrArg)) { +      ConstantOffsetPtrs[&*FAI] = std::make_pair(PtrArg, C->getValue()); + +      // We can SROA any pointer arguments derived from alloca instructions. +      if (isa<AllocaInst>(PtrArg)) { +        SROAArgValues[&*FAI] = PtrArg; +        SROAArgCosts[PtrArg] = 0; +      } +    } +  } +  NumConstantArgs = SimplifiedValues.size(); +  NumConstantOffsetPtrArgs = ConstantOffsetPtrs.size(); +  NumAllocaArgs = SROAArgValues.size(); + +  // FIXME: If a caller has multiple calls to a callee, we end up recomputing +  // the ephemeral values multiple times (and they're completely determined by +  // the callee, so this is purely duplicate work). +  SmallPtrSet<const Value *, 32> EphValues; +  CodeMetrics::collectEphemeralValues(&F, &GetAssumptionCache(F), EphValues); + +  // The worklist of live basic blocks in the callee *after* inlining. We avoid +  // adding basic blocks of the callee which can be proven to be dead for this +  // particular call site in order to get more accurate cost estimates. This +  // requires a somewhat heavyweight iteration pattern: we need to walk the +  // basic blocks in a breadth-first order as we insert live successors. To +  // accomplish this, prioritizing for small iterations because we exit after +  // crossing our threshold, we use a small-size optimized SetVector. +  typedef SetVector<BasicBlock *, SmallVector<BasicBlock *, 16>, +                    SmallPtrSet<BasicBlock *, 16>> +      BBSetVector; +  BBSetVector BBWorklist; +  BBWorklist.insert(&F.getEntryBlock()); +  bool SingleBB = true; +  // Note that we *must not* cache the size, this loop grows the worklist. +  for (unsigned Idx = 0; Idx != BBWorklist.size(); ++Idx) { +    // Bail out the moment we cross the threshold. This means we'll under-count +    // the cost, but only when undercounting doesn't matter. +    if (Cost >= Threshold && !ComputeFullInlineCost) +      break; + +    BasicBlock *BB = BBWorklist[Idx]; +    if (BB->empty()) +      continue; + +    // Disallow inlining a blockaddress. A blockaddress only has defined +    // behavior for an indirect branch in the same function, and we do not +    // currently support inlining indirect branches. But, the inliner may not +    // see an indirect branch that ends up being dead code at a particular call +    // site. If the blockaddress escapes the function, e.g., via a global +    // variable, inlining may lead to an invalid cross-function reference. +    if (BB->hasAddressTaken()) +      return false; + +    // Analyze the cost of this block. If we blow through the threshold, this +    // returns false, and we can bail on out. +    if (!analyzeBlock(BB, EphValues)) +      return false; + +    TerminatorInst *TI = BB->getTerminator(); + +    // Add in the live successors by first checking whether we have terminator +    // that may be simplified based on the values simplified by this call. +    if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { +      if (BI->isConditional()) { +        Value *Cond = BI->getCondition(); +        if (ConstantInt *SimpleCond = +                dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { +          BasicBlock *NextBB = BI->getSuccessor(SimpleCond->isZero() ? 1 : 0); +          BBWorklist.insert(NextBB); +          KnownSuccessors[BB] = NextBB; +          findDeadBlocks(BB, NextBB); +          continue; +        } +      } +    } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { +      Value *Cond = SI->getCondition(); +      if (ConstantInt *SimpleCond = +              dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { +        BasicBlock *NextBB = SI->findCaseValue(SimpleCond)->getCaseSuccessor(); +        BBWorklist.insert(NextBB); +        KnownSuccessors[BB] = NextBB; +        findDeadBlocks(BB, NextBB); +        continue; +      } +    } + +    // If we're unable to select a particular successor, just count all of +    // them. +    for (unsigned TIdx = 0, TSize = TI->getNumSuccessors(); TIdx != TSize; +         ++TIdx) +      BBWorklist.insert(TI->getSuccessor(TIdx)); + +    // If we had any successors at this point, than post-inlining is likely to +    // have them as well. Note that we assume any basic blocks which existed +    // due to branches or switches which folded above will also fold after +    // inlining. +    if (SingleBB && TI->getNumSuccessors() > 1) { +      // Take off the bonus we applied to the threshold. +      Threshold -= SingleBBBonus; +      SingleBB = false; +    } +  } + +  bool OnlyOneCallAndLocalLinkage = +      F.hasLocalLinkage() && F.hasOneUse() && &F == CS.getCalledFunction(); +  // If this is a noduplicate call, we can still inline as long as +  // inlining this would cause the removal of the caller (so the instruction +  // is not actually duplicated, just moved). +  if (!OnlyOneCallAndLocalLinkage && ContainsNoDuplicateCall) +    return false; + +  // We applied the maximum possible vector bonus at the beginning. Now, +  // subtract the excess bonus, if any, from the Threshold before +  // comparing against Cost. +  if (NumVectorInstructions <= NumInstructions / 10) +    Threshold -= VectorBonus; +  else if (NumVectorInstructions <= NumInstructions / 2) +    Threshold -= VectorBonus/2; + +  return Cost < std::max(1, Threshold); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +/// Dump stats about this call's analysis. +LLVM_DUMP_METHOD void CallAnalyzer::dump() { +#define DEBUG_PRINT_STAT(x) dbgs() << "      " #x ": " << x << "\n" +  DEBUG_PRINT_STAT(NumConstantArgs); +  DEBUG_PRINT_STAT(NumConstantOffsetPtrArgs); +  DEBUG_PRINT_STAT(NumAllocaArgs); +  DEBUG_PRINT_STAT(NumConstantPtrCmps); +  DEBUG_PRINT_STAT(NumConstantPtrDiffs); +  DEBUG_PRINT_STAT(NumInstructionsSimplified); +  DEBUG_PRINT_STAT(NumInstructions); +  DEBUG_PRINT_STAT(SROACostSavings); +  DEBUG_PRINT_STAT(SROACostSavingsLost); +  DEBUG_PRINT_STAT(LoadEliminationCost); +  DEBUG_PRINT_STAT(ContainsNoDuplicateCall); +  DEBUG_PRINT_STAT(Cost); +  DEBUG_PRINT_STAT(Threshold); +#undef DEBUG_PRINT_STAT +} +#endif + +/// Test that there are no attribute conflicts between Caller and Callee +///        that prevent inlining. +static bool functionsHaveCompatibleAttributes(Function *Caller, +                                              Function *Callee, +                                              TargetTransformInfo &TTI) { +  return TTI.areInlineCompatible(Caller, Callee) && +         AttributeFuncs::areInlineCompatible(*Caller, *Callee); +} + +int llvm::getCallsiteCost(CallSite CS, const DataLayout &DL) { +  int Cost = 0; +  for (unsigned I = 0, E = CS.arg_size(); I != E; ++I) { +    if (CS.isByValArgument(I)) { +      // We approximate the number of loads and stores needed by dividing the +      // size of the byval type by the target's pointer size. +      PointerType *PTy = cast<PointerType>(CS.getArgument(I)->getType()); +      unsigned TypeSize = DL.getTypeSizeInBits(PTy->getElementType()); +      unsigned AS = PTy->getAddressSpace(); +      unsigned PointerSize = DL.getPointerSizeInBits(AS); +      // Ceiling division. +      unsigned NumStores = (TypeSize + PointerSize - 1) / PointerSize; + +      // If it generates more than 8 stores it is likely to be expanded as an +      // inline memcpy so we take that as an upper bound. Otherwise we assume +      // one load and one store per word copied. +      // FIXME: The maxStoresPerMemcpy setting from the target should be used +      // here instead of a magic number of 8, but it's not available via +      // DataLayout. +      NumStores = std::min(NumStores, 8U); + +      Cost += 2 * NumStores * InlineConstants::InstrCost; +    } else { +      // For non-byval arguments subtract off one instruction per call +      // argument. +      Cost += InlineConstants::InstrCost; +    } +  } +  // The call instruction also disappears after inlining. +  Cost += InlineConstants::InstrCost + InlineConstants::CallPenalty; +  return Cost; +} + +InlineCost llvm::getInlineCost( +    CallSite CS, const InlineParams &Params, TargetTransformInfo &CalleeTTI, +    std::function<AssumptionCache &(Function &)> &GetAssumptionCache, +    Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI, +    ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) { +  return getInlineCost(CS, CS.getCalledFunction(), Params, CalleeTTI, +                       GetAssumptionCache, GetBFI, PSI, ORE); +} + +InlineCost llvm::getInlineCost( +    CallSite CS, Function *Callee, const InlineParams &Params, +    TargetTransformInfo &CalleeTTI, +    std::function<AssumptionCache &(Function &)> &GetAssumptionCache, +    Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI, +    ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) { + +  // Cannot inline indirect calls. +  if (!Callee) +    return llvm::InlineCost::getNever(); + +  // Never inline calls with byval arguments that does not have the alloca +  // address space. Since byval arguments can be replaced with a copy to an +  // alloca, the inlined code would need to be adjusted to handle that the +  // argument is in the alloca address space (so it is a little bit complicated +  // to solve). +  unsigned AllocaAS = Callee->getParent()->getDataLayout().getAllocaAddrSpace(); +  for (unsigned I = 0, E = CS.arg_size(); I != E; ++I) +    if (CS.isByValArgument(I)) { +      PointerType *PTy = cast<PointerType>(CS.getArgument(I)->getType()); +      if (PTy->getAddressSpace() != AllocaAS) +        return llvm::InlineCost::getNever(); +    } + +  // Calls to functions with always-inline attributes should be inlined +  // whenever possible. +  if (CS.hasFnAttr(Attribute::AlwaysInline)) { +    if (isInlineViable(*Callee)) +      return llvm::InlineCost::getAlways(); +    return llvm::InlineCost::getNever(); +  } + +  // Never inline functions with conflicting attributes (unless callee has +  // always-inline attribute). +  Function *Caller = CS.getCaller(); +  if (!functionsHaveCompatibleAttributes(Caller, Callee, CalleeTTI)) +    return llvm::InlineCost::getNever(); + +  // Don't inline this call if the caller has the optnone attribute. +  if (Caller->hasFnAttribute(Attribute::OptimizeNone)) +    return llvm::InlineCost::getNever(); + +  // Don't inline a function that treats null pointer as valid into a caller +  // that does not have this attribute. +  if (!Caller->nullPointerIsDefined() && Callee->nullPointerIsDefined()) +    return llvm::InlineCost::getNever(); + +  // Don't inline functions which can be interposed at link-time.  Don't inline +  // functions marked noinline or call sites marked noinline. +  // Note: inlining non-exact non-interposable functions is fine, since we know +  // we have *a* correct implementation of the source level function. +  if (Callee->isInterposable() || Callee->hasFnAttribute(Attribute::NoInline) || +      CS.isNoInline()) +    return llvm::InlineCost::getNever(); + +  LLVM_DEBUG(llvm::dbgs() << "      Analyzing call of " << Callee->getName() +                          << "... (caller:" << Caller->getName() << ")\n"); + +  CallAnalyzer CA(CalleeTTI, GetAssumptionCache, GetBFI, PSI, ORE, *Callee, CS, +                  Params); +  bool ShouldInline = CA.analyzeCall(CS); + +  LLVM_DEBUG(CA.dump()); + +  // Check if there was a reason to force inlining or no inlining. +  if (!ShouldInline && CA.getCost() < CA.getThreshold()) +    return InlineCost::getNever(); +  if (ShouldInline && CA.getCost() >= CA.getThreshold()) +    return InlineCost::getAlways(); + +  return llvm::InlineCost::get(CA.getCost(), CA.getThreshold()); +} + +bool llvm::isInlineViable(Function &F) { +  bool ReturnsTwice = F.hasFnAttribute(Attribute::ReturnsTwice); +  for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) { +    // Disallow inlining of functions which contain indirect branches or +    // blockaddresses. +    if (isa<IndirectBrInst>(BI->getTerminator()) || BI->hasAddressTaken()) +      return false; + +    for (auto &II : *BI) { +      CallSite CS(&II); +      if (!CS) +        continue; + +      // Disallow recursive calls. +      if (&F == CS.getCalledFunction()) +        return false; + +      // Disallow calls which expose returns-twice to a function not previously +      // attributed as such. +      if (!ReturnsTwice && CS.isCall() && +          cast<CallInst>(CS.getInstruction())->canReturnTwice()) +        return false; + +      if (CS.getCalledFunction()) +        switch (CS.getCalledFunction()->getIntrinsicID()) { +        default: +          break; +        // Disallow inlining of @llvm.icall.branch.funnel because current +        // backend can't separate call targets from call arguments. +        case llvm::Intrinsic::icall_branch_funnel: +        // Disallow inlining functions that call @llvm.localescape. Doing this +        // correctly would require major changes to the inliner. +        case llvm::Intrinsic::localescape: +        // Disallow inlining of functions that access VarArgs. +        case llvm::Intrinsic::vastart: +        case llvm::Intrinsic::vaend: +          return false; +        } +    } +  } + +  return true; +} + +// APIs to create InlineParams based on command line flags and/or other +// parameters. + +InlineParams llvm::getInlineParams(int Threshold) { +  InlineParams Params; + +  // This field is the threshold to use for a callee by default. This is +  // derived from one or more of: +  //  * optimization or size-optimization levels, +  //  * a value passed to createFunctionInliningPass function, or +  //  * the -inline-threshold flag. +  //  If the -inline-threshold flag is explicitly specified, that is used +  //  irrespective of anything else. +  if (InlineThreshold.getNumOccurrences() > 0) +    Params.DefaultThreshold = InlineThreshold; +  else +    Params.DefaultThreshold = Threshold; + +  // Set the HintThreshold knob from the -inlinehint-threshold. +  Params.HintThreshold = HintThreshold; + +  // Set the HotCallSiteThreshold knob from the -hot-callsite-threshold. +  Params.HotCallSiteThreshold = HotCallSiteThreshold; + +  // If the -locally-hot-callsite-threshold is explicitly specified, use it to +  // populate LocallyHotCallSiteThreshold. Later, we populate +  // Params.LocallyHotCallSiteThreshold from -locally-hot-callsite-threshold if +  // we know that optimization level is O3 (in the getInlineParams variant that +  // takes the opt and size levels). +  // FIXME: Remove this check (and make the assignment unconditional) after +  // addressing size regression issues at O2. +  if (LocallyHotCallSiteThreshold.getNumOccurrences() > 0) +    Params.LocallyHotCallSiteThreshold = LocallyHotCallSiteThreshold; + +  // Set the ColdCallSiteThreshold knob from the -inline-cold-callsite-threshold. +  Params.ColdCallSiteThreshold = ColdCallSiteThreshold; + +  // Set the OptMinSizeThreshold and OptSizeThreshold params only if the +  // -inlinehint-threshold commandline option is not explicitly given. If that +  // option is present, then its value applies even for callees with size and +  // minsize attributes. +  // If the -inline-threshold is not specified, set the ColdThreshold from the +  // -inlinecold-threshold even if it is not explicitly passed. If +  // -inline-threshold is specified, then -inlinecold-threshold needs to be +  // explicitly specified to set the ColdThreshold knob +  if (InlineThreshold.getNumOccurrences() == 0) { +    Params.OptMinSizeThreshold = InlineConstants::OptMinSizeThreshold; +    Params.OptSizeThreshold = InlineConstants::OptSizeThreshold; +    Params.ColdThreshold = ColdThreshold; +  } else if (ColdThreshold.getNumOccurrences() > 0) { +    Params.ColdThreshold = ColdThreshold; +  } +  return Params; +} + +InlineParams llvm::getInlineParams() { +  return getInlineParams(InlineThreshold); +} + +// Compute the default threshold for inlining based on the opt level and the +// size opt level. +static int computeThresholdFromOptLevels(unsigned OptLevel, +                                         unsigned SizeOptLevel) { +  if (OptLevel > 2) +    return InlineConstants::OptAggressiveThreshold; +  if (SizeOptLevel == 1) // -Os +    return InlineConstants::OptSizeThreshold; +  if (SizeOptLevel == 2) // -Oz +    return InlineConstants::OptMinSizeThreshold; +  return InlineThreshold; +} + +InlineParams llvm::getInlineParams(unsigned OptLevel, unsigned SizeOptLevel) { +  auto Params = +      getInlineParams(computeThresholdFromOptLevels(OptLevel, SizeOptLevel)); +  // At O3, use the value of -locally-hot-callsite-threshold option to populate +  // Params.LocallyHotCallSiteThreshold. Below O3, this flag has effect only +  // when it is specified explicitly. +  if (OptLevel > 2) +    Params.LocallyHotCallSiteThreshold = LocallyHotCallSiteThreshold; +  return Params; +} diff --git a/contrib/llvm/lib/Analysis/InstCount.cpp b/contrib/llvm/lib/Analysis/InstCount.cpp new file mode 100644 index 000000000000..95ab6ee3db5b --- /dev/null +++ b/contrib/llvm/lib/Analysis/InstCount.cpp @@ -0,0 +1,79 @@ +//===-- InstCount.cpp - Collects the count of all instructions ------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass collects the count of all instructions and reports them +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "instcount" + +STATISTIC(TotalInsts , "Number of instructions (of all types)"); +STATISTIC(TotalBlocks, "Number of basic blocks"); +STATISTIC(TotalFuncs , "Number of non-external functions"); + +#define HANDLE_INST(N, OPCODE, CLASS) \ +  STATISTIC(Num ## OPCODE ## Inst, "Number of " #OPCODE " insts"); + +#include "llvm/IR/Instruction.def" + +namespace { +  class InstCount : public FunctionPass, public InstVisitor<InstCount> { +    friend class InstVisitor<InstCount>; + +    void visitFunction  (Function &F) { ++TotalFuncs; } +    void visitBasicBlock(BasicBlock &BB) { ++TotalBlocks; } + +#define HANDLE_INST(N, OPCODE, CLASS) \ +    void visit##OPCODE(CLASS &) { ++Num##OPCODE##Inst; ++TotalInsts; } + +#include "llvm/IR/Instruction.def" + +    void visitInstruction(Instruction &I) { +      errs() << "Instruction Count does not know about " << I; +      llvm_unreachable(nullptr); +    } +  public: +    static char ID; // Pass identification, replacement for typeid +    InstCount() : FunctionPass(ID) { +      initializeInstCountPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnFunction(Function &F) override; + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } +    void print(raw_ostream &O, const Module *M) const override {} + +  }; +} + +char InstCount::ID = 0; +INITIALIZE_PASS(InstCount, "instcount", +                "Counts the various types of Instructions", false, true) + +FunctionPass *llvm::createInstCountPass() { return new InstCount(); } + +// InstCount::run - This is the main Analysis entry point for a +// function. +// +bool InstCount::runOnFunction(Function &F) { +  visit(F); +  return false; +} diff --git a/contrib/llvm/lib/Analysis/InstructionSimplify.cpp b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp new file mode 100644 index 000000000000..5e72798d459a --- /dev/null +++ b/contrib/llvm/lib/Analysis/InstructionSimplify.cpp @@ -0,0 +1,5181 @@ +//===- InstructionSimplify.cpp - Fold instruction operands ----------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements routines for folding instructions into simpler forms +// that do not require creating new instructions.  This does constant folding +// ("add i32 1, 1" -> "2") but can also handle non-constant operands, either +// returning a constant ("and i32 %x, 0" -> "0") or an already existing value +// ("and i32 %x, %x" -> "%x").  All operands are assumed to have already been +// simplified: This is usually true and assuming it simplifies the logic (if +// they have not been simplified then results are correct but maybe suboptimal). +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/CmpInstAnalysis.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/KnownBits.h" +#include <algorithm> +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "instsimplify" + +enum { RecursionLimit = 3 }; + +STATISTIC(NumExpand,  "Number of expansions"); +STATISTIC(NumReassoc, "Number of reassociations"); + +static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, +                            unsigned); +static Value *SimplifyFPBinOp(unsigned, Value *, Value *, const FastMathFlags &, +                              const SimplifyQuery &, unsigned); +static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, +                              unsigned); +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +                               const SimplifyQuery &Q, unsigned MaxRecurse); +static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyCastInst(unsigned, Value *, Type *, +                               const SimplifyQuery &, unsigned); +static Value *SimplifyGEPInst(Type *, ArrayRef<Value *>, const SimplifyQuery &, +                              unsigned); + +static Value *foldSelectWithBinaryOp(Value *Cond, Value *TrueVal, +                                     Value *FalseVal) { +  BinaryOperator::BinaryOps BinOpCode; +  if (auto *BO = dyn_cast<BinaryOperator>(Cond)) +    BinOpCode = BO->getOpcode(); +  else +    return nullptr; + +  CmpInst::Predicate ExpectedPred, Pred1, Pred2; +  if (BinOpCode == BinaryOperator::Or) { +    ExpectedPred = ICmpInst::ICMP_NE; +  } else if (BinOpCode == BinaryOperator::And) { +    ExpectedPred = ICmpInst::ICMP_EQ; +  } else +    return nullptr; + +  // %A = icmp eq %TV, %FV +  // %B = icmp eq %X, %Y (and one of these is a select operand) +  // %C = and %A, %B +  // %D = select %C, %TV, %FV +  // --> +  // %FV + +  // %A = icmp ne %TV, %FV +  // %B = icmp ne %X, %Y (and one of these is a select operand) +  // %C = or %A, %B +  // %D = select %C, %TV, %FV +  // --> +  // %TV +  Value *X, *Y; +  if (!match(Cond, m_c_BinOp(m_c_ICmp(Pred1, m_Specific(TrueVal), +                                      m_Specific(FalseVal)), +                             m_ICmp(Pred2, m_Value(X), m_Value(Y)))) || +      Pred1 != Pred2 || Pred1 != ExpectedPred) +    return nullptr; + +  if (X == TrueVal || X == FalseVal || Y == TrueVal || Y == FalseVal) +    return BinOpCode == BinaryOperator::Or ? TrueVal : FalseVal; + +  return nullptr; +} + +/// For a boolean type or a vector of boolean type, return false or a vector +/// with every element false. +static Constant *getFalse(Type *Ty) { +  return ConstantInt::getFalse(Ty); +} + +/// For a boolean type or a vector of boolean type, return true or a vector +/// with every element true. +static Constant *getTrue(Type *Ty) { +  return ConstantInt::getTrue(Ty); +} + +/// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? +static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, +                          Value *RHS) { +  CmpInst *Cmp = dyn_cast<CmpInst>(V); +  if (!Cmp) +    return false; +  CmpInst::Predicate CPred = Cmp->getPredicate(); +  Value *CLHS = Cmp->getOperand(0), *CRHS = Cmp->getOperand(1); +  if (CPred == Pred && CLHS == LHS && CRHS == RHS) +    return true; +  return CPred == CmpInst::getSwappedPredicate(Pred) && CLHS == RHS && +    CRHS == LHS; +} + +/// Does the given value dominate the specified phi node? +static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { +  Instruction *I = dyn_cast<Instruction>(V); +  if (!I) +    // Arguments and constants dominate all instructions. +    return true; + +  // If we are processing instructions (and/or basic blocks) that have not been +  // fully added to a function, the parent nodes may still be null. Simply +  // return the conservative answer in these cases. +  if (!I->getParent() || !P->getParent() || !I->getFunction()) +    return false; + +  // If we have a DominatorTree then do a precise test. +  if (DT) +    return DT->dominates(I, P); + +  // Otherwise, if the instruction is in the entry block and is not an invoke, +  // then it obviously dominates all phi nodes. +  if (I->getParent() == &I->getFunction()->getEntryBlock() && +      !isa<InvokeInst>(I)) +    return true; + +  return false; +} + +/// Simplify "A op (B op' C)" by distributing op over op', turning it into +/// "(A op B) op' (A op C)".  Here "op" is given by Opcode and "op'" is +/// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS. +/// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". +/// Returns the simplified value, or null if no simplification was performed. +static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, +                          Instruction::BinaryOps OpcodeToExpand, +                          const SimplifyQuery &Q, unsigned MaxRecurse) { +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return nullptr; + +  // Check whether the expression has the form "(A op' B) op C". +  if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS)) +    if (Op0->getOpcode() == OpcodeToExpand) { +      // It does!  Try turning it into "(A op C) op' (B op C)". +      Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; +      // Do "A op C" and "B op C" both simplify? +      if (Value *L = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) +        if (Value *R = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { +          // They do! Return "L op' R" if it simplifies or is already available. +          // If "L op' R" equals "A op' B" then "L op' R" is just the LHS. +          if ((L == A && R == B) || (Instruction::isCommutative(OpcodeToExpand) +                                     && L == B && R == A)) { +            ++NumExpand; +            return LHS; +          } +          // Otherwise return "L op' R" if it simplifies. +          if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { +            ++NumExpand; +            return V; +          } +        } +    } + +  // Check whether the expression has the form "A op (B op' C)". +  if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS)) +    if (Op1->getOpcode() == OpcodeToExpand) { +      // It does!  Try turning it into "(A op B) op' (A op C)". +      Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); +      // Do "A op B" and "A op C" both simplify? +      if (Value *L = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) +        if (Value *R = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) { +          // They do! Return "L op' R" if it simplifies or is already available. +          // If "L op' R" equals "B op' C" then "L op' R" is just the RHS. +          if ((L == B && R == C) || (Instruction::isCommutative(OpcodeToExpand) +                                     && L == C && R == B)) { +            ++NumExpand; +            return RHS; +          } +          // Otherwise return "L op' R" if it simplifies. +          if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { +            ++NumExpand; +            return V; +          } +        } +    } + +  return nullptr; +} + +/// Generic simplifications for associative binary operations. +/// Returns the simpler value, or null if none was found. +static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, +                                       Value *LHS, Value *RHS, +                                       const SimplifyQuery &Q, +                                       unsigned MaxRecurse) { +  assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); + +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return nullptr; + +  BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); +  BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); + +  // Transform: "(A op B) op C" ==> "A op (B op C)" if it simplifies completely. +  if (Op0 && Op0->getOpcode() == Opcode) { +    Value *A = Op0->getOperand(0); +    Value *B = Op0->getOperand(1); +    Value *C = RHS; + +    // Does "B op C" simplify? +    if (Value *V = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { +      // It does!  Return "A op V" if it simplifies or is already available. +      // If V equals B then "A op V" is just the LHS. +      if (V == B) return LHS; +      // Otherwise return "A op V" if it simplifies. +      if (Value *W = SimplifyBinOp(Opcode, A, V, Q, MaxRecurse)) { +        ++NumReassoc; +        return W; +      } +    } +  } + +  // Transform: "A op (B op C)" ==> "(A op B) op C" if it simplifies completely. +  if (Op1 && Op1->getOpcode() == Opcode) { +    Value *A = LHS; +    Value *B = Op1->getOperand(0); +    Value *C = Op1->getOperand(1); + +    // Does "A op B" simplify? +    if (Value *V = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) { +      // It does!  Return "V op C" if it simplifies or is already available. +      // If V equals B then "V op C" is just the RHS. +      if (V == B) return RHS; +      // Otherwise return "V op C" if it simplifies. +      if (Value *W = SimplifyBinOp(Opcode, V, C, Q, MaxRecurse)) { +        ++NumReassoc; +        return W; +      } +    } +  } + +  // The remaining transforms require commutativity as well as associativity. +  if (!Instruction::isCommutative(Opcode)) +    return nullptr; + +  // Transform: "(A op B) op C" ==> "(C op A) op B" if it simplifies completely. +  if (Op0 && Op0->getOpcode() == Opcode) { +    Value *A = Op0->getOperand(0); +    Value *B = Op0->getOperand(1); +    Value *C = RHS; + +    // Does "C op A" simplify? +    if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { +      // It does!  Return "V op B" if it simplifies or is already available. +      // If V equals A then "V op B" is just the LHS. +      if (V == A) return LHS; +      // Otherwise return "V op B" if it simplifies. +      if (Value *W = SimplifyBinOp(Opcode, V, B, Q, MaxRecurse)) { +        ++NumReassoc; +        return W; +      } +    } +  } + +  // Transform: "A op (B op C)" ==> "B op (C op A)" if it simplifies completely. +  if (Op1 && Op1->getOpcode() == Opcode) { +    Value *A = LHS; +    Value *B = Op1->getOperand(0); +    Value *C = Op1->getOperand(1); + +    // Does "C op A" simplify? +    if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { +      // It does!  Return "B op V" if it simplifies or is already available. +      // If V equals C then "B op V" is just the RHS. +      if (V == C) return RHS; +      // Otherwise return "B op V" if it simplifies. +      if (Value *W = SimplifyBinOp(Opcode, B, V, Q, MaxRecurse)) { +        ++NumReassoc; +        return W; +      } +    } +  } + +  return nullptr; +} + +/// In the case of a binary operation with a select instruction as an operand, +/// try to simplify the binop by seeing whether evaluating it on both branches +/// of the select results in the same value. Returns the common value if so, +/// otherwise returns null. +static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, +                                    Value *RHS, const SimplifyQuery &Q, +                                    unsigned MaxRecurse) { +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return nullptr; + +  SelectInst *SI; +  if (isa<SelectInst>(LHS)) { +    SI = cast<SelectInst>(LHS); +  } else { +    assert(isa<SelectInst>(RHS) && "No select instruction operand!"); +    SI = cast<SelectInst>(RHS); +  } + +  // Evaluate the BinOp on the true and false branches of the select. +  Value *TV; +  Value *FV; +  if (SI == LHS) { +    TV = SimplifyBinOp(Opcode, SI->getTrueValue(), RHS, Q, MaxRecurse); +    FV = SimplifyBinOp(Opcode, SI->getFalseValue(), RHS, Q, MaxRecurse); +  } else { +    TV = SimplifyBinOp(Opcode, LHS, SI->getTrueValue(), Q, MaxRecurse); +    FV = SimplifyBinOp(Opcode, LHS, SI->getFalseValue(), Q, MaxRecurse); +  } + +  // If they simplified to the same value, then return the common value. +  // If they both failed to simplify then return null. +  if (TV == FV) +    return TV; + +  // If one branch simplified to undef, return the other one. +  if (TV && isa<UndefValue>(TV)) +    return FV; +  if (FV && isa<UndefValue>(FV)) +    return TV; + +  // If applying the operation did not change the true and false select values, +  // then the result of the binop is the select itself. +  if (TV == SI->getTrueValue() && FV == SI->getFalseValue()) +    return SI; + +  // If one branch simplified and the other did not, and the simplified +  // value is equal to the unsimplified one, return the simplified value. +  // For example, select (cond, X, X & Z) & Z -> X & Z. +  if ((FV && !TV) || (TV && !FV)) { +    // Check that the simplified value has the form "X op Y" where "op" is the +    // same as the original operation. +    Instruction *Simplified = dyn_cast<Instruction>(FV ? FV : TV); +    if (Simplified && Simplified->getOpcode() == unsigned(Opcode)) { +      // The value that didn't simplify is "UnsimplifiedLHS op UnsimplifiedRHS". +      // We already know that "op" is the same as for the simplified value.  See +      // if the operands match too.  If so, return the simplified value. +      Value *UnsimplifiedBranch = FV ? SI->getTrueValue() : SI->getFalseValue(); +      Value *UnsimplifiedLHS = SI == LHS ? UnsimplifiedBranch : LHS; +      Value *UnsimplifiedRHS = SI == LHS ? RHS : UnsimplifiedBranch; +      if (Simplified->getOperand(0) == UnsimplifiedLHS && +          Simplified->getOperand(1) == UnsimplifiedRHS) +        return Simplified; +      if (Simplified->isCommutative() && +          Simplified->getOperand(1) == UnsimplifiedLHS && +          Simplified->getOperand(0) == UnsimplifiedRHS) +        return Simplified; +    } +  } + +  return nullptr; +} + +/// In the case of a comparison with a select instruction, try to simplify the +/// comparison by seeing whether both branches of the select result in the same +/// value. Returns the common value if so, otherwise returns null. +static Value *ThreadCmpOverSelect(CmpInst::Predicate Pred, Value *LHS, +                                  Value *RHS, const SimplifyQuery &Q, +                                  unsigned MaxRecurse) { +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return nullptr; + +  // Make sure the select is on the LHS. +  if (!isa<SelectInst>(LHS)) { +    std::swap(LHS, RHS); +    Pred = CmpInst::getSwappedPredicate(Pred); +  } +  assert(isa<SelectInst>(LHS) && "Not comparing with a select instruction!"); +  SelectInst *SI = cast<SelectInst>(LHS); +  Value *Cond = SI->getCondition(); +  Value *TV = SI->getTrueValue(); +  Value *FV = SI->getFalseValue(); + +  // Now that we have "cmp select(Cond, TV, FV), RHS", analyse it. +  // Does "cmp TV, RHS" simplify? +  Value *TCmp = SimplifyCmpInst(Pred, TV, RHS, Q, MaxRecurse); +  if (TCmp == Cond) { +    // It not only simplified, it simplified to the select condition.  Replace +    // it with 'true'. +    TCmp = getTrue(Cond->getType()); +  } else if (!TCmp) { +    // It didn't simplify.  However if "cmp TV, RHS" is equal to the select +    // condition then we can replace it with 'true'.  Otherwise give up. +    if (!isSameCompare(Cond, Pred, TV, RHS)) +      return nullptr; +    TCmp = getTrue(Cond->getType()); +  } + +  // Does "cmp FV, RHS" simplify? +  Value *FCmp = SimplifyCmpInst(Pred, FV, RHS, Q, MaxRecurse); +  if (FCmp == Cond) { +    // It not only simplified, it simplified to the select condition.  Replace +    // it with 'false'. +    FCmp = getFalse(Cond->getType()); +  } else if (!FCmp) { +    // It didn't simplify.  However if "cmp FV, RHS" is equal to the select +    // condition then we can replace it with 'false'.  Otherwise give up. +    if (!isSameCompare(Cond, Pred, FV, RHS)) +      return nullptr; +    FCmp = getFalse(Cond->getType()); +  } + +  // If both sides simplified to the same value, then use it as the result of +  // the original comparison. +  if (TCmp == FCmp) +    return TCmp; + +  // The remaining cases only make sense if the select condition has the same +  // type as the result of the comparison, so bail out if this is not so. +  if (Cond->getType()->isVectorTy() != RHS->getType()->isVectorTy()) +    return nullptr; +  // If the false value simplified to false, then the result of the compare +  // is equal to "Cond && TCmp".  This also catches the case when the false +  // value simplified to false and the true value to true, returning "Cond". +  if (match(FCmp, m_Zero())) +    if (Value *V = SimplifyAndInst(Cond, TCmp, Q, MaxRecurse)) +      return V; +  // If the true value simplified to true, then the result of the compare +  // is equal to "Cond || FCmp". +  if (match(TCmp, m_One())) +    if (Value *V = SimplifyOrInst(Cond, FCmp, Q, MaxRecurse)) +      return V; +  // Finally, if the false value simplified to true and the true value to +  // false, then the result of the compare is equal to "!Cond". +  if (match(FCmp, m_One()) && match(TCmp, m_Zero())) +    if (Value *V = +        SimplifyXorInst(Cond, Constant::getAllOnesValue(Cond->getType()), +                        Q, MaxRecurse)) +      return V; + +  return nullptr; +} + +/// In the case of a binary operation with an operand that is a PHI instruction, +/// try to simplify the binop by seeing whether evaluating it on the incoming +/// phi values yields the same result for every value. If so returns the common +/// value, otherwise returns null. +static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS, +                                 Value *RHS, const SimplifyQuery &Q, +                                 unsigned MaxRecurse) { +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return nullptr; + +  PHINode *PI; +  if (isa<PHINode>(LHS)) { +    PI = cast<PHINode>(LHS); +    // Bail out if RHS and the phi may be mutually interdependent due to a loop. +    if (!valueDominatesPHI(RHS, PI, Q.DT)) +      return nullptr; +  } else { +    assert(isa<PHINode>(RHS) && "No PHI instruction operand!"); +    PI = cast<PHINode>(RHS); +    // Bail out if LHS and the phi may be mutually interdependent due to a loop. +    if (!valueDominatesPHI(LHS, PI, Q.DT)) +      return nullptr; +  } + +  // Evaluate the BinOp on the incoming phi values. +  Value *CommonValue = nullptr; +  for (Value *Incoming : PI->incoming_values()) { +    // If the incoming value is the phi node itself, it can safely be skipped. +    if (Incoming == PI) continue; +    Value *V = PI == LHS ? +      SimplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) : +      SimplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse); +    // If the operation failed to simplify, or simplified to a different value +    // to previously, then give up. +    if (!V || (CommonValue && V != CommonValue)) +      return nullptr; +    CommonValue = V; +  } + +  return CommonValue; +} + +/// In the case of a comparison with a PHI instruction, try to simplify the +/// comparison by seeing whether comparing with all of the incoming phi values +/// yields the same result every time. If so returns the common result, +/// otherwise returns null. +static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, +                               const SimplifyQuery &Q, unsigned MaxRecurse) { +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return nullptr; + +  // Make sure the phi is on the LHS. +  if (!isa<PHINode>(LHS)) { +    std::swap(LHS, RHS); +    Pred = CmpInst::getSwappedPredicate(Pred); +  } +  assert(isa<PHINode>(LHS) && "Not comparing with a phi instruction!"); +  PHINode *PI = cast<PHINode>(LHS); + +  // Bail out if RHS and the phi may be mutually interdependent due to a loop. +  if (!valueDominatesPHI(RHS, PI, Q.DT)) +    return nullptr; + +  // Evaluate the BinOp on the incoming phi values. +  Value *CommonValue = nullptr; +  for (Value *Incoming : PI->incoming_values()) { +    // If the incoming value is the phi node itself, it can safely be skipped. +    if (Incoming == PI) continue; +    Value *V = SimplifyCmpInst(Pred, Incoming, RHS, Q, MaxRecurse); +    // If the operation failed to simplify, or simplified to a different value +    // to previously, then give up. +    if (!V || (CommonValue && V != CommonValue)) +      return nullptr; +    CommonValue = V; +  } + +  return CommonValue; +} + +static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode, +                                       Value *&Op0, Value *&Op1, +                                       const SimplifyQuery &Q) { +  if (auto *CLHS = dyn_cast<Constant>(Op0)) { +    if (auto *CRHS = dyn_cast<Constant>(Op1)) +      return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); + +    // Canonicalize the constant to the RHS if this is a commutative operation. +    if (Instruction::isCommutative(Opcode)) +      std::swap(Op0, Op1); +  } +  return nullptr; +} + +/// Given operands for an Add, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, +                              const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) +    return C; + +  // X + undef -> undef +  if (match(Op1, m_Undef())) +    return Op1; + +  // X + 0 -> X +  if (match(Op1, m_Zero())) +    return Op0; + +  // If two operands are negative, return 0. +  if (isKnownNegation(Op0, Op1)) +    return Constant::getNullValue(Op0->getType()); + +  // X + (Y - X) -> Y +  // (Y - X) + X -> Y +  // Eg: X + -X -> 0 +  Value *Y = nullptr; +  if (match(Op1, m_Sub(m_Value(Y), m_Specific(Op0))) || +      match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)))) +    return Y; + +  // X + ~X -> -1   since   ~X = -X-1 +  Type *Ty = Op0->getType(); +  if (match(Op0, m_Not(m_Specific(Op1))) || +      match(Op1, m_Not(m_Specific(Op0)))) +    return Constant::getAllOnesValue(Ty); + +  // add nsw/nuw (xor Y, signmask), signmask --> Y +  // The no-wrapping add guarantees that the top bit will be set by the add. +  // Therefore, the xor must be clearing the already set sign bit of Y. +  if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) && +      match(Op0, m_Xor(m_Value(Y), m_SignMask()))) +    return Y; + +  // add nuw %x, -1  ->  -1, because %x can only be 0. +  if (IsNUW && match(Op1, m_AllOnes())) +    return Op1; // Which is -1. + +  /// i1 add -> xor. +  if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) +    if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) +      return V; + +  // Try some generic simplifications for associative operations. +  if (Value *V = SimplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q, +                                          MaxRecurse)) +    return V; + +  // Threading Add over selects and phi nodes is pointless, so don't bother. +  // Threading over the select in "A + select(cond, B, C)" means evaluating +  // "A+B" and "A+C" and seeing if they are equal; but they are equal if and +  // only if B and C are equal.  If B and C are equal then (since we assume +  // that operands have already been simplified) "select(cond, B, C)" should +  // have been simplified to the common value of B and C already.  Analysing +  // "A+B" and "A+C" thus gains nothing, but costs compile time.  Similarly +  // for threading over phi nodes. + +  return nullptr; +} + +Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, +                             const SimplifyQuery &Query) { +  return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); +} + +/// Compute the base pointer and cumulative constant offsets for V. +/// +/// This strips all constant offsets off of V, leaving it the base pointer, and +/// accumulates the total constant offset applied in the returned constant. It +/// returns 0 if V is not a pointer, and returns the constant '0' if there are +/// no constant offsets applied. +/// +/// This is very similar to GetPointerBaseWithConstantOffset except it doesn't +/// follow non-inbounds geps. This allows it to remain usable for icmp ult/etc. +/// folding. +static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, +                                                bool AllowNonInbounds = false) { +  assert(V->getType()->isPtrOrPtrVectorTy()); + +  Type *IntPtrTy = DL.getIntPtrType(V->getType())->getScalarType(); +  APInt Offset = APInt::getNullValue(IntPtrTy->getIntegerBitWidth()); + +  // Even though we don't look through PHI nodes, we could be called on an +  // instruction in an unreachable block, which may be on a cycle. +  SmallPtrSet<Value *, 4> Visited; +  Visited.insert(V); +  do { +    if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { +      if ((!AllowNonInbounds && !GEP->isInBounds()) || +          !GEP->accumulateConstantOffset(DL, Offset)) +        break; +      V = GEP->getPointerOperand(); +    } else if (Operator::getOpcode(V) == Instruction::BitCast) { +      V = cast<Operator>(V)->getOperand(0); +    } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { +      if (GA->isInterposable()) +        break; +      V = GA->getAliasee(); +    } else { +      if (auto CS = CallSite(V)) +        if (Value *RV = CS.getReturnedArgOperand()) { +          V = RV; +          continue; +        } +      break; +    } +    assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!"); +  } while (Visited.insert(V).second); + +  Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset); +  if (V->getType()->isVectorTy()) +    return ConstantVector::getSplat(V->getType()->getVectorNumElements(), +                                    OffsetIntPtr); +  return OffsetIntPtr; +} + +/// Compute the constant difference between two pointer values. +/// If the difference is not a constant, returns zero. +static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, +                                          Value *RHS) { +  Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS); +  Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS); + +  // If LHS and RHS are not related via constant offsets to the same base +  // value, there is nothing we can do here. +  if (LHS != RHS) +    return nullptr; + +  // Otherwise, the difference of LHS - RHS can be computed as: +  //    LHS - RHS +  //  = (LHSOffset + Base) - (RHSOffset + Base) +  //  = LHSOffset - RHSOffset +  return ConstantExpr::getSub(LHSOffset, RHSOffset); +} + +/// Given operands for a Sub, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +                              const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q)) +    return C; + +  // X - undef -> undef +  // undef - X -> undef +  if (match(Op0, m_Undef()) || match(Op1, m_Undef())) +    return UndefValue::get(Op0->getType()); + +  // X - 0 -> X +  if (match(Op1, m_Zero())) +    return Op0; + +  // X - X -> 0 +  if (Op0 == Op1) +    return Constant::getNullValue(Op0->getType()); + +  // Is this a negation? +  if (match(Op0, m_Zero())) { +    // 0 - X -> 0 if the sub is NUW. +    if (isNUW) +      return Constant::getNullValue(Op0->getType()); + +    KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +    if (Known.Zero.isMaxSignedValue()) { +      // Op1 is either 0 or the minimum signed value. If the sub is NSW, then +      // Op1 must be 0 because negating the minimum signed value is undefined. +      if (isNSW) +        return Constant::getNullValue(Op0->getType()); + +      // 0 - X -> X if X is 0 or the minimum signed value. +      return Op1; +    } +  } + +  // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies. +  // For example, (X + Y) - Y -> X; (Y + X) - Y -> X +  Value *X = nullptr, *Y = nullptr, *Z = Op1; +  if (MaxRecurse && match(Op0, m_Add(m_Value(X), m_Value(Y)))) { // (X + Y) - Z +    // See if "V === Y - Z" simplifies. +    if (Value *V = SimplifyBinOp(Instruction::Sub, Y, Z, Q, MaxRecurse-1)) +      // It does!  Now see if "X + V" simplifies. +      if (Value *W = SimplifyBinOp(Instruction::Add, X, V, Q, MaxRecurse-1)) { +        // It does, we successfully reassociated! +        ++NumReassoc; +        return W; +      } +    // See if "V === X - Z" simplifies. +    if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) +      // It does!  Now see if "Y + V" simplifies. +      if (Value *W = SimplifyBinOp(Instruction::Add, Y, V, Q, MaxRecurse-1)) { +        // It does, we successfully reassociated! +        ++NumReassoc; +        return W; +      } +  } + +  // X - (Y + Z) -> (X - Y) - Z or (X - Z) - Y if everything simplifies. +  // For example, X - (X + 1) -> -1 +  X = Op0; +  if (MaxRecurse && match(Op1, m_Add(m_Value(Y), m_Value(Z)))) { // X - (Y + Z) +    // See if "V === X - Y" simplifies. +    if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) +      // It does!  Now see if "V - Z" simplifies. +      if (Value *W = SimplifyBinOp(Instruction::Sub, V, Z, Q, MaxRecurse-1)) { +        // It does, we successfully reassociated! +        ++NumReassoc; +        return W; +      } +    // See if "V === X - Z" simplifies. +    if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) +      // It does!  Now see if "V - Y" simplifies. +      if (Value *W = SimplifyBinOp(Instruction::Sub, V, Y, Q, MaxRecurse-1)) { +        // It does, we successfully reassociated! +        ++NumReassoc; +        return W; +      } +  } + +  // Z - (X - Y) -> (Z - X) + Y if everything simplifies. +  // For example, X - (X - Y) -> Y. +  Z = Op0; +  if (MaxRecurse && match(Op1, m_Sub(m_Value(X), m_Value(Y)))) // Z - (X - Y) +    // See if "V === Z - X" simplifies. +    if (Value *V = SimplifyBinOp(Instruction::Sub, Z, X, Q, MaxRecurse-1)) +      // It does!  Now see if "V + Y" simplifies. +      if (Value *W = SimplifyBinOp(Instruction::Add, V, Y, Q, MaxRecurse-1)) { +        // It does, we successfully reassociated! +        ++NumReassoc; +        return W; +      } + +  // trunc(X) - trunc(Y) -> trunc(X - Y) if everything simplifies. +  if (MaxRecurse && match(Op0, m_Trunc(m_Value(X))) && +      match(Op1, m_Trunc(m_Value(Y)))) +    if (X->getType() == Y->getType()) +      // See if "V === X - Y" simplifies. +      if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) +        // It does!  Now see if "trunc V" simplifies. +        if (Value *W = SimplifyCastInst(Instruction::Trunc, V, Op0->getType(), +                                        Q, MaxRecurse - 1)) +          // It does, return the simplified "trunc V". +          return W; + +  // Variations on GEP(base, I, ...) - GEP(base, i, ...) -> GEP(null, I-i, ...). +  if (match(Op0, m_PtrToInt(m_Value(X))) && +      match(Op1, m_PtrToInt(m_Value(Y)))) +    if (Constant *Result = computePointerDifference(Q.DL, X, Y)) +      return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); + +  // i1 sub -> xor. +  if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) +    if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) +      return V; + +  // Threading Sub over selects and phi nodes is pointless, so don't bother. +  // Threading over the select in "A - select(cond, B, C)" means evaluating +  // "A-B" and "A-C" and seeing if they are equal; but they are equal if and +  // only if B and C are equal.  If B and C are equal then (since we assume +  // that operands have already been simplified) "select(cond, B, C)" should +  // have been simplified to the common value of B and C already.  Analysing +  // "A-B" and "A-C" thus gains nothing, but costs compile time.  Similarly +  // for threading over phi nodes. + +  return nullptr; +} + +Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +                             const SimplifyQuery &Q) { +  return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); +} + +/// Given operands for a Mul, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                              unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q)) +    return C; + +  // X * undef -> 0 +  // X * 0 -> 0 +  if (match(Op1, m_CombineOr(m_Undef(), m_Zero()))) +    return Constant::getNullValue(Op0->getType()); + +  // X * 1 -> X +  if (match(Op1, m_One())) +    return Op0; + +  // (X / Y) * Y -> X if the division is exact. +  Value *X = nullptr; +  if (match(Op0, m_Exact(m_IDiv(m_Value(X), m_Specific(Op1)))) || // (X / Y) * Y +      match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0)))))   // Y * (X / Y) +    return X; + +  // i1 mul -> and. +  if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) +    if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1)) +      return V; + +  // Try some generic simplifications for associative operations. +  if (Value *V = SimplifyAssociativeBinOp(Instruction::Mul, Op0, Op1, Q, +                                          MaxRecurse)) +    return V; + +  // Mul distributes over Add. Try some generic simplifications based on this. +  if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add, +                             Q, MaxRecurse)) +    return V; + +  // If the operation is with the result of a select instruction, check whether +  // operating on either branch of the select always yields the same value. +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +    if (Value *V = ThreadBinOpOverSelect(Instruction::Mul, Op0, Op1, Q, +                                         MaxRecurse)) +      return V; + +  // If the operation is with the result of a phi instruction, check whether +  // operating on all incoming values of the phi always yields the same value. +  if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) +    if (Value *V = ThreadBinOpOverPHI(Instruction::Mul, Op0, Op1, Q, +                                      MaxRecurse)) +      return V; + +  return nullptr; +} + +Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifyMulInst(Op0, Op1, Q, RecursionLimit); +} + +/// Check for common or similar folds of integer division or integer remainder. +/// This applies to all 4 opcodes (sdiv/udiv/srem/urem). +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { +  Type *Ty = Op0->getType(); + +  // X / undef -> undef +  // X % undef -> undef +  if (match(Op1, m_Undef())) +    return Op1; + +  // X / 0 -> undef +  // X % 0 -> undef +  // We don't need to preserve faults! +  if (match(Op1, m_Zero())) +    return UndefValue::get(Ty); + +  // If any element of a constant divisor vector is zero or undef, the whole op +  // is undef. +  auto *Op1C = dyn_cast<Constant>(Op1); +  if (Op1C && Ty->isVectorTy()) { +    unsigned NumElts = Ty->getVectorNumElements(); +    for (unsigned i = 0; i != NumElts; ++i) { +      Constant *Elt = Op1C->getAggregateElement(i); +      if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt))) +        return UndefValue::get(Ty); +    } +  } + +  // undef / X -> 0 +  // undef % X -> 0 +  if (match(Op0, m_Undef())) +    return Constant::getNullValue(Ty); + +  // 0 / X -> 0 +  // 0 % X -> 0 +  if (match(Op0, m_Zero())) +    return Constant::getNullValue(Op0->getType()); + +  // X / X -> 1 +  // X % X -> 0 +  if (Op0 == Op1) +    return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); + +  // X / 1 -> X +  // X % 1 -> 0 +  // If this is a boolean op (single-bit element type), we can't have +  // division-by-zero or remainder-by-zero, so assume the divisor is 1. +  // Similarly, if we're zero-extending a boolean divisor, then assume it's a 1. +  Value *X; +  if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1) || +      (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) +    return IsDiv ? Op0 : Constant::getNullValue(Ty); + +  return nullptr; +} + +/// Given a predicate and two operands, return true if the comparison is true. +/// This is a helper for div/rem simplification where we return some other value +/// when we can prove a relationship between the operands. +static bool isICmpTrue(ICmpInst::Predicate Pred, Value *LHS, Value *RHS, +                       const SimplifyQuery &Q, unsigned MaxRecurse) { +  Value *V = SimplifyICmpInst(Pred, LHS, RHS, Q, MaxRecurse); +  Constant *C = dyn_cast_or_null<Constant>(V); +  return (C && C->isAllOnesValue()); +} + +/// Return true if we can simplify X / Y to 0. Remainder can adapt that answer +/// to simplify X % Y to X. +static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q, +                      unsigned MaxRecurse, bool IsSigned) { +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return false; + +  if (IsSigned) { +    // |X| / |Y| --> 0 +    // +    // We require that 1 operand is a simple constant. That could be extended to +    // 2 variables if we computed the sign bit for each. +    // +    // Make sure that a constant is not the minimum signed value because taking +    // the abs() of that is undefined. +    Type *Ty = X->getType(); +    const APInt *C; +    if (match(X, m_APInt(C)) && !C->isMinSignedValue()) { +      // Is the variable divisor magnitude always greater than the constant +      // dividend magnitude? +      // |Y| > |C| --> Y < -abs(C) or Y > abs(C) +      Constant *PosDividendC = ConstantInt::get(Ty, C->abs()); +      Constant *NegDividendC = ConstantInt::get(Ty, -C->abs()); +      if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) || +          isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse)) +        return true; +    } +    if (match(Y, m_APInt(C))) { +      // Special-case: we can't take the abs() of a minimum signed value. If +      // that's the divisor, then all we have to do is prove that the dividend +      // is also not the minimum signed value. +      if (C->isMinSignedValue()) +        return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse); + +      // Is the variable dividend magnitude always less than the constant +      // divisor magnitude? +      // |X| < |C| --> X > -abs(C) and X < abs(C) +      Constant *PosDivisorC = ConstantInt::get(Ty, C->abs()); +      Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs()); +      if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) && +          isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse)) +        return true; +    } +    return false; +  } + +  // IsSigned == false. +  // Is the dividend unsigned less than the divisor? +  return isICmpTrue(ICmpInst::ICMP_ULT, X, Y, Q, MaxRecurse); +} + +/// These are simplifications common to SDiv and UDiv. +static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, +                          const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) +    return C; + +  if (Value *V = simplifyDivRem(Op0, Op1, true)) +    return V; + +  bool IsSigned = Opcode == Instruction::SDiv; + +  // (X * Y) / Y -> X if the multiplication does not overflow. +  Value *X; +  if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { +    auto *Mul = cast<OverflowingBinaryOperator>(Op0); +    // If the Mul does not overflow, then we are good to go. +    if ((IsSigned && Mul->hasNoSignedWrap()) || +        (!IsSigned && Mul->hasNoUnsignedWrap())) +      return X; +    // If X has the form X = A / Y, then X * Y cannot overflow. +    if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || +        (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) +      return X; +  } + +  // (X rem Y) / Y -> 0 +  if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || +      (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) +    return Constant::getNullValue(Op0->getType()); + +  // (X /u C1) /u C2 -> 0 if C1 * C2 overflow +  ConstantInt *C1, *C2; +  if (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && +      match(Op1, m_ConstantInt(C2))) { +    bool Overflow; +    (void)C1->getValue().umul_ov(C2->getValue(), Overflow); +    if (Overflow) +      return Constant::getNullValue(Op0->getType()); +  } + +  // If the operation is with the result of a select instruction, check whether +  // operating on either branch of the select always yields the same value. +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +    if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) +      return V; + +  // If the operation is with the result of a phi instruction, check whether +  // operating on all incoming values of the phi always yields the same value. +  if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) +    if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) +      return V; + +  if (isDivZero(Op0, Op1, Q, MaxRecurse, IsSigned)) +    return Constant::getNullValue(Op0->getType()); + +  return nullptr; +} + +/// These are simplifications common to SRem and URem. +static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, +                          const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) +    return C; + +  if (Value *V = simplifyDivRem(Op0, Op1, false)) +    return V; + +  // (X % Y) % Y -> X % Y +  if ((Opcode == Instruction::SRem && +       match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || +      (Opcode == Instruction::URem && +       match(Op0, m_URem(m_Value(), m_Specific(Op1))))) +    return Op0; + +  // (X << Y) % X -> 0 +  if ((Opcode == Instruction::SRem && +       match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) || +      (Opcode == Instruction::URem && +       match(Op0, m_NUWShl(m_Specific(Op1), m_Value())))) +    return Constant::getNullValue(Op0->getType()); + +  // If the operation is with the result of a select instruction, check whether +  // operating on either branch of the select always yields the same value. +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +    if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) +      return V; + +  // If the operation is with the result of a phi instruction, check whether +  // operating on all incoming values of the phi always yields the same value. +  if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) +    if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) +      return V; + +  // If X / Y == 0, then X % Y == X. +  if (isDivZero(Op0, Op1, Q, MaxRecurse, Opcode == Instruction::SRem)) +    return Op0; + +  return nullptr; +} + +/// Given operands for an SDiv, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                               unsigned MaxRecurse) { +  // If two operands are negated and no signed overflow, return -1. +  if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true)) +    return Constant::getAllOnesValue(Op0->getType()); + +  return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifySDivInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for a UDiv, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                               unsigned MaxRecurse) { +  return simplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifyUDivInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for an SRem, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                               unsigned MaxRecurse) { +  // If the divisor is 0, the result is undefined, so assume the divisor is -1. +  // srem Op0, (sext i1 X) --> srem Op0, -1 --> 0 +  Value *X; +  if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) +    return ConstantInt::getNullValue(Op0->getType()); + +  // If the two operands are negated, return 0. +  if (isKnownNegation(Op0, Op1)) +    return ConstantInt::getNullValue(Op0->getType()); + +  return simplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifySRemInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for a URem, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                               unsigned MaxRecurse) { +  return simplifyRem(Instruction::URem, Op0, Op1, Q, MaxRecurse); +} + +Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit); +} + +/// Returns true if a shift by \c Amount always yields undef. +static bool isUndefShift(Value *Amount) { +  Constant *C = dyn_cast<Constant>(Amount); +  if (!C) +    return false; + +  // X shift by undef -> undef because it may shift by the bitwidth. +  if (isa<UndefValue>(C)) +    return true; + +  // Shifting by the bitwidth or more is undefined. +  if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) +    if (CI->getValue().getLimitedValue() >= +        CI->getType()->getScalarSizeInBits()) +      return true; + +  // If all lanes of a vector shift are undefined the whole shift is. +  if (isa<ConstantVector>(C) || isa<ConstantDataVector>(C)) { +    for (unsigned I = 0, E = C->getType()->getVectorNumElements(); I != E; ++I) +      if (!isUndefShift(C->getAggregateElement(I))) +        return false; +    return true; +  } + +  return false; +} + +/// Given operands for an Shl, LShr or AShr, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, +                            Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) +    return C; + +  // 0 shift by X -> 0 +  if (match(Op0, m_Zero())) +    return Constant::getNullValue(Op0->getType()); + +  // X shift by 0 -> X +  // Shift-by-sign-extended bool must be shift-by-0 because shift-by-all-ones +  // would be poison. +  Value *X; +  if (match(Op1, m_Zero()) || +      (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) +    return Op0; + +  // Fold undefined shifts. +  if (isUndefShift(Op1)) +    return UndefValue::get(Op0->getType()); + +  // If the operation is with the result of a select instruction, check whether +  // operating on either branch of the select always yields the same value. +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +    if (Value *V = ThreadBinOpOverSelect(Opcode, Op0, Op1, Q, MaxRecurse)) +      return V; + +  // If the operation is with the result of a phi instruction, check whether +  // operating on all incoming values of the phi always yields the same value. +  if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) +    if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) +      return V; + +  // If any bits in the shift amount make that value greater than or equal to +  // the number of bits in the type, the shift is undefined. +  KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +  if (Known.One.getLimitedValue() >= Known.getBitWidth()) +    return UndefValue::get(Op0->getType()); + +  // If all valid bits in the shift amount are known zero, the first operand is +  // unchanged. +  unsigned NumValidShiftBits = Log2_32_Ceil(Known.getBitWidth()); +  if (Known.countMinTrailingZeros() >= NumValidShiftBits) +    return Op0; + +  return nullptr; +} + +/// Given operands for an Shl, LShr or AShr, see if we can +/// fold the result.  If not, this returns null. +static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, +                                 Value *Op1, bool isExact, const SimplifyQuery &Q, +                                 unsigned MaxRecurse) { +  if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) +    return V; + +  // X >> X -> 0 +  if (Op0 == Op1) +    return Constant::getNullValue(Op0->getType()); + +  // undef >> X -> 0 +  // undef >> X -> undef (if it's exact) +  if (match(Op0, m_Undef())) +    return isExact ? Op0 : Constant::getNullValue(Op0->getType()); + +  // The low bit cannot be shifted out of an exact shift if it is set. +  if (isExact) { +    KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); +    if (Op0Known.One[0]) +      return Op0; +  } + +  return nullptr; +} + +/// Given operands for an Shl, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +                              const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Value *V = SimplifyShift(Instruction::Shl, Op0, Op1, Q, MaxRecurse)) +    return V; + +  // undef << X -> 0 +  // undef << X -> undef if (if it's NSW/NUW) +  if (match(Op0, m_Undef())) +    return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType()); + +  // (X >> A) << A -> X +  Value *X; +  if (match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1))))) +    return X; + +  // shl nuw i8 C, %x  ->  C  iff C has sign bit set. +  if (isNUW && match(Op0, m_Negative())) +    return Op0; +  // NOTE: could use computeKnownBits() / LazyValueInfo, +  // but the cost-benefit analysis suggests it isn't worth it. + +  return nullptr; +} + +Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +                             const SimplifyQuery &Q) { +  return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit); +} + +/// Given operands for an LShr, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, +                               const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, +                                    MaxRecurse)) +      return V; + +  // (X << A) >> A -> X +  Value *X; +  if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1)))) +    return X; + +  // ((X << A) | Y) >> A -> X  if effective width of Y is not larger than A. +  // We can return X as we do in the above case since OR alters no bits in X. +  // SimplifyDemandedBits in InstCombine can do more general optimization for +  // bit manipulation. This pattern aims to provide opportunities for other +  // optimizers by supporting a simple but common case in InstSimplify. +  Value *Y; +  const APInt *ShRAmt, *ShLAmt; +  if (match(Op1, m_APInt(ShRAmt)) && +      match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) && +      *ShRAmt == *ShLAmt) { +    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +    const unsigned Width = Op0->getType()->getScalarSizeInBits(); +    const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros(); +    if (ShRAmt->uge(EffWidthY)) +      return X; +  } + +  return nullptr; +} + +Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, +                              const SimplifyQuery &Q) { +  return ::SimplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit); +} + +/// Given operands for an AShr, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, +                               const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, +                                    MaxRecurse)) +    return V; + +  // all ones >>a X -> -1 +  // Do not return Op0 because it may contain undef elements if it's a vector. +  if (match(Op0, m_AllOnes())) +    return Constant::getAllOnesValue(Op0->getType()); + +  // (X << A) >> A -> X +  Value *X; +  if (match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1)))) +    return X; + +  // Arithmetic shifting an all-sign-bit value is a no-op. +  unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +  if (NumSignBits == Op0->getType()->getScalarSizeInBits()) +    return Op0; + +  return nullptr; +} + +Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, +                              const SimplifyQuery &Q) { +  return ::SimplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit); +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, +                                         ICmpInst *UnsignedICmp, bool IsAnd) { +  Value *X, *Y; + +  ICmpInst::Predicate EqPred; +  if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) || +      !ICmpInst::isEquality(EqPred)) +    return nullptr; + +  ICmpInst::Predicate UnsignedPred; +  if (match(UnsignedICmp, m_ICmp(UnsignedPred, m_Value(X), m_Specific(Y))) && +      ICmpInst::isUnsigned(UnsignedPred)) +    ; +  else if (match(UnsignedICmp, +                 m_ICmp(UnsignedPred, m_Specific(Y), m_Value(X))) && +           ICmpInst::isUnsigned(UnsignedPred)) +    UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); +  else +    return nullptr; + +  // X < Y && Y != 0  -->  X < Y +  // X < Y || Y != 0  -->  Y != 0 +  if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE) +    return IsAnd ? UnsignedICmp : ZeroICmp; + +  // X >= Y || Y != 0  -->  true +  // X >= Y || Y == 0  -->  X >= Y +  if (UnsignedPred == ICmpInst::ICMP_UGE && !IsAnd) { +    if (EqPred == ICmpInst::ICMP_NE) +      return getTrue(UnsignedICmp->getType()); +    return UnsignedICmp; +  } + +  // X < Y && Y == 0  -->  false +  if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_EQ && +      IsAnd) +    return getFalse(UnsignedICmp->getType()); + +  return nullptr; +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { +  ICmpInst::Predicate Pred0, Pred1; +  Value *A ,*B; +  if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || +      !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) +    return nullptr; + +  // We have (icmp Pred0, A, B) & (icmp Pred1, A, B). +  // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we +  // can eliminate Op1 from this 'and'. +  if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) +    return Op0; + +  // Check for any combination of predicates that are guaranteed to be disjoint. +  if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || +      (Pred0 == ICmpInst::ICMP_EQ && ICmpInst::isFalseWhenEqual(Pred1)) || +      (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT) || +      (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT)) +    return getFalse(Op0->getType()); + +  return nullptr; +} + +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { +  ICmpInst::Predicate Pred0, Pred1; +  Value *A ,*B; +  if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || +      !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) +    return nullptr; + +  // We have (icmp Pred0, A, B) | (icmp Pred1, A, B). +  // If Op1 is always implied true by Op0, then Op0 is a subset of Op1, and we +  // can eliminate Op0 from this 'or'. +  if (ICmpInst::isImpliedTrueByMatchingCmp(Pred0, Pred1)) +    return Op1; + +  // Check for any combination of predicates that cover the entire range of +  // possibilities. +  if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) || +      (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) || +      (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) || +      (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE)) +    return getTrue(Op0->getType()); + +  return nullptr; +} + +/// Test if a pair of compares with a shared operand and 2 constants has an +/// empty set intersection, full set union, or if one compare is a superset of +/// the other. +static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, +                                                bool IsAnd) { +  // Look for this pattern: {and/or} (icmp X, C0), (icmp X, C1)). +  if (Cmp0->getOperand(0) != Cmp1->getOperand(0)) +    return nullptr; + +  const APInt *C0, *C1; +  if (!match(Cmp0->getOperand(1), m_APInt(C0)) || +      !match(Cmp1->getOperand(1), m_APInt(C1))) +    return nullptr; + +  auto Range0 = ConstantRange::makeExactICmpRegion(Cmp0->getPredicate(), *C0); +  auto Range1 = ConstantRange::makeExactICmpRegion(Cmp1->getPredicate(), *C1); + +  // For and-of-compares, check if the intersection is empty: +  // (icmp X, C0) && (icmp X, C1) --> empty set --> false +  if (IsAnd && Range0.intersectWith(Range1).isEmptySet()) +    return getFalse(Cmp0->getType()); + +  // For or-of-compares, check if the union is full: +  // (icmp X, C0) || (icmp X, C1) --> full set --> true +  if (!IsAnd && Range0.unionWith(Range1).isFullSet()) +    return getTrue(Cmp0->getType()); + +  // Is one range a superset of the other? +  // If this is and-of-compares, take the smaller set: +  // (icmp sgt X, 4) && (icmp sgt X, 42) --> icmp sgt X, 42 +  // If this is or-of-compares, take the larger set: +  // (icmp sgt X, 4) || (icmp sgt X, 42) --> icmp sgt X, 4 +  if (Range0.contains(Range1)) +    return IsAnd ? Cmp1 : Cmp0; +  if (Range1.contains(Range0)) +    return IsAnd ? Cmp0 : Cmp1; + +  return nullptr; +} + +static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1, +                                           bool IsAnd) { +  ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate(); +  if (!match(Cmp0->getOperand(1), m_Zero()) || +      !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1) +    return nullptr; + +  if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ)) +    return nullptr; + +  // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)". +  Value *X = Cmp0->getOperand(0); +  Value *Y = Cmp1->getOperand(0); + +  // If one of the compares is a masked version of a (not) null check, then +  // that compare implies the other, so we eliminate the other. Optionally, look +  // through a pointer-to-int cast to match a null check of a pointer type. + +  // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0 +  // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0 +  // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0 +  // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0 +  if (match(Y, m_c_And(m_Specific(X), m_Value())) || +      match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value()))) +    return Cmp1; + +  // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0 +  // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0 +  // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0 +  // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0 +  if (match(X, m_c_And(m_Specific(Y), m_Value())) || +      match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value()))) +    return Cmp0; + +  return nullptr; +} + +static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { +  // (icmp (add V, C0), C1) & (icmp V, C0) +  ICmpInst::Predicate Pred0, Pred1; +  const APInt *C0, *C1; +  Value *V; +  if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) +    return nullptr; + +  if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) +    return nullptr; + +  auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); +  if (AddInst->getOperand(1) != Op1->getOperand(1)) +    return nullptr; + +  Type *ITy = Op0->getType(); +  bool isNSW = AddInst->hasNoSignedWrap(); +  bool isNUW = AddInst->hasNoUnsignedWrap(); + +  const APInt Delta = *C1 - *C0; +  if (C0->isStrictlyPositive()) { +    if (Delta == 2) { +      if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT) +        return getFalse(ITy); +      if (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT && isNSW) +        return getFalse(ITy); +    } +    if (Delta == 1) { +      if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_SGT) +        return getFalse(ITy); +      if (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGT && isNSW) +        return getFalse(ITy); +    } +  } +  if (C0->getBoolValue() && isNUW) { +    if (Delta == 2) +      if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT) +        return getFalse(ITy); +    if (Delta == 1) +      if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGT) +        return getFalse(ITy); +  } + +  return nullptr; +} + +static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { +  if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) +    return X; +  if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true)) +    return X; + +  if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1)) +    return X; +  if (Value *X = simplifyAndOfICmpsWithSameOperands(Op1, Op0)) +    return X; + +  if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true)) +    return X; + +  if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) +    return X; + +  if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1)) +    return X; +  if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0)) +    return X; + +  return nullptr; +} + +static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { +  // (icmp (add V, C0), C1) | (icmp V, C0) +  ICmpInst::Predicate Pred0, Pred1; +  const APInt *C0, *C1; +  Value *V; +  if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_APInt(C0)), m_APInt(C1)))) +    return nullptr; + +  if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) +    return nullptr; + +  auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); +  if (AddInst->getOperand(1) != Op1->getOperand(1)) +    return nullptr; + +  Type *ITy = Op0->getType(); +  bool isNSW = AddInst->hasNoSignedWrap(); +  bool isNUW = AddInst->hasNoUnsignedWrap(); + +  const APInt Delta = *C1 - *C0; +  if (C0->isStrictlyPositive()) { +    if (Delta == 2) { +      if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) +        return getTrue(ITy); +      if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) +        return getTrue(ITy); +    } +    if (Delta == 1) { +      if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) +        return getTrue(ITy); +      if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) +        return getTrue(ITy); +    } +  } +  if (C0->getBoolValue() && isNUW) { +    if (Delta == 2) +      if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) +        return getTrue(ITy); +    if (Delta == 1) +      if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) +        return getTrue(ITy); +  } + +  return nullptr; +} + +static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { +  if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) +    return X; +  if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false)) +    return X; + +  if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1)) +    return X; +  if (Value *X = simplifyOrOfICmpsWithSameOperands(Op1, Op0)) +    return X; + +  if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false)) +    return X; + +  if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) +    return X; + +  if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1)) +    return X; +  if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0)) +    return X; + +  return nullptr; +} + +static Value *simplifyAndOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { +  Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); +  Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); +  if (LHS0->getType() != RHS0->getType()) +    return nullptr; + +  FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); +  if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || +      (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) { +    // (fcmp ord NNAN, X) & (fcmp ord X, Y) --> fcmp ord X, Y +    // (fcmp ord NNAN, X) & (fcmp ord Y, X) --> fcmp ord Y, X +    // (fcmp ord X, NNAN) & (fcmp ord X, Y) --> fcmp ord X, Y +    // (fcmp ord X, NNAN) & (fcmp ord Y, X) --> fcmp ord Y, X +    // (fcmp uno NNAN, X) | (fcmp uno X, Y) --> fcmp uno X, Y +    // (fcmp uno NNAN, X) | (fcmp uno Y, X) --> fcmp uno Y, X +    // (fcmp uno X, NNAN) | (fcmp uno X, Y) --> fcmp uno X, Y +    // (fcmp uno X, NNAN) | (fcmp uno Y, X) --> fcmp uno Y, X +    if ((isKnownNeverNaN(LHS0) && (LHS1 == RHS0 || LHS1 == RHS1)) || +        (isKnownNeverNaN(LHS1) && (LHS0 == RHS0 || LHS0 == RHS1))) +      return RHS; + +    // (fcmp ord X, Y) & (fcmp ord NNAN, X) --> fcmp ord X, Y +    // (fcmp ord Y, X) & (fcmp ord NNAN, X) --> fcmp ord Y, X +    // (fcmp ord X, Y) & (fcmp ord X, NNAN) --> fcmp ord X, Y +    // (fcmp ord Y, X) & (fcmp ord X, NNAN) --> fcmp ord Y, X +    // (fcmp uno X, Y) | (fcmp uno NNAN, X) --> fcmp uno X, Y +    // (fcmp uno Y, X) | (fcmp uno NNAN, X) --> fcmp uno Y, X +    // (fcmp uno X, Y) | (fcmp uno X, NNAN) --> fcmp uno X, Y +    // (fcmp uno Y, X) | (fcmp uno X, NNAN) --> fcmp uno Y, X +    if ((isKnownNeverNaN(RHS0) && (RHS1 == LHS0 || RHS1 == LHS1)) || +        (isKnownNeverNaN(RHS1) && (RHS0 == LHS0 || RHS0 == LHS1))) +      return LHS; +  } + +  return nullptr; +} + +static Value *simplifyAndOrOfCmps(Value *Op0, Value *Op1, bool IsAnd) { +  // Look through casts of the 'and' operands to find compares. +  auto *Cast0 = dyn_cast<CastInst>(Op0); +  auto *Cast1 = dyn_cast<CastInst>(Op1); +  if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() && +      Cast0->getSrcTy() == Cast1->getSrcTy()) { +    Op0 = Cast0->getOperand(0); +    Op1 = Cast1->getOperand(0); +  } + +  Value *V = nullptr; +  auto *ICmp0 = dyn_cast<ICmpInst>(Op0); +  auto *ICmp1 = dyn_cast<ICmpInst>(Op1); +  if (ICmp0 && ICmp1) +    V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1) : +                simplifyOrOfICmps(ICmp0, ICmp1); + +  auto *FCmp0 = dyn_cast<FCmpInst>(Op0); +  auto *FCmp1 = dyn_cast<FCmpInst>(Op1); +  if (FCmp0 && FCmp1) +    V = simplifyAndOrOfFCmps(FCmp0, FCmp1, IsAnd); + +  if (!V) +    return nullptr; +  if (!Cast0) +    return V; + +  // If we looked through casts, we can only handle a constant simplification +  // because we are not allowed to create a cast instruction here. +  if (auto *C = dyn_cast<Constant>(V)) +    return ConstantExpr::getCast(Cast0->getOpcode(), C, Cast0->getType()); + +  return nullptr; +} + +/// Given operands for an And, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                              unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::And, Op0, Op1, Q)) +    return C; + +  // X & undef -> 0 +  if (match(Op1, m_Undef())) +    return Constant::getNullValue(Op0->getType()); + +  // X & X = X +  if (Op0 == Op1) +    return Op0; + +  // X & 0 = 0 +  if (match(Op1, m_Zero())) +    return Constant::getNullValue(Op0->getType()); + +  // X & -1 = X +  if (match(Op1, m_AllOnes())) +    return Op0; + +  // A & ~A  =  ~A & A  =  0 +  if (match(Op0, m_Not(m_Specific(Op1))) || +      match(Op1, m_Not(m_Specific(Op0)))) +    return Constant::getNullValue(Op0->getType()); + +  // (A | ?) & A = A +  if (match(Op0, m_c_Or(m_Specific(Op1), m_Value()))) +    return Op1; + +  // A & (A | ?) = A +  if (match(Op1, m_c_Or(m_Specific(Op0), m_Value()))) +    return Op0; + +  // A mask that only clears known zeros of a shifted value is a no-op. +  Value *X; +  const APInt *Mask; +  const APInt *ShAmt; +  if (match(Op1, m_APInt(Mask))) { +    // If all bits in the inverted and shifted mask are clear: +    // and (shl X, ShAmt), Mask --> shl X, ShAmt +    if (match(Op0, m_Shl(m_Value(X), m_APInt(ShAmt))) && +        (~(*Mask)).lshr(*ShAmt).isNullValue()) +      return Op0; + +    // If all bits in the inverted and shifted mask are clear: +    // and (lshr X, ShAmt), Mask --> lshr X, ShAmt +    if (match(Op0, m_LShr(m_Value(X), m_APInt(ShAmt))) && +        (~(*Mask)).shl(*ShAmt).isNullValue()) +      return Op0; +  } + +  // A & (-A) = A if A is a power of two or zero. +  if (match(Op0, m_Neg(m_Specific(Op1))) || +      match(Op1, m_Neg(m_Specific(Op0)))) { +    if (isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, +                               Q.DT)) +      return Op0; +    if (isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, +                               Q.DT)) +      return Op1; +  } + +  if (Value *V = simplifyAndOrOfCmps(Op0, Op1, true)) +    return V; + +  // Try some generic simplifications for associative operations. +  if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, +                                          MaxRecurse)) +    return V; + +  // And distributes over Or.  Try some generic simplifications based on this. +  if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Or, +                             Q, MaxRecurse)) +    return V; + +  // And distributes over Xor.  Try some generic simplifications based on this. +  if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Xor, +                             Q, MaxRecurse)) +    return V; + +  // If the operation is with the result of a select instruction, check whether +  // operating on either branch of the select always yields the same value. +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +    if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q, +                                         MaxRecurse)) +      return V; + +  // If the operation is with the result of a phi instruction, check whether +  // operating on all incoming values of the phi always yields the same value. +  if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) +    if (Value *V = ThreadBinOpOverPHI(Instruction::And, Op0, Op1, Q, +                                      MaxRecurse)) +      return V; + +  // Assuming the effective width of Y is not larger than A, i.e. all bits +  // from X and Y are disjoint in (X << A) | Y, +  // if the mask of this AND op covers all bits of X or Y, while it covers +  // no bits from the other, we can bypass this AND op. E.g., +  // ((X << A) | Y) & Mask -> Y, +  //     if Mask = ((1 << effective_width_of(Y)) - 1) +  // ((X << A) | Y) & Mask -> X << A, +  //     if Mask = ((1 << effective_width_of(X)) - 1) << A +  // SimplifyDemandedBits in InstCombine can optimize the general case. +  // This pattern aims to help other passes for a common case. +  Value *Y, *XShifted; +  if (match(Op1, m_APInt(Mask)) && +      match(Op0, m_c_Or(m_CombineAnd(m_NUWShl(m_Value(X), m_APInt(ShAmt)), +                                     m_Value(XShifted)), +                        m_Value(Y)))) { +    const unsigned Width = Op0->getType()->getScalarSizeInBits(); +    const unsigned ShftCnt = ShAmt->getLimitedValue(Width); +    const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +    const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros(); +    if (EffWidthY <= ShftCnt) { +      const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, +                                                Q.DT); +      const unsigned EffWidthX = Width - XKnown.countMinLeadingZeros(); +      const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY); +      const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt; +      // If the mask is extracting all bits from X or Y as is, we can skip +      // this AND op. +      if (EffBitsY.isSubsetOf(*Mask) && !EffBitsX.intersects(*Mask)) +        return Y; +      if (EffBitsX.isSubsetOf(*Mask) && !EffBitsY.intersects(*Mask)) +        return XShifted; +    } +  } + +  return nullptr; +} + +Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifyAndInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for an Or, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                             unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::Or, Op0, Op1, Q)) +    return C; + +  // X | undef -> -1 +  // X | -1 = -1 +  // Do not return Op1 because it may contain undef elements if it's a vector. +  if (match(Op1, m_Undef()) || match(Op1, m_AllOnes())) +    return Constant::getAllOnesValue(Op0->getType()); + +  // X | X = X +  // X | 0 = X +  if (Op0 == Op1 || match(Op1, m_Zero())) +    return Op0; + +  // A | ~A  =  ~A | A  =  -1 +  if (match(Op0, m_Not(m_Specific(Op1))) || +      match(Op1, m_Not(m_Specific(Op0)))) +    return Constant::getAllOnesValue(Op0->getType()); + +  // (A & ?) | A = A +  if (match(Op0, m_c_And(m_Specific(Op1), m_Value()))) +    return Op1; + +  // A | (A & ?) = A +  if (match(Op1, m_c_And(m_Specific(Op0), m_Value()))) +    return Op0; + +  // ~(A & ?) | A = -1 +  if (match(Op0, m_Not(m_c_And(m_Specific(Op1), m_Value())))) +    return Constant::getAllOnesValue(Op1->getType()); + +  // A | ~(A & ?) = -1 +  if (match(Op1, m_Not(m_c_And(m_Specific(Op1), m_Value())))) +    return Constant::getAllOnesValue(Op0->getType()); + +  Value *A, *B; +  // (A & ~B) | (A ^ B) -> (A ^ B) +  // (~B & A) | (A ^ B) -> (A ^ B) +  // (A & ~B) | (B ^ A) -> (B ^ A) +  // (~B & A) | (B ^ A) -> (B ^ A) +  if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && +      (match(Op0, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || +       match(Op0, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) +    return Op1; + +  // Commute the 'or' operands. +  // (A ^ B) | (A & ~B) -> (A ^ B) +  // (A ^ B) | (~B & A) -> (A ^ B) +  // (B ^ A) | (A & ~B) -> (B ^ A) +  // (B ^ A) | (~B & A) -> (B ^ A) +  if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && +      (match(Op1, m_c_And(m_Specific(A), m_Not(m_Specific(B)))) || +       match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) +    return Op0; + +  // (A & B) | (~A ^ B) -> (~A ^ B) +  // (B & A) | (~A ^ B) -> (~A ^ B) +  // (A & B) | (B ^ ~A) -> (B ^ ~A) +  // (B & A) | (B ^ ~A) -> (B ^ ~A) +  if (match(Op0, m_And(m_Value(A), m_Value(B))) && +      (match(Op1, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || +       match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) +    return Op1; + +  // (~A ^ B) | (A & B) -> (~A ^ B) +  // (~A ^ B) | (B & A) -> (~A ^ B) +  // (B ^ ~A) | (A & B) -> (B ^ ~A) +  // (B ^ ~A) | (B & A) -> (B ^ ~A) +  if (match(Op1, m_And(m_Value(A), m_Value(B))) && +      (match(Op0, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || +       match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) +    return Op0; + +  if (Value *V = simplifyAndOrOfCmps(Op0, Op1, false)) +    return V; + +  // Try some generic simplifications for associative operations. +  if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, +                                          MaxRecurse)) +    return V; + +  // Or distributes over And.  Try some generic simplifications based on this. +  if (Value *V = ExpandBinOp(Instruction::Or, Op0, Op1, Instruction::And, Q, +                             MaxRecurse)) +    return V; + +  // If the operation is with the result of a select instruction, check whether +  // operating on either branch of the select always yields the same value. +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +    if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, +                                         MaxRecurse)) +      return V; + +  // (A & C1)|(B & C2) +  const APInt *C1, *C2; +  if (match(Op0, m_And(m_Value(A), m_APInt(C1))) && +      match(Op1, m_And(m_Value(B), m_APInt(C2)))) { +    if (*C1 == ~*C2) { +      // (A & C1)|(B & C2) +      // If we have: ((V + N) & C1) | (V & C2) +      // .. and C2 = ~C1 and C2 is 0+1+ and (N & C2) == 0 +      // replace with V+N. +      Value *N; +      if (C2->isMask() && // C2 == 0+1+ +          match(A, m_c_Add(m_Specific(B), m_Value(N)))) { +        // Add commutes, try both ways. +        if (MaskedValueIsZero(N, *C2, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) +          return A; +      } +      // Or commutes, try both ways. +      if (C1->isMask() && +          match(B, m_c_Add(m_Specific(A), m_Value(N)))) { +        // Add commutes, try both ways. +        if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) +          return B; +      } +    } +  } + +  // If the operation is with the result of a phi instruction, check whether +  // operating on all incoming values of the phi always yields the same value. +  if (isa<PHINode>(Op0) || isa<PHINode>(Op1)) +    if (Value *V = ThreadBinOpOverPHI(Instruction::Or, Op0, Op1, Q, MaxRecurse)) +      return V; + +  return nullptr; +} + +Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifyOrInst(Op0, Op1, Q, RecursionLimit); +} + +/// Given operands for a Xor, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, +                              unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::Xor, Op0, Op1, Q)) +    return C; + +  // A ^ undef -> undef +  if (match(Op1, m_Undef())) +    return Op1; + +  // A ^ 0 = A +  if (match(Op1, m_Zero())) +    return Op0; + +  // A ^ A = 0 +  if (Op0 == Op1) +    return Constant::getNullValue(Op0->getType()); + +  // A ^ ~A  =  ~A ^ A  =  -1 +  if (match(Op0, m_Not(m_Specific(Op1))) || +      match(Op1, m_Not(m_Specific(Op0)))) +    return Constant::getAllOnesValue(Op0->getType()); + +  // Try some generic simplifications for associative operations. +  if (Value *V = SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q, +                                          MaxRecurse)) +    return V; + +  // Threading Xor over selects and phi nodes is pointless, so don't bother. +  // Threading over the select in "A ^ select(cond, B, C)" means evaluating +  // "A^B" and "A^C" and seeing if they are equal; but they are equal if and +  // only if B and C are equal.  If B and C are equal then (since we assume +  // that operands have already been simplified) "select(cond, B, C)" should +  // have been simplified to the common value of B and C already.  Analysing +  // "A^B" and "A^C" thus gains nothing, but costs compile time.  Similarly +  // for threading over phi nodes. + +  return nullptr; +} + +Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { +  return ::SimplifyXorInst(Op0, Op1, Q, RecursionLimit); +} + + +static Type *GetCompareTy(Value *Op) { +  return CmpInst::makeCmpResultType(Op->getType()); +} + +/// Rummage around inside V looking for something equivalent to the comparison +/// "LHS Pred RHS". Return such a value if found, otherwise return null. +/// Helper function for analyzing max/min idioms. +static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, +                                         Value *LHS, Value *RHS) { +  SelectInst *SI = dyn_cast<SelectInst>(V); +  if (!SI) +    return nullptr; +  CmpInst *Cmp = dyn_cast<CmpInst>(SI->getCondition()); +  if (!Cmp) +    return nullptr; +  Value *CmpLHS = Cmp->getOperand(0), *CmpRHS = Cmp->getOperand(1); +  if (Pred == Cmp->getPredicate() && LHS == CmpLHS && RHS == CmpRHS) +    return Cmp; +  if (Pred == CmpInst::getSwappedPredicate(Cmp->getPredicate()) && +      LHS == CmpRHS && RHS == CmpLHS) +    return Cmp; +  return nullptr; +} + +// A significant optimization not implemented here is assuming that alloca +// addresses are not equal to incoming argument values. They don't *alias*, +// as we say, but that doesn't mean they aren't equal, so we take a +// conservative approach. +// +// This is inspired in part by C++11 5.10p1: +//   "Two pointers of the same type compare equal if and only if they are both +//    null, both point to the same function, or both represent the same +//    address." +// +// This is pretty permissive. +// +// It's also partly due to C11 6.5.9p6: +//   "Two pointers compare equal if and only if both are null pointers, both are +//    pointers to the same object (including a pointer to an object and a +//    subobject at its beginning) or function, both are pointers to one past the +//    last element of the same array object, or one is a pointer to one past the +//    end of one array object and the other is a pointer to the start of a +//    different array object that happens to immediately follow the first array +//    object in the address space.) +// +// C11's version is more restrictive, however there's no reason why an argument +// couldn't be a one-past-the-end value for a stack object in the caller and be +// equal to the beginning of a stack object in the callee. +// +// If the C and C++ standards are ever made sufficiently restrictive in this +// area, it may be possible to update LLVM's semantics accordingly and reinstate +// this optimization. +static Constant * +computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, +                   const DominatorTree *DT, CmpInst::Predicate Pred, +                   AssumptionCache *AC, const Instruction *CxtI, +                   Value *LHS, Value *RHS) { +  // First, skip past any trivial no-ops. +  LHS = LHS->stripPointerCasts(); +  RHS = RHS->stripPointerCasts(); + +  // A non-null pointer is not equal to a null pointer. +  if (llvm::isKnownNonZero(LHS, DL) && isa<ConstantPointerNull>(RHS) && +      (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE)) +    return ConstantInt::get(GetCompareTy(LHS), +                            !CmpInst::isTrueWhenEqual(Pred)); + +  // We can only fold certain predicates on pointer comparisons. +  switch (Pred) { +  default: +    return nullptr; + +    // Equality comaprisons are easy to fold. +  case CmpInst::ICMP_EQ: +  case CmpInst::ICMP_NE: +    break; + +    // We can only handle unsigned relational comparisons because 'inbounds' on +    // a GEP only protects against unsigned wrapping. +  case CmpInst::ICMP_UGT: +  case CmpInst::ICMP_UGE: +  case CmpInst::ICMP_ULT: +  case CmpInst::ICMP_ULE: +    // However, we have to switch them to their signed variants to handle +    // negative indices from the base pointer. +    Pred = ICmpInst::getSignedPredicate(Pred); +    break; +  } + +  // Strip off any constant offsets so that we can reason about them. +  // It's tempting to use getUnderlyingObject or even just stripInBoundsOffsets +  // here and compare base addresses like AliasAnalysis does, however there are +  // numerous hazards. AliasAnalysis and its utilities rely on special rules +  // governing loads and stores which don't apply to icmps. Also, AliasAnalysis +  // doesn't need to guarantee pointer inequality when it says NoAlias. +  Constant *LHSOffset = stripAndComputeConstantOffsets(DL, LHS); +  Constant *RHSOffset = stripAndComputeConstantOffsets(DL, RHS); + +  // If LHS and RHS are related via constant offsets to the same base +  // value, we can replace it with an icmp which just compares the offsets. +  if (LHS == RHS) +    return ConstantExpr::getICmp(Pred, LHSOffset, RHSOffset); + +  // Various optimizations for (in)equality comparisons. +  if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE) { +    // Different non-empty allocations that exist at the same time have +    // different addresses (if the program can tell). Global variables always +    // exist, so they always exist during the lifetime of each other and all +    // allocas. Two different allocas usually have different addresses... +    // +    // However, if there's an @llvm.stackrestore dynamically in between two +    // allocas, they may have the same address. It's tempting to reduce the +    // scope of the problem by only looking at *static* allocas here. That would +    // cover the majority of allocas while significantly reducing the likelihood +    // of having an @llvm.stackrestore pop up in the middle. However, it's not +    // actually impossible for an @llvm.stackrestore to pop up in the middle of +    // an entry block. Also, if we have a block that's not attached to a +    // function, we can't tell if it's "static" under the current definition. +    // Theoretically, this problem could be fixed by creating a new kind of +    // instruction kind specifically for static allocas. Such a new instruction +    // could be required to be at the top of the entry block, thus preventing it +    // from being subject to a @llvm.stackrestore. Instcombine could even +    // convert regular allocas into these special allocas. It'd be nifty. +    // However, until then, this problem remains open. +    // +    // So, we'll assume that two non-empty allocas have different addresses +    // for now. +    // +    // With all that, if the offsets are within the bounds of their allocations +    // (and not one-past-the-end! so we can't use inbounds!), and their +    // allocations aren't the same, the pointers are not equal. +    // +    // Note that it's not necessary to check for LHS being a global variable +    // address, due to canonicalization and constant folding. +    if (isa<AllocaInst>(LHS) && +        (isa<AllocaInst>(RHS) || isa<GlobalVariable>(RHS))) { +      ConstantInt *LHSOffsetCI = dyn_cast<ConstantInt>(LHSOffset); +      ConstantInt *RHSOffsetCI = dyn_cast<ConstantInt>(RHSOffset); +      uint64_t LHSSize, RHSSize; +      ObjectSizeOpts Opts; +      Opts.NullIsUnknownSize = +          NullPointerIsDefined(cast<AllocaInst>(LHS)->getFunction()); +      if (LHSOffsetCI && RHSOffsetCI && +          getObjectSize(LHS, LHSSize, DL, TLI, Opts) && +          getObjectSize(RHS, RHSSize, DL, TLI, Opts)) { +        const APInt &LHSOffsetValue = LHSOffsetCI->getValue(); +        const APInt &RHSOffsetValue = RHSOffsetCI->getValue(); +        if (!LHSOffsetValue.isNegative() && +            !RHSOffsetValue.isNegative() && +            LHSOffsetValue.ult(LHSSize) && +            RHSOffsetValue.ult(RHSSize)) { +          return ConstantInt::get(GetCompareTy(LHS), +                                  !CmpInst::isTrueWhenEqual(Pred)); +        } +      } + +      // Repeat the above check but this time without depending on DataLayout +      // or being able to compute a precise size. +      if (!cast<PointerType>(LHS->getType())->isEmptyTy() && +          !cast<PointerType>(RHS->getType())->isEmptyTy() && +          LHSOffset->isNullValue() && +          RHSOffset->isNullValue()) +        return ConstantInt::get(GetCompareTy(LHS), +                                !CmpInst::isTrueWhenEqual(Pred)); +    } + +    // Even if an non-inbounds GEP occurs along the path we can still optimize +    // equality comparisons concerning the result. We avoid walking the whole +    // chain again by starting where the last calls to +    // stripAndComputeConstantOffsets left off and accumulate the offsets. +    Constant *LHSNoBound = stripAndComputeConstantOffsets(DL, LHS, true); +    Constant *RHSNoBound = stripAndComputeConstantOffsets(DL, RHS, true); +    if (LHS == RHS) +      return ConstantExpr::getICmp(Pred, +                                   ConstantExpr::getAdd(LHSOffset, LHSNoBound), +                                   ConstantExpr::getAdd(RHSOffset, RHSNoBound)); + +    // If one side of the equality comparison must come from a noalias call +    // (meaning a system memory allocation function), and the other side must +    // come from a pointer that cannot overlap with dynamically-allocated +    // memory within the lifetime of the current function (allocas, byval +    // arguments, globals), then determine the comparison result here. +    SmallVector<Value *, 8> LHSUObjs, RHSUObjs; +    GetUnderlyingObjects(LHS, LHSUObjs, DL); +    GetUnderlyingObjects(RHS, RHSUObjs, DL); + +    // Is the set of underlying objects all noalias calls? +    auto IsNAC = [](ArrayRef<Value *> Objects) { +      return all_of(Objects, isNoAliasCall); +    }; + +    // Is the set of underlying objects all things which must be disjoint from +    // noalias calls. For allocas, we consider only static ones (dynamic +    // allocas might be transformed into calls to malloc not simultaneously +    // live with the compared-to allocation). For globals, we exclude symbols +    // that might be resolve lazily to symbols in another dynamically-loaded +    // library (and, thus, could be malloc'ed by the implementation). +    auto IsAllocDisjoint = [](ArrayRef<Value *> Objects) { +      return all_of(Objects, [](Value *V) { +        if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) +          return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); +        if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) +          return (GV->hasLocalLinkage() || GV->hasHiddenVisibility() || +                  GV->hasProtectedVisibility() || GV->hasGlobalUnnamedAddr()) && +                 !GV->isThreadLocal(); +        if (const Argument *A = dyn_cast<Argument>(V)) +          return A->hasByValAttr(); +        return false; +      }); +    }; + +    if ((IsNAC(LHSUObjs) && IsAllocDisjoint(RHSUObjs)) || +        (IsNAC(RHSUObjs) && IsAllocDisjoint(LHSUObjs))) +        return ConstantInt::get(GetCompareTy(LHS), +                                !CmpInst::isTrueWhenEqual(Pred)); + +    // Fold comparisons for non-escaping pointer even if the allocation call +    // cannot be elided. We cannot fold malloc comparison to null. Also, the +    // dynamic allocation call could be either of the operands. +    Value *MI = nullptr; +    if (isAllocLikeFn(LHS, TLI) && +        llvm::isKnownNonZero(RHS, DL, 0, nullptr, CxtI, DT)) +      MI = LHS; +    else if (isAllocLikeFn(RHS, TLI) && +             llvm::isKnownNonZero(LHS, DL, 0, nullptr, CxtI, DT)) +      MI = RHS; +    // FIXME: We should also fold the compare when the pointer escapes, but the +    // compare dominates the pointer escape +    if (MI && !PointerMayBeCaptured(MI, true, true)) +      return ConstantInt::get(GetCompareTy(LHS), +                              CmpInst::isFalseWhenEqual(Pred)); +  } + +  // Otherwise, fail. +  return nullptr; +} + +/// Fold an icmp when its operands have i1 scalar type. +static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS, +                                  Value *RHS, const SimplifyQuery &Q) { +  Type *ITy = GetCompareTy(LHS); // The return type. +  Type *OpTy = LHS->getType();   // The operand type. +  if (!OpTy->isIntOrIntVectorTy(1)) +    return nullptr; + +  // A boolean compared to true/false can be simplified in 14 out of the 20 +  // (10 predicates * 2 constants) possible combinations. Cases not handled here +  // require a 'not' of the LHS, so those must be transformed in InstCombine. +  if (match(RHS, m_Zero())) { +    switch (Pred) { +    case CmpInst::ICMP_NE:  // X !=  0 -> X +    case CmpInst::ICMP_UGT: // X >u  0 -> X +    case CmpInst::ICMP_SLT: // X <s  0 -> X +      return LHS; + +    case CmpInst::ICMP_ULT: // X <u  0 -> false +    case CmpInst::ICMP_SGT: // X >s  0 -> false +      return getFalse(ITy); + +    case CmpInst::ICMP_UGE: // X >=u 0 -> true +    case CmpInst::ICMP_SLE: // X <=s 0 -> true +      return getTrue(ITy); + +    default: break; +    } +  } else if (match(RHS, m_One())) { +    switch (Pred) { +    case CmpInst::ICMP_EQ:  // X ==   1 -> X +    case CmpInst::ICMP_UGE: // X >=u  1 -> X +    case CmpInst::ICMP_SLE: // X <=s -1 -> X +      return LHS; + +    case CmpInst::ICMP_UGT: // X >u   1 -> false +    case CmpInst::ICMP_SLT: // X <s  -1 -> false +      return getFalse(ITy); + +    case CmpInst::ICMP_ULE: // X <=u  1 -> true +    case CmpInst::ICMP_SGE: // X >=s -1 -> true +      return getTrue(ITy); + +    default: break; +    } +  } + +  switch (Pred) { +  default: +    break; +  case ICmpInst::ICMP_UGE: +    if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) +      return getTrue(ITy); +    break; +  case ICmpInst::ICMP_SGE: +    /// For signed comparison, the values for an i1 are 0 and -1 +    /// respectively. This maps into a truth table of: +    /// LHS | RHS | LHS >=s RHS   | LHS implies RHS +    ///  0  |  0  |  1 (0 >= 0)   |  1 +    ///  0  |  1  |  1 (0 >= -1)  |  1 +    ///  1  |  0  |  0 (-1 >= 0)  |  0 +    ///  1  |  1  |  1 (-1 >= -1) |  1 +    if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) +      return getTrue(ITy); +    break; +  case ICmpInst::ICMP_ULE: +    if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) +      return getTrue(ITy); +    break; +  } + +  return nullptr; +} + +/// Try hard to fold icmp with zero RHS because this is a common case. +static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, +                                   Value *RHS, const SimplifyQuery &Q) { +  if (!match(RHS, m_Zero())) +    return nullptr; + +  Type *ITy = GetCompareTy(LHS); // The return type. +  switch (Pred) { +  default: +    llvm_unreachable("Unknown ICmp predicate!"); +  case ICmpInst::ICMP_ULT: +    return getFalse(ITy); +  case ICmpInst::ICMP_UGE: +    return getTrue(ITy); +  case ICmpInst::ICMP_EQ: +  case ICmpInst::ICMP_ULE: +    if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) +      return getFalse(ITy); +    break; +  case ICmpInst::ICMP_NE: +  case ICmpInst::ICMP_UGT: +    if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) +      return getTrue(ITy); +    break; +  case ICmpInst::ICMP_SLT: { +    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +    if (LHSKnown.isNegative()) +      return getTrue(ITy); +    if (LHSKnown.isNonNegative()) +      return getFalse(ITy); +    break; +  } +  case ICmpInst::ICMP_SLE: { +    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +    if (LHSKnown.isNegative()) +      return getTrue(ITy); +    if (LHSKnown.isNonNegative() && +        isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) +      return getFalse(ITy); +    break; +  } +  case ICmpInst::ICMP_SGE: { +    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +    if (LHSKnown.isNegative()) +      return getFalse(ITy); +    if (LHSKnown.isNonNegative()) +      return getTrue(ITy); +    break; +  } +  case ICmpInst::ICMP_SGT: { +    KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +    if (LHSKnown.isNegative()) +      return getFalse(ITy); +    if (LHSKnown.isNonNegative() && +        isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) +      return getTrue(ITy); +    break; +  } +  } + +  return nullptr; +} + +/// Many binary operators with a constant operand have an easy-to-compute +/// range of outputs. This can be used to fold a comparison to always true or +/// always false. +static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { +  unsigned Width = Lower.getBitWidth(); +  const APInt *C; +  switch (BO.getOpcode()) { +  case Instruction::Add: +    if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { +      // FIXME: If we have both nuw and nsw, we should reduce the range further. +      if (BO.hasNoUnsignedWrap()) { +        // 'add nuw x, C' produces [C, UINT_MAX]. +        Lower = *C; +      } else if (BO.hasNoSignedWrap()) { +        if (C->isNegative()) { +          // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C]. +          Lower = APInt::getSignedMinValue(Width); +          Upper = APInt::getSignedMaxValue(Width) + *C + 1; +        } else { +          // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX]. +          Lower = APInt::getSignedMinValue(Width) + *C; +          Upper = APInt::getSignedMaxValue(Width) + 1; +        } +      } +    } +    break; + +  case Instruction::And: +    if (match(BO.getOperand(1), m_APInt(C))) +      // 'and x, C' produces [0, C]. +      Upper = *C + 1; +    break; + +  case Instruction::Or: +    if (match(BO.getOperand(1), m_APInt(C))) +      // 'or x, C' produces [C, UINT_MAX]. +      Lower = *C; +    break; + +  case Instruction::AShr: +    if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { +      // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C]. +      Lower = APInt::getSignedMinValue(Width).ashr(*C); +      Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1; +    } else if (match(BO.getOperand(0), m_APInt(C))) { +      unsigned ShiftAmount = Width - 1; +      if (!C->isNullValue() && BO.isExact()) +        ShiftAmount = C->countTrailingZeros(); +      if (C->isNegative()) { +        // 'ashr C, x' produces [C, C >> (Width-1)] +        Lower = *C; +        Upper = C->ashr(ShiftAmount) + 1; +      } else { +        // 'ashr C, x' produces [C >> (Width-1), C] +        Lower = C->ashr(ShiftAmount); +        Upper = *C + 1; +      } +    } +    break; + +  case Instruction::LShr: +    if (match(BO.getOperand(1), m_APInt(C)) && C->ult(Width)) { +      // 'lshr x, C' produces [0, UINT_MAX >> C]. +      Upper = APInt::getAllOnesValue(Width).lshr(*C) + 1; +    } else if (match(BO.getOperand(0), m_APInt(C))) { +      // 'lshr C, x' produces [C >> (Width-1), C]. +      unsigned ShiftAmount = Width - 1; +      if (!C->isNullValue() && BO.isExact()) +        ShiftAmount = C->countTrailingZeros(); +      Lower = C->lshr(ShiftAmount); +      Upper = *C + 1; +    } +    break; + +  case Instruction::Shl: +    if (match(BO.getOperand(0), m_APInt(C))) { +      if (BO.hasNoUnsignedWrap()) { +        // 'shl nuw C, x' produces [C, C << CLZ(C)] +        Lower = *C; +        Upper = Lower.shl(Lower.countLeadingZeros()) + 1; +      } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw? +        if (C->isNegative()) { +          // 'shl nsw C, x' produces [C << CLO(C)-1, C] +          unsigned ShiftAmount = C->countLeadingOnes() - 1; +          Lower = C->shl(ShiftAmount); +          Upper = *C + 1; +        } else { +          // 'shl nsw C, x' produces [C, C << CLZ(C)-1] +          unsigned ShiftAmount = C->countLeadingZeros() - 1; +          Lower = *C; +          Upper = C->shl(ShiftAmount) + 1; +        } +      } +    } +    break; + +  case Instruction::SDiv: +    if (match(BO.getOperand(1), m_APInt(C))) { +      APInt IntMin = APInt::getSignedMinValue(Width); +      APInt IntMax = APInt::getSignedMaxValue(Width); +      if (C->isAllOnesValue()) { +        // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] +        //    where C != -1 and C != 0 and C != 1 +        Lower = IntMin + 1; +        Upper = IntMax + 1; +      } else if (C->countLeadingZeros() < Width - 1) { +        // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C] +        //    where C != -1 and C != 0 and C != 1 +        Lower = IntMin.sdiv(*C); +        Upper = IntMax.sdiv(*C); +        if (Lower.sgt(Upper)) +          std::swap(Lower, Upper); +        Upper = Upper + 1; +        assert(Upper != Lower && "Upper part of range has wrapped!"); +      } +    } else if (match(BO.getOperand(0), m_APInt(C))) { +      if (C->isMinSignedValue()) { +        // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2]. +        Lower = *C; +        Upper = Lower.lshr(1) + 1; +      } else { +        // 'sdiv C, x' produces [-|C|, |C|]. +        Upper = C->abs() + 1; +        Lower = (-Upper) + 1; +      } +    } +    break; + +  case Instruction::UDiv: +    if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { +      // 'udiv x, C' produces [0, UINT_MAX / C]. +      Upper = APInt::getMaxValue(Width).udiv(*C) + 1; +    } else if (match(BO.getOperand(0), m_APInt(C))) { +      // 'udiv C, x' produces [0, C]. +      Upper = *C + 1; +    } +    break; + +  case Instruction::SRem: +    if (match(BO.getOperand(1), m_APInt(C))) { +      // 'srem x, C' produces (-|C|, |C|). +      Upper = C->abs(); +      Lower = (-Upper) + 1; +    } +    break; + +  case Instruction::URem: +    if (match(BO.getOperand(1), m_APInt(C))) +      // 'urem x, C' produces [0, C). +      Upper = *C; +    break; + +  default: +    break; +  } +} + +static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, +                                       Value *RHS) { +  Type *ITy = GetCompareTy(RHS); // The return type. + +  Value *X; +  // Sign-bit checks can be optimized to true/false after unsigned +  // floating-point casts: +  // icmp slt (bitcast (uitofp X)),  0 --> false +  // icmp sgt (bitcast (uitofp X)), -1 --> true +  if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) { +    if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) +      return ConstantInt::getFalse(ITy); +    if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes())) +      return ConstantInt::getTrue(ITy); +  } + +  const APInt *C; +  if (!match(RHS, m_APInt(C))) +    return nullptr; + +  // Rule out tautological comparisons (eg., ult 0 or uge 0). +  ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C); +  if (RHS_CR.isEmptySet()) +    return ConstantInt::getFalse(ITy); +  if (RHS_CR.isFullSet()) +    return ConstantInt::getTrue(ITy); + +  // Find the range of possible values for binary operators. +  unsigned Width = C->getBitWidth(); +  APInt Lower = APInt(Width, 0); +  APInt Upper = APInt(Width, 0); +  if (auto *BO = dyn_cast<BinaryOperator>(LHS)) +    setLimitsForBinOp(*BO, Lower, Upper); + +  ConstantRange LHS_CR = +      Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); + +  if (auto *I = dyn_cast<Instruction>(LHS)) +    if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) +      LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); + +  if (!LHS_CR.isFullSet()) { +    if (RHS_CR.contains(LHS_CR)) +      return ConstantInt::getTrue(ITy); +    if (RHS_CR.inverse().contains(LHS_CR)) +      return ConstantInt::getFalse(ITy); +  } + +  return nullptr; +} + +/// TODO: A large part of this logic is duplicated in InstCombine's +/// foldICmpBinOp(). We should be able to share that and avoid the code +/// duplication. +static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, +                                    Value *RHS, const SimplifyQuery &Q, +                                    unsigned MaxRecurse) { +  Type *ITy = GetCompareTy(LHS); // The return type. + +  BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); +  BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); +  if (MaxRecurse && (LBO || RBO)) { +    // Analyze the case when either LHS or RHS is an add instruction. +    Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; +    // LHS = A + B (or A and B are null); RHS = C + D (or C and D are null). +    bool NoLHSWrapProblem = false, NoRHSWrapProblem = false; +    if (LBO && LBO->getOpcode() == Instruction::Add) { +      A = LBO->getOperand(0); +      B = LBO->getOperand(1); +      NoLHSWrapProblem = +          ICmpInst::isEquality(Pred) || +          (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) || +          (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap()); +    } +    if (RBO && RBO->getOpcode() == Instruction::Add) { +      C = RBO->getOperand(0); +      D = RBO->getOperand(1); +      NoRHSWrapProblem = +          ICmpInst::isEquality(Pred) || +          (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) || +          (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap()); +    } + +    // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. +    if ((A == RHS || B == RHS) && NoLHSWrapProblem) +      if (Value *V = SimplifyICmpInst(Pred, A == RHS ? B : A, +                                      Constant::getNullValue(RHS->getType()), Q, +                                      MaxRecurse - 1)) +        return V; + +    // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. +    if ((C == LHS || D == LHS) && NoRHSWrapProblem) +      if (Value *V = +              SimplifyICmpInst(Pred, Constant::getNullValue(LHS->getType()), +                               C == LHS ? D : C, Q, MaxRecurse - 1)) +        return V; + +    // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. +    if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem && +        NoRHSWrapProblem) { +      // Determine Y and Z in the form icmp (X+Y), (X+Z). +      Value *Y, *Z; +      if (A == C) { +        // C + B == C + D  ->  B == D +        Y = B; +        Z = D; +      } else if (A == D) { +        // D + B == C + D  ->  B == C +        Y = B; +        Z = C; +      } else if (B == C) { +        // A + C == C + D  ->  A == D +        Y = A; +        Z = D; +      } else { +        assert(B == D); +        // A + D == C + D  ->  A == C +        Y = A; +        Z = C; +      } +      if (Value *V = SimplifyICmpInst(Pred, Y, Z, Q, MaxRecurse - 1)) +        return V; +    } +  } + +  { +    Value *Y = nullptr; +    // icmp pred (or X, Y), X +    if (LBO && match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { +      if (Pred == ICmpInst::ICMP_ULT) +        return getFalse(ITy); +      if (Pred == ICmpInst::ICMP_UGE) +        return getTrue(ITy); + +      if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { +        KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +        KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +        if (RHSKnown.isNonNegative() && YKnown.isNegative()) +          return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); +        if (RHSKnown.isNegative() || YKnown.isNonNegative()) +          return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); +      } +    } +    // icmp pred X, (or X, Y) +    if (RBO && match(RBO, m_c_Or(m_Value(Y), m_Specific(LHS)))) { +      if (Pred == ICmpInst::ICMP_ULE) +        return getTrue(ITy); +      if (Pred == ICmpInst::ICMP_UGT) +        return getFalse(ITy); + +      if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { +        KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +        KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +        if (LHSKnown.isNonNegative() && YKnown.isNegative()) +          return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); +        if (LHSKnown.isNegative() || YKnown.isNonNegative()) +          return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); +      } +    } +  } + +  // icmp pred (and X, Y), X +  if (LBO && match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { +    if (Pred == ICmpInst::ICMP_UGT) +      return getFalse(ITy); +    if (Pred == ICmpInst::ICMP_ULE) +      return getTrue(ITy); +  } +  // icmp pred X, (and X, Y) +  if (RBO && match(RBO, m_c_And(m_Value(), m_Specific(LHS)))) { +    if (Pred == ICmpInst::ICMP_UGE) +      return getTrue(ITy); +    if (Pred == ICmpInst::ICMP_ULT) +      return getFalse(ITy); +  } + +  // 0 - (zext X) pred C +  if (!CmpInst::isUnsigned(Pred) && match(LHS, m_Neg(m_ZExt(m_Value())))) { +    if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { +      if (RHSC->getValue().isStrictlyPositive()) { +        if (Pred == ICmpInst::ICMP_SLT) +          return ConstantInt::getTrue(RHSC->getContext()); +        if (Pred == ICmpInst::ICMP_SGE) +          return ConstantInt::getFalse(RHSC->getContext()); +        if (Pred == ICmpInst::ICMP_EQ) +          return ConstantInt::getFalse(RHSC->getContext()); +        if (Pred == ICmpInst::ICMP_NE) +          return ConstantInt::getTrue(RHSC->getContext()); +      } +      if (RHSC->getValue().isNonNegative()) { +        if (Pred == ICmpInst::ICMP_SLE) +          return ConstantInt::getTrue(RHSC->getContext()); +        if (Pred == ICmpInst::ICMP_SGT) +          return ConstantInt::getFalse(RHSC->getContext()); +      } +    } +  } + +  // icmp pred (urem X, Y), Y +  if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { +    switch (Pred) { +    default: +      break; +    case ICmpInst::ICMP_SGT: +    case ICmpInst::ICMP_SGE: { +      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      if (!Known.isNonNegative()) +        break; +      LLVM_FALLTHROUGH; +    } +    case ICmpInst::ICMP_EQ: +    case ICmpInst::ICMP_UGT: +    case ICmpInst::ICMP_UGE: +      return getFalse(ITy); +    case ICmpInst::ICMP_SLT: +    case ICmpInst::ICMP_SLE: { +      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      if (!Known.isNonNegative()) +        break; +      LLVM_FALLTHROUGH; +    } +    case ICmpInst::ICMP_NE: +    case ICmpInst::ICMP_ULT: +    case ICmpInst::ICMP_ULE: +      return getTrue(ITy); +    } +  } + +  // icmp pred X, (urem Y, X) +  if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { +    switch (Pred) { +    default: +      break; +    case ICmpInst::ICMP_SGT: +    case ICmpInst::ICMP_SGE: { +      KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      if (!Known.isNonNegative()) +        break; +      LLVM_FALLTHROUGH; +    } +    case ICmpInst::ICMP_NE: +    case ICmpInst::ICMP_UGT: +    case ICmpInst::ICMP_UGE: +      return getTrue(ITy); +    case ICmpInst::ICMP_SLT: +    case ICmpInst::ICMP_SLE: { +      KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      if (!Known.isNonNegative()) +        break; +      LLVM_FALLTHROUGH; +    } +    case ICmpInst::ICMP_EQ: +    case ICmpInst::ICMP_ULT: +    case ICmpInst::ICMP_ULE: +      return getFalse(ITy); +    } +  } + +  // x >> y <=u x +  // x udiv y <=u x. +  if (LBO && (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || +              match(LBO, m_UDiv(m_Specific(RHS), m_Value())))) { +    // icmp pred (X op Y), X +    if (Pred == ICmpInst::ICMP_UGT) +      return getFalse(ITy); +    if (Pred == ICmpInst::ICMP_ULE) +      return getTrue(ITy); +  } + +  // x >=u x >> y +  // x >=u x udiv y. +  if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) || +              match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) { +    // icmp pred X, (X op Y) +    if (Pred == ICmpInst::ICMP_ULT) +      return getFalse(ITy); +    if (Pred == ICmpInst::ICMP_UGE) +      return getTrue(ITy); +  } + +  // handle: +  //   CI2 << X == CI +  //   CI2 << X != CI +  // +  //   where CI2 is a power of 2 and CI isn't +  if (auto *CI = dyn_cast<ConstantInt>(RHS)) { +    const APInt *CI2Val, *CIVal = &CI->getValue(); +    if (LBO && match(LBO, m_Shl(m_APInt(CI2Val), m_Value())) && +        CI2Val->isPowerOf2()) { +      if (!CIVal->isPowerOf2()) { +        // CI2 << X can equal zero in some circumstances, +        // this simplification is unsafe if CI is zero. +        // +        // We know it is safe if: +        // - The shift is nsw, we can't shift out the one bit. +        // - The shift is nuw, we can't shift out the one bit. +        // - CI2 is one +        // - CI isn't zero +        if (LBO->hasNoSignedWrap() || LBO->hasNoUnsignedWrap() || +            CI2Val->isOneValue() || !CI->isZero()) { +          if (Pred == ICmpInst::ICMP_EQ) +            return ConstantInt::getFalse(RHS->getContext()); +          if (Pred == ICmpInst::ICMP_NE) +            return ConstantInt::getTrue(RHS->getContext()); +        } +      } +      if (CIVal->isSignMask() && CI2Val->isOneValue()) { +        if (Pred == ICmpInst::ICMP_UGT) +          return ConstantInt::getFalse(RHS->getContext()); +        if (Pred == ICmpInst::ICMP_ULE) +          return ConstantInt::getTrue(RHS->getContext()); +      } +    } +  } + +  if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && +      LBO->getOperand(1) == RBO->getOperand(1)) { +    switch (LBO->getOpcode()) { +    default: +      break; +    case Instruction::UDiv: +    case Instruction::LShr: +      if (ICmpInst::isSigned(Pred) || !LBO->isExact() || !RBO->isExact()) +        break; +      if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), +                                      RBO->getOperand(0), Q, MaxRecurse - 1)) +          return V; +      break; +    case Instruction::SDiv: +      if (!ICmpInst::isEquality(Pred) || !LBO->isExact() || !RBO->isExact()) +        break; +      if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), +                                      RBO->getOperand(0), Q, MaxRecurse - 1)) +        return V; +      break; +    case Instruction::AShr: +      if (!LBO->isExact() || !RBO->isExact()) +        break; +      if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), +                                      RBO->getOperand(0), Q, MaxRecurse - 1)) +        return V; +      break; +    case Instruction::Shl: { +      bool NUW = LBO->hasNoUnsignedWrap() && RBO->hasNoUnsignedWrap(); +      bool NSW = LBO->hasNoSignedWrap() && RBO->hasNoSignedWrap(); +      if (!NUW && !NSW) +        break; +      if (!NSW && ICmpInst::isSigned(Pred)) +        break; +      if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), +                                      RBO->getOperand(0), Q, MaxRecurse - 1)) +        return V; +      break; +    } +    } +  } +  return nullptr; +} + +/// Simplify integer comparisons where at least one operand of the compare +/// matches an integer min/max idiom. +static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, +                                     Value *RHS, const SimplifyQuery &Q, +                                     unsigned MaxRecurse) { +  Type *ITy = GetCompareTy(LHS); // The return type. +  Value *A, *B; +  CmpInst::Predicate P = CmpInst::BAD_ICMP_PREDICATE; +  CmpInst::Predicate EqP; // Chosen so that "A == max/min(A,B)" iff "A EqP B". + +  // Signed variants on "max(a,b)>=a -> true". +  if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { +    if (A != RHS) +      std::swap(A, B);       // smax(A, B) pred A. +    EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". +    // We analyze this as smax(A, B) pred A. +    P = Pred; +  } else if (match(RHS, m_SMax(m_Value(A), m_Value(B))) && +             (A == LHS || B == LHS)) { +    if (A != LHS) +      std::swap(A, B);       // A pred smax(A, B). +    EqP = CmpInst::ICMP_SGE; // "A == smax(A, B)" iff "A sge B". +    // We analyze this as smax(A, B) swapped-pred A. +    P = CmpInst::getSwappedPredicate(Pred); +  } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && +             (A == RHS || B == RHS)) { +    if (A != RHS) +      std::swap(A, B);       // smin(A, B) pred A. +    EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". +    // We analyze this as smax(-A, -B) swapped-pred -A. +    // Note that we do not need to actually form -A or -B thanks to EqP. +    P = CmpInst::getSwappedPredicate(Pred); +  } else if (match(RHS, m_SMin(m_Value(A), m_Value(B))) && +             (A == LHS || B == LHS)) { +    if (A != LHS) +      std::swap(A, B);       // A pred smin(A, B). +    EqP = CmpInst::ICMP_SLE; // "A == smin(A, B)" iff "A sle B". +    // We analyze this as smax(-A, -B) pred -A. +    // Note that we do not need to actually form -A or -B thanks to EqP. +    P = Pred; +  } +  if (P != CmpInst::BAD_ICMP_PREDICATE) { +    // Cases correspond to "max(A, B) p A". +    switch (P) { +    default: +      break; +    case CmpInst::ICMP_EQ: +    case CmpInst::ICMP_SLE: +      // Equivalent to "A EqP B".  This may be the same as the condition tested +      // in the max/min; if so, we can just return that. +      if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) +        return V; +      if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) +        return V; +      // Otherwise, see if "A EqP B" simplifies. +      if (MaxRecurse) +        if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) +          return V; +      break; +    case CmpInst::ICMP_NE: +    case CmpInst::ICMP_SGT: { +      CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); +      // Equivalent to "A InvEqP B".  This may be the same as the condition +      // tested in the max/min; if so, we can just return that. +      if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) +        return V; +      if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) +        return V; +      // Otherwise, see if "A InvEqP B" simplifies. +      if (MaxRecurse) +        if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) +          return V; +      break; +    } +    case CmpInst::ICMP_SGE: +      // Always true. +      return getTrue(ITy); +    case CmpInst::ICMP_SLT: +      // Always false. +      return getFalse(ITy); +    } +  } + +  // Unsigned variants on "max(a,b)>=a -> true". +  P = CmpInst::BAD_ICMP_PREDICATE; +  if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && (A == RHS || B == RHS)) { +    if (A != RHS) +      std::swap(A, B);       // umax(A, B) pred A. +    EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". +    // We analyze this as umax(A, B) pred A. +    P = Pred; +  } else if (match(RHS, m_UMax(m_Value(A), m_Value(B))) && +             (A == LHS || B == LHS)) { +    if (A != LHS) +      std::swap(A, B);       // A pred umax(A, B). +    EqP = CmpInst::ICMP_UGE; // "A == umax(A, B)" iff "A uge B". +    // We analyze this as umax(A, B) swapped-pred A. +    P = CmpInst::getSwappedPredicate(Pred); +  } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && +             (A == RHS || B == RHS)) { +    if (A != RHS) +      std::swap(A, B);       // umin(A, B) pred A. +    EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". +    // We analyze this as umax(-A, -B) swapped-pred -A. +    // Note that we do not need to actually form -A or -B thanks to EqP. +    P = CmpInst::getSwappedPredicate(Pred); +  } else if (match(RHS, m_UMin(m_Value(A), m_Value(B))) && +             (A == LHS || B == LHS)) { +    if (A != LHS) +      std::swap(A, B);       // A pred umin(A, B). +    EqP = CmpInst::ICMP_ULE; // "A == umin(A, B)" iff "A ule B". +    // We analyze this as umax(-A, -B) pred -A. +    // Note that we do not need to actually form -A or -B thanks to EqP. +    P = Pred; +  } +  if (P != CmpInst::BAD_ICMP_PREDICATE) { +    // Cases correspond to "max(A, B) p A". +    switch (P) { +    default: +      break; +    case CmpInst::ICMP_EQ: +    case CmpInst::ICMP_ULE: +      // Equivalent to "A EqP B".  This may be the same as the condition tested +      // in the max/min; if so, we can just return that. +      if (Value *V = ExtractEquivalentCondition(LHS, EqP, A, B)) +        return V; +      if (Value *V = ExtractEquivalentCondition(RHS, EqP, A, B)) +        return V; +      // Otherwise, see if "A EqP B" simplifies. +      if (MaxRecurse) +        if (Value *V = SimplifyICmpInst(EqP, A, B, Q, MaxRecurse - 1)) +          return V; +      break; +    case CmpInst::ICMP_NE: +    case CmpInst::ICMP_UGT: { +      CmpInst::Predicate InvEqP = CmpInst::getInversePredicate(EqP); +      // Equivalent to "A InvEqP B".  This may be the same as the condition +      // tested in the max/min; if so, we can just return that. +      if (Value *V = ExtractEquivalentCondition(LHS, InvEqP, A, B)) +        return V; +      if (Value *V = ExtractEquivalentCondition(RHS, InvEqP, A, B)) +        return V; +      // Otherwise, see if "A InvEqP B" simplifies. +      if (MaxRecurse) +        if (Value *V = SimplifyICmpInst(InvEqP, A, B, Q, MaxRecurse - 1)) +          return V; +      break; +    } +    case CmpInst::ICMP_UGE: +      // Always true. +      return getTrue(ITy); +    case CmpInst::ICMP_ULT: +      // Always false. +      return getFalse(ITy); +    } +  } + +  // Variants on "max(x,y) >= min(x,z)". +  Value *C, *D; +  if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && +      match(RHS, m_SMin(m_Value(C), m_Value(D))) && +      (A == C || A == D || B == C || B == D)) { +    // max(x, ?) pred min(x, ?). +    if (Pred == CmpInst::ICMP_SGE) +      // Always true. +      return getTrue(ITy); +    if (Pred == CmpInst::ICMP_SLT) +      // Always false. +      return getFalse(ITy); +  } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && +             match(RHS, m_SMax(m_Value(C), m_Value(D))) && +             (A == C || A == D || B == C || B == D)) { +    // min(x, ?) pred max(x, ?). +    if (Pred == CmpInst::ICMP_SLE) +      // Always true. +      return getTrue(ITy); +    if (Pred == CmpInst::ICMP_SGT) +      // Always false. +      return getFalse(ITy); +  } else if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && +             match(RHS, m_UMin(m_Value(C), m_Value(D))) && +             (A == C || A == D || B == C || B == D)) { +    // max(x, ?) pred min(x, ?). +    if (Pred == CmpInst::ICMP_UGE) +      // Always true. +      return getTrue(ITy); +    if (Pred == CmpInst::ICMP_ULT) +      // Always false. +      return getFalse(ITy); +  } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && +             match(RHS, m_UMax(m_Value(C), m_Value(D))) && +             (A == C || A == D || B == C || B == D)) { +    // min(x, ?) pred max(x, ?). +    if (Pred == CmpInst::ICMP_ULE) +      // Always true. +      return getTrue(ITy); +    if (Pred == CmpInst::ICMP_UGT) +      // Always false. +      return getFalse(ITy); +  } + +  return nullptr; +} + +/// Given operands for an ICmpInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +                               const SimplifyQuery &Q, unsigned MaxRecurse) { +  CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; +  assert(CmpInst::isIntPredicate(Pred) && "Not an integer compare!"); + +  if (Constant *CLHS = dyn_cast<Constant>(LHS)) { +    if (Constant *CRHS = dyn_cast<Constant>(RHS)) +      return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + +    // If we have a constant, make sure it is on the RHS. +    std::swap(LHS, RHS); +    Pred = CmpInst::getSwappedPredicate(Pred); +  } + +  Type *ITy = GetCompareTy(LHS); // The return type. + +  // icmp X, X -> true/false +  // icmp X, undef -> true/false because undef could be X. +  if (LHS == RHS || isa<UndefValue>(RHS)) +    return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); + +  if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) +    return V; + +  if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q)) +    return V; + +  if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS)) +    return V; + +  // If both operands have range metadata, use the metadata +  // to simplify the comparison. +  if (isa<Instruction>(RHS) && isa<Instruction>(LHS)) { +    auto RHS_Instr = cast<Instruction>(RHS); +    auto LHS_Instr = cast<Instruction>(LHS); + +    if (RHS_Instr->getMetadata(LLVMContext::MD_range) && +        LHS_Instr->getMetadata(LLVMContext::MD_range)) { +      auto RHS_CR = getConstantRangeFromMetadata( +          *RHS_Instr->getMetadata(LLVMContext::MD_range)); +      auto LHS_CR = getConstantRangeFromMetadata( +          *LHS_Instr->getMetadata(LLVMContext::MD_range)); + +      auto Satisfied_CR = ConstantRange::makeSatisfyingICmpRegion(Pred, RHS_CR); +      if (Satisfied_CR.contains(LHS_CR)) +        return ConstantInt::getTrue(RHS->getContext()); + +      auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( +                CmpInst::getInversePredicate(Pred), RHS_CR); +      if (InversedSatisfied_CR.contains(LHS_CR)) +        return ConstantInt::getFalse(RHS->getContext()); +    } +  } + +  // Compare of cast, for example (zext X) != 0 -> X != 0 +  if (isa<CastInst>(LHS) && (isa<Constant>(RHS) || isa<CastInst>(RHS))) { +    Instruction *LI = cast<CastInst>(LHS); +    Value *SrcOp = LI->getOperand(0); +    Type *SrcTy = SrcOp->getType(); +    Type *DstTy = LI->getType(); + +    // Turn icmp (ptrtoint x), (ptrtoint/constant) into a compare of the input +    // if the integer type is the same size as the pointer type. +    if (MaxRecurse && isa<PtrToIntInst>(LI) && +        Q.DL.getTypeSizeInBits(SrcTy) == DstTy->getPrimitiveSizeInBits()) { +      if (Constant *RHSC = dyn_cast<Constant>(RHS)) { +        // Transfer the cast to the constant. +        if (Value *V = SimplifyICmpInst(Pred, SrcOp, +                                        ConstantExpr::getIntToPtr(RHSC, SrcTy), +                                        Q, MaxRecurse-1)) +          return V; +      } else if (PtrToIntInst *RI = dyn_cast<PtrToIntInst>(RHS)) { +        if (RI->getOperand(0)->getType() == SrcTy) +          // Compare without the cast. +          if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), +                                          Q, MaxRecurse-1)) +            return V; +      } +    } + +    if (isa<ZExtInst>(LHS)) { +      // Turn icmp (zext X), (zext Y) into a compare of X and Y if they have the +      // same type. +      if (ZExtInst *RI = dyn_cast<ZExtInst>(RHS)) { +        if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) +          // Compare X and Y.  Note that signed predicates become unsigned. +          if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), +                                          SrcOp, RI->getOperand(0), Q, +                                          MaxRecurse-1)) +            return V; +      } +      // Turn icmp (zext X), Cst into a compare of X and Cst if Cst is extended +      // too.  If not, then try to deduce the result of the comparison. +      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { +        // Compute the constant that would happen if we truncated to SrcTy then +        // reextended to DstTy. +        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); +        Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); + +        // If the re-extended constant didn't change then this is effectively +        // also a case of comparing two zero-extended values. +        if (RExt == CI && MaxRecurse) +          if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), +                                        SrcOp, Trunc, Q, MaxRecurse-1)) +            return V; + +        // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit +        // there.  Use this to work out the result of the comparison. +        if (RExt != CI) { +          switch (Pred) { +          default: llvm_unreachable("Unknown ICmp predicate!"); +          // LHS <u RHS. +          case ICmpInst::ICMP_EQ: +          case ICmpInst::ICMP_UGT: +          case ICmpInst::ICMP_UGE: +            return ConstantInt::getFalse(CI->getContext()); + +          case ICmpInst::ICMP_NE: +          case ICmpInst::ICMP_ULT: +          case ICmpInst::ICMP_ULE: +            return ConstantInt::getTrue(CI->getContext()); + +          // LHS is non-negative.  If RHS is negative then LHS >s LHS.  If RHS +          // is non-negative then LHS <s RHS. +          case ICmpInst::ICMP_SGT: +          case ICmpInst::ICMP_SGE: +            return CI->getValue().isNegative() ? +              ConstantInt::getTrue(CI->getContext()) : +              ConstantInt::getFalse(CI->getContext()); + +          case ICmpInst::ICMP_SLT: +          case ICmpInst::ICMP_SLE: +            return CI->getValue().isNegative() ? +              ConstantInt::getFalse(CI->getContext()) : +              ConstantInt::getTrue(CI->getContext()); +          } +        } +      } +    } + +    if (isa<SExtInst>(LHS)) { +      // Turn icmp (sext X), (sext Y) into a compare of X and Y if they have the +      // same type. +      if (SExtInst *RI = dyn_cast<SExtInst>(RHS)) { +        if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) +          // Compare X and Y.  Note that the predicate does not change. +          if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), +                                          Q, MaxRecurse-1)) +            return V; +      } +      // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended +      // too.  If not, then try to deduce the result of the comparison. +      else if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { +        // Compute the constant that would happen if we truncated to SrcTy then +        // reextended to DstTy. +        Constant *Trunc = ConstantExpr::getTrunc(CI, SrcTy); +        Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); + +        // If the re-extended constant didn't change then this is effectively +        // also a case of comparing two sign-extended values. +        if (RExt == CI && MaxRecurse) +          if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) +            return V; + +        // Otherwise the upper bits of LHS are all equal, while RHS has varying +        // bits there.  Use this to work out the result of the comparison. +        if (RExt != CI) { +          switch (Pred) { +          default: llvm_unreachable("Unknown ICmp predicate!"); +          case ICmpInst::ICMP_EQ: +            return ConstantInt::getFalse(CI->getContext()); +          case ICmpInst::ICMP_NE: +            return ConstantInt::getTrue(CI->getContext()); + +          // If RHS is non-negative then LHS <s RHS.  If RHS is negative then +          // LHS >s RHS. +          case ICmpInst::ICMP_SGT: +          case ICmpInst::ICMP_SGE: +            return CI->getValue().isNegative() ? +              ConstantInt::getTrue(CI->getContext()) : +              ConstantInt::getFalse(CI->getContext()); +          case ICmpInst::ICMP_SLT: +          case ICmpInst::ICMP_SLE: +            return CI->getValue().isNegative() ? +              ConstantInt::getFalse(CI->getContext()) : +              ConstantInt::getTrue(CI->getContext()); + +          // If LHS is non-negative then LHS <u RHS.  If LHS is negative then +          // LHS >u RHS. +          case ICmpInst::ICMP_UGT: +          case ICmpInst::ICMP_UGE: +            // Comparison is true iff the LHS <s 0. +            if (MaxRecurse) +              if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SLT, SrcOp, +                                              Constant::getNullValue(SrcTy), +                                              Q, MaxRecurse-1)) +                return V; +            break; +          case ICmpInst::ICMP_ULT: +          case ICmpInst::ICMP_ULE: +            // Comparison is true iff the LHS >=s 0. +            if (MaxRecurse) +              if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, +                                              Constant::getNullValue(SrcTy), +                                              Q, MaxRecurse-1)) +                return V; +            break; +          } +        } +      } +    } +  } + +  // icmp eq|ne X, Y -> false|true if X != Y +  if (ICmpInst::isEquality(Pred) && +      isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) { +    return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy); +  } + +  if (Value *V = simplifyICmpWithBinOp(Pred, LHS, RHS, Q, MaxRecurse)) +    return V; + +  if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) +    return V; + +  // Simplify comparisons of related pointers using a powerful, recursive +  // GEP-walk when we have target data available.. +  if (LHS->getType()->isPointerTy()) +    if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, LHS, +                                     RHS)) +      return C; +  if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS)) +    if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS)) +      if (Q.DL.getTypeSizeInBits(CLHS->getPointerOperandType()) == +              Q.DL.getTypeSizeInBits(CLHS->getType()) && +          Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) == +              Q.DL.getTypeSizeInBits(CRHS->getType())) +        if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, +                                         CLHS->getPointerOperand(), +                                         CRHS->getPointerOperand())) +          return C; + +  if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) { +    if (GEPOperator *GRHS = dyn_cast<GEPOperator>(RHS)) { +      if (GLHS->getPointerOperand() == GRHS->getPointerOperand() && +          GLHS->hasAllConstantIndices() && GRHS->hasAllConstantIndices() && +          (ICmpInst::isEquality(Pred) || +           (GLHS->isInBounds() && GRHS->isInBounds() && +            Pred == ICmpInst::getSignedPredicate(Pred)))) { +        // The bases are equal and the indices are constant.  Build a constant +        // expression GEP with the same indices and a null base pointer to see +        // what constant folding can make out of it. +        Constant *Null = Constant::getNullValue(GLHS->getPointerOperandType()); +        SmallVector<Value *, 4> IndicesLHS(GLHS->idx_begin(), GLHS->idx_end()); +        Constant *NewLHS = ConstantExpr::getGetElementPtr( +            GLHS->getSourceElementType(), Null, IndicesLHS); + +        SmallVector<Value *, 4> IndicesRHS(GRHS->idx_begin(), GRHS->idx_end()); +        Constant *NewRHS = ConstantExpr::getGetElementPtr( +            GLHS->getSourceElementType(), Null, IndicesRHS); +        return ConstantExpr::getICmp(Pred, NewLHS, NewRHS); +      } +    } +  } + +  // If the comparison is with the result of a select instruction, check whether +  // comparing with either branch of the select always yields the same value. +  if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) +    if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) +      return V; + +  // If the comparison is with the result of a phi instruction, check whether +  // doing the compare with each incoming phi value yields a common result. +  if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) +    if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) +      return V; + +  return nullptr; +} + +Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, +                              const SimplifyQuery &Q) { +  return ::SimplifyICmpInst(Predicate, LHS, RHS, Q, RecursionLimit); +} + +/// Given operands for an FCmpInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +                               FastMathFlags FMF, const SimplifyQuery &Q, +                               unsigned MaxRecurse) { +  CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; +  assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!"); + +  if (Constant *CLHS = dyn_cast<Constant>(LHS)) { +    if (Constant *CRHS = dyn_cast<Constant>(RHS)) +      return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + +    // If we have a constant, make sure it is on the RHS. +    std::swap(LHS, RHS); +    Pred = CmpInst::getSwappedPredicate(Pred); +  } + +  // Fold trivial predicates. +  Type *RetTy = GetCompareTy(LHS); +  if (Pred == FCmpInst::FCMP_FALSE) +    return getFalse(RetTy); +  if (Pred == FCmpInst::FCMP_TRUE) +    return getTrue(RetTy); + +  // UNO/ORD predicates can be trivially folded if NaNs are ignored. +  if (FMF.noNaNs()) { +    if (Pred == FCmpInst::FCMP_UNO) +      return getFalse(RetTy); +    if (Pred == FCmpInst::FCMP_ORD) +      return getTrue(RetTy); +  } + +  // NaN is unordered; NaN is not ordered. +  assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) && +         "Comparison must be either ordered or unordered"); +  if (match(RHS, m_NaN())) +    return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); + +  // fcmp pred x, undef  and  fcmp pred undef, x +  // fold to true if unordered, false if ordered +  if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { +    // Choosing NaN for the undef will always make unordered comparison succeed +    // and ordered comparison fail. +    return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); +  } + +  // fcmp x,x -> true/false.  Not all compares are foldable. +  if (LHS == RHS) { +    if (CmpInst::isTrueWhenEqual(Pred)) +      return getTrue(RetTy); +    if (CmpInst::isFalseWhenEqual(Pred)) +      return getFalse(RetTy); +  } + +  // Handle fcmp with constant RHS. +  const APFloat *C; +  if (match(RHS, m_APFloat(C))) { +    // Check whether the constant is an infinity. +    if (C->isInfinity()) { +      if (C->isNegative()) { +        switch (Pred) { +        case FCmpInst::FCMP_OLT: +          // No value is ordered and less than negative infinity. +          return getFalse(RetTy); +        case FCmpInst::FCMP_UGE: +          // All values are unordered with or at least negative infinity. +          return getTrue(RetTy); +        default: +          break; +        } +      } else { +        switch (Pred) { +        case FCmpInst::FCMP_OGT: +          // No value is ordered and greater than infinity. +          return getFalse(RetTy); +        case FCmpInst::FCMP_ULE: +          // All values are unordered with and at most infinity. +          return getTrue(RetTy); +        default: +          break; +        } +      } +    } +    if (C->isZero()) { +      switch (Pred) { +      case FCmpInst::FCMP_UGE: +        if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) +          return getTrue(RetTy); +        break; +      case FCmpInst::FCMP_OLT: +        // X < 0 +        if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) +          return getFalse(RetTy); +        break; +      default: +        break; +      } +    } else if (C->isNegative()) { +      assert(!C->isNaN() && "Unexpected NaN constant!"); +      // TODO: We can catch more cases by using a range check rather than +      //       relying on CannotBeOrderedLessThanZero. +      switch (Pred) { +      case FCmpInst::FCMP_UGE: +      case FCmpInst::FCMP_UGT: +      case FCmpInst::FCMP_UNE: +        // (X >= 0) implies (X > C) when (C < 0) +        if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) +          return getTrue(RetTy); +        break; +      case FCmpInst::FCMP_OEQ: +      case FCmpInst::FCMP_OLE: +      case FCmpInst::FCMP_OLT: +        // (X >= 0) implies !(X < C) when (C < 0) +        if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) +          return getFalse(RetTy); +        break; +      default: +        break; +      } +    } +  } + +  // If the comparison is with the result of a select instruction, check whether +  // comparing with either branch of the select always yields the same value. +  if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) +    if (Value *V = ThreadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) +      return V; + +  // If the comparison is with the result of a phi instruction, check whether +  // doing the compare with each incoming phi value yields a common result. +  if (isa<PHINode>(LHS) || isa<PHINode>(RHS)) +    if (Value *V = ThreadCmpOverPHI(Pred, LHS, RHS, Q, MaxRecurse)) +      return V; + +  return nullptr; +} + +Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +                              FastMathFlags FMF, const SimplifyQuery &Q) { +  return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); +} + +/// See if V simplifies when its operand Op is replaced with RepOp. +static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, +                                           const SimplifyQuery &Q, +                                           unsigned MaxRecurse) { +  // Trivial replacement. +  if (V == Op) +    return RepOp; + +  // We cannot replace a constant, and shouldn't even try. +  if (isa<Constant>(Op)) +    return nullptr; + +  auto *I = dyn_cast<Instruction>(V); +  if (!I) +    return nullptr; + +  // If this is a binary operator, try to simplify it with the replaced op. +  if (auto *B = dyn_cast<BinaryOperator>(I)) { +    // Consider: +    //   %cmp = icmp eq i32 %x, 2147483647 +    //   %add = add nsw i32 %x, 1 +    //   %sel = select i1 %cmp, i32 -2147483648, i32 %add +    // +    // We can't replace %sel with %add unless we strip away the flags. +    if (isa<OverflowingBinaryOperator>(B)) +      if (B->hasNoSignedWrap() || B->hasNoUnsignedWrap()) +        return nullptr; +    if (isa<PossiblyExactOperator>(B)) +      if (B->isExact()) +        return nullptr; + +    if (MaxRecurse) { +      if (B->getOperand(0) == Op) +        return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q, +                             MaxRecurse - 1); +      if (B->getOperand(1) == Op) +        return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, Q, +                             MaxRecurse - 1); +    } +  } + +  // Same for CmpInsts. +  if (CmpInst *C = dyn_cast<CmpInst>(I)) { +    if (MaxRecurse) { +      if (C->getOperand(0) == Op) +        return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), Q, +                               MaxRecurse - 1); +      if (C->getOperand(1) == Op) +        return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, Q, +                               MaxRecurse - 1); +    } +  } + +  // Same for GEPs. +  if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { +    if (MaxRecurse) { +      SmallVector<Value *, 8> NewOps(GEP->getNumOperands()); +      transform(GEP->operands(), NewOps.begin(), +                [&](Value *V) { return V == Op ? RepOp : V; }); +      return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, +                             MaxRecurse - 1); +    } +  } + +  // TODO: We could hand off more cases to instsimplify here. + +  // If all operands are constant after substituting Op for RepOp then we can +  // constant fold the instruction. +  if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) { +    // Build a list of all constant operands. +    SmallVector<Constant *, 8> ConstOps; +    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { +      if (I->getOperand(i) == Op) +        ConstOps.push_back(CRepOp); +      else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i))) +        ConstOps.push_back(COp); +      else +        break; +    } + +    // All operands were constants, fold it. +    if (ConstOps.size() == I->getNumOperands()) { +      if (CmpInst *C = dyn_cast<CmpInst>(I)) +        return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], +                                               ConstOps[1], Q.DL, Q.TLI); + +      if (LoadInst *LI = dyn_cast<LoadInst>(I)) +        if (!LI->isVolatile()) +          return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); + +      return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); +    } +  } + +  return nullptr; +} + +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison where one operand of the compare is a constant. +static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, +                                    const APInt *Y, bool TrueWhenUnset) { +  const APInt *C; + +  // (X & Y) == 0 ? X & ~Y : X  --> X +  // (X & Y) != 0 ? X & ~Y : X  --> X & ~Y +  if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && +      *Y == ~*C) +    return TrueWhenUnset ? FalseVal : TrueVal; + +  // (X & Y) == 0 ? X : X & ~Y  --> X & ~Y +  // (X & Y) != 0 ? X : X & ~Y  --> X +  if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && +      *Y == ~*C) +    return TrueWhenUnset ? FalseVal : TrueVal; + +  if (Y->isPowerOf2()) { +    // (X & Y) == 0 ? X | Y : X  --> X | Y +    // (X & Y) != 0 ? X | Y : X  --> X +    if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && +        *Y == *C) +      return TrueWhenUnset ? TrueVal : FalseVal; + +    // (X & Y) == 0 ? X : X | Y  --> X +    // (X & Y) != 0 ? X : X | Y  --> X | Y +    if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && +        *Y == *C) +      return TrueWhenUnset ? TrueVal : FalseVal; +  } + +  return nullptr; +} + +/// An alternative way to test if a bit is set or not uses sgt/slt instead of +/// eq/ne. +static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, +                                           ICmpInst::Predicate Pred, +                                           Value *TrueVal, Value *FalseVal) { +  Value *X; +  APInt Mask; +  if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask)) +    return nullptr; + +  return simplifySelectBitTest(TrueVal, FalseVal, X, &Mask, +                               Pred == ICmpInst::ICMP_EQ); +} + +/// Try to simplify a select instruction when its condition operand is an +/// integer comparison. +static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, +                                         Value *FalseVal, const SimplifyQuery &Q, +                                         unsigned MaxRecurse) { +  ICmpInst::Predicate Pred; +  Value *CmpLHS, *CmpRHS; +  if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) +    return nullptr; + +  if (ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero())) { +    Value *X; +    const APInt *Y; +    if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y)))) +      if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y, +                                           Pred == ICmpInst::ICMP_EQ)) +        return V; +  } + +  // Check for other compares that behave like bit test. +  if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, +                                              TrueVal, FalseVal)) +    return V; + +  // If we have an equality comparison, then we know the value in one of the +  // arms of the select. See if substituting this value into the arm and +  // simplifying the result yields the same value as the other arm. +  if (Pred == ICmpInst::ICMP_EQ) { +    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == +            TrueVal || +        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == +            TrueVal) +      return FalseVal; +    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == +            FalseVal || +        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == +            FalseVal) +      return FalseVal; +  } else if (Pred == ICmpInst::ICMP_NE) { +    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == +            FalseVal || +        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == +            FalseVal) +      return TrueVal; +    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == +            TrueVal || +        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == +            TrueVal) +      return TrueVal; +  } + +  return nullptr; +} + +/// Given operands for a SelectInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, +                                 const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (auto *CondC = dyn_cast<Constant>(Cond)) { +    if (auto *TrueC = dyn_cast<Constant>(TrueVal)) +      if (auto *FalseC = dyn_cast<Constant>(FalseVal)) +        return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); + +    // select undef, X, Y -> X or Y +    if (isa<UndefValue>(CondC)) +      return isa<Constant>(FalseVal) ? FalseVal : TrueVal; + +    // TODO: Vector constants with undef elements don't simplify. + +    // select true, X, Y  -> X +    if (CondC->isAllOnesValue()) +      return TrueVal; +    // select false, X, Y -> Y +    if (CondC->isNullValue()) +      return FalseVal; +  } + +  // select ?, X, X -> X +  if (TrueVal == FalseVal) +    return TrueVal; + +  if (isa<UndefValue>(TrueVal))   // select ?, undef, X -> X +    return FalseVal; +  if (isa<UndefValue>(FalseVal))   // select ?, X, undef -> X +    return TrueVal; + +  if (Value *V = +          simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse)) +    return V; + +  if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal)) +    return V; + +  return nullptr; +} + +Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, +                                const SimplifyQuery &Q) { +  return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Q, RecursionLimit); +} + +/// Given operands for an GetElementPtrInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, +                              const SimplifyQuery &Q, unsigned) { +  // The type of the GEP pointer operand. +  unsigned AS = +      cast<PointerType>(Ops[0]->getType()->getScalarType())->getAddressSpace(); + +  // getelementptr P -> P. +  if (Ops.size() == 1) +    return Ops[0]; + +  // Compute the (pointer) type returned by the GEP instruction. +  Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Ops.slice(1)); +  Type *GEPTy = PointerType::get(LastType, AS); +  if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) +    GEPTy = VectorType::get(GEPTy, VT->getNumElements()); +  else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType())) +    GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + +  if (isa<UndefValue>(Ops[0])) +    return UndefValue::get(GEPTy); + +  if (Ops.size() == 2) { +    // getelementptr P, 0 -> P. +    if (match(Ops[1], m_Zero()) && Ops[0]->getType() == GEPTy) +      return Ops[0]; + +    Type *Ty = SrcTy; +    if (Ty->isSized()) { +      Value *P; +      uint64_t C; +      uint64_t TyAllocSize = Q.DL.getTypeAllocSize(Ty); +      // getelementptr P, N -> P if P points to a type of zero size. +      if (TyAllocSize == 0 && Ops[0]->getType() == GEPTy) +        return Ops[0]; + +      // The following transforms are only safe if the ptrtoint cast +      // doesn't truncate the pointers. +      if (Ops[1]->getType()->getScalarSizeInBits() == +          Q.DL.getIndexSizeInBits(AS)) { +        auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { +          if (match(P, m_Zero())) +            return Constant::getNullValue(GEPTy); +          Value *Temp; +          if (match(P, m_PtrToInt(m_Value(Temp)))) +            if (Temp->getType() == GEPTy) +              return Temp; +          return nullptr; +        }; + +        // getelementptr V, (sub P, V) -> P if P points to a type of size 1. +        if (TyAllocSize == 1 && +            match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))))) +          if (Value *R = PtrToIntOrZero(P)) +            return R; + +        // getelementptr V, (ashr (sub P, V), C) -> Q +        // if P points to a type of size 1 << C. +        if (match(Ops[1], +                  m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), +                         m_ConstantInt(C))) && +            TyAllocSize == 1ULL << C) +          if (Value *R = PtrToIntOrZero(P)) +            return R; + +        // getelementptr V, (sdiv (sub P, V), C) -> Q +        // if P points to a type of size C. +        if (match(Ops[1], +                  m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), +                         m_SpecificInt(TyAllocSize)))) +          if (Value *R = PtrToIntOrZero(P)) +            return R; +      } +    } +  } + +  if (Q.DL.getTypeAllocSize(LastType) == 1 && +      all_of(Ops.slice(1).drop_back(1), +             [](Value *Idx) { return match(Idx, m_Zero()); })) { +    unsigned IdxWidth = +        Q.DL.getIndexSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); +    if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == IdxWidth) { +      APInt BasePtrOffset(IdxWidth, 0); +      Value *StrippedBasePtr = +          Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, +                                                            BasePtrOffset); + +      // gep (gep V, C), (sub 0, V) -> C +      if (match(Ops.back(), +                m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr))))) { +        auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset); +        return ConstantExpr::getIntToPtr(CI, GEPTy); +      } +      // gep (gep V, C), (xor V, -1) -> C-1 +      if (match(Ops.back(), +                m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes()))) { +        auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1); +        return ConstantExpr::getIntToPtr(CI, GEPTy); +      } +    } +  } + +  // Check to see if this is constant foldable. +  if (!all_of(Ops, [](Value *V) { return isa<Constant>(V); })) +    return nullptr; + +  auto *CE = ConstantExpr::getGetElementPtr(SrcTy, cast<Constant>(Ops[0]), +                                            Ops.slice(1)); +  if (auto *CEFolded = ConstantFoldConstant(CE, Q.DL)) +    return CEFolded; +  return CE; +} + +Value *llvm::SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, +                             const SimplifyQuery &Q) { +  return ::SimplifyGEPInst(SrcTy, Ops, Q, RecursionLimit); +} + +/// Given operands for an InsertValueInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, +                                      ArrayRef<unsigned> Idxs, const SimplifyQuery &Q, +                                      unsigned) { +  if (Constant *CAgg = dyn_cast<Constant>(Agg)) +    if (Constant *CVal = dyn_cast<Constant>(Val)) +      return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs); + +  // insertvalue x, undef, n -> x +  if (match(Val, m_Undef())) +    return Agg; + +  // insertvalue x, (extractvalue y, n), n +  if (ExtractValueInst *EV = dyn_cast<ExtractValueInst>(Val)) +    if (EV->getAggregateOperand()->getType() == Agg->getType() && +        EV->getIndices() == Idxs) { +      // insertvalue undef, (extractvalue y, n), n -> y +      if (match(Agg, m_Undef())) +        return EV->getAggregateOperand(); + +      // insertvalue y, (extractvalue y, n), n -> y +      if (Agg == EV->getAggregateOperand()) +        return Agg; +    } + +  return nullptr; +} + +Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, +                                     ArrayRef<unsigned> Idxs, +                                     const SimplifyQuery &Q) { +  return ::SimplifyInsertValueInst(Agg, Val, Idxs, Q, RecursionLimit); +} + +Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, +                                       const SimplifyQuery &Q) { +  // Try to constant fold. +  auto *VecC = dyn_cast<Constant>(Vec); +  auto *ValC = dyn_cast<Constant>(Val); +  auto *IdxC = dyn_cast<Constant>(Idx); +  if (VecC && ValC && IdxC) +    return ConstantFoldInsertElementInstruction(VecC, ValC, IdxC); + +  // Fold into undef if index is out of bounds. +  if (auto *CI = dyn_cast<ConstantInt>(Idx)) { +    uint64_t NumElements = cast<VectorType>(Vec->getType())->getNumElements(); +    if (CI->uge(NumElements)) +      return UndefValue::get(Vec->getType()); +  } + +  // If index is undef, it might be out of bounds (see above case) +  if (isa<UndefValue>(Idx)) +    return UndefValue::get(Vec->getType()); + +  return nullptr; +} + +/// Given operands for an ExtractValueInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, +                                       const SimplifyQuery &, unsigned) { +  if (auto *CAgg = dyn_cast<Constant>(Agg)) +    return ConstantFoldExtractValueInstruction(CAgg, Idxs); + +  // extractvalue x, (insertvalue y, elt, n), n -> elt +  unsigned NumIdxs = Idxs.size(); +  for (auto *IVI = dyn_cast<InsertValueInst>(Agg); IVI != nullptr; +       IVI = dyn_cast<InsertValueInst>(IVI->getAggregateOperand())) { +    ArrayRef<unsigned> InsertValueIdxs = IVI->getIndices(); +    unsigned NumInsertValueIdxs = InsertValueIdxs.size(); +    unsigned NumCommonIdxs = std::min(NumInsertValueIdxs, NumIdxs); +    if (InsertValueIdxs.slice(0, NumCommonIdxs) == +        Idxs.slice(0, NumCommonIdxs)) { +      if (NumIdxs == NumInsertValueIdxs) +        return IVI->getInsertedValueOperand(); +      break; +    } +  } + +  return nullptr; +} + +Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, +                                      const SimplifyQuery &Q) { +  return ::SimplifyExtractValueInst(Agg, Idxs, Q, RecursionLimit); +} + +/// Given operands for an ExtractElementInst, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &, +                                         unsigned) { +  if (auto *CVec = dyn_cast<Constant>(Vec)) { +    if (auto *CIdx = dyn_cast<Constant>(Idx)) +      return ConstantFoldExtractElementInstruction(CVec, CIdx); + +    // The index is not relevant if our vector is a splat. +    if (auto *Splat = CVec->getSplatValue()) +      return Splat; + +    if (isa<UndefValue>(Vec)) +      return UndefValue::get(Vec->getType()->getVectorElementType()); +  } + +  // If extracting a specified index from the vector, see if we can recursively +  // find a previously computed scalar that was inserted into the vector. +  if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) { +    if (IdxC->getValue().uge(Vec->getType()->getVectorNumElements())) +      // definitely out of bounds, thus undefined result +      return UndefValue::get(Vec->getType()->getVectorElementType()); +    if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) +      return Elt; +  } + +  // An undef extract index can be arbitrarily chosen to be an out-of-range +  // index value, which would result in the instruction being undef. +  if (isa<UndefValue>(Idx)) +    return UndefValue::get(Vec->getType()->getVectorElementType()); + +  return nullptr; +} + +Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, +                                        const SimplifyQuery &Q) { +  return ::SimplifyExtractElementInst(Vec, Idx, Q, RecursionLimit); +} + +/// See if we can fold the given phi. If not, returns null. +static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { +  // If all of the PHI's incoming values are the same then replace the PHI node +  // with the common value. +  Value *CommonValue = nullptr; +  bool HasUndefInput = false; +  for (Value *Incoming : PN->incoming_values()) { +    // If the incoming value is the phi node itself, it can safely be skipped. +    if (Incoming == PN) continue; +    if (isa<UndefValue>(Incoming)) { +      // Remember that we saw an undef value, but otherwise ignore them. +      HasUndefInput = true; +      continue; +    } +    if (CommonValue && Incoming != CommonValue) +      return nullptr;  // Not the same, bail out. +    CommonValue = Incoming; +  } + +  // If CommonValue is null then all of the incoming values were either undef or +  // equal to the phi node itself. +  if (!CommonValue) +    return UndefValue::get(PN->getType()); + +  // If we have a PHI node like phi(X, undef, X), where X is defined by some +  // instruction, we cannot return X as the result of the PHI node unless it +  // dominates the PHI block. +  if (HasUndefInput) +    return valueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; + +  return CommonValue; +} + +static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, +                               Type *Ty, const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (auto *C = dyn_cast<Constant>(Op)) +    return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); + +  if (auto *CI = dyn_cast<CastInst>(Op)) { +    auto *Src = CI->getOperand(0); +    Type *SrcTy = Src->getType(); +    Type *MidTy = CI->getType(); +    Type *DstTy = Ty; +    if (Src->getType() == Ty) { +      auto FirstOp = static_cast<Instruction::CastOps>(CI->getOpcode()); +      auto SecondOp = static_cast<Instruction::CastOps>(CastOpc); +      Type *SrcIntPtrTy = +          SrcTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(SrcTy) : nullptr; +      Type *MidIntPtrTy = +          MidTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(MidTy) : nullptr; +      Type *DstIntPtrTy = +          DstTy->isPtrOrPtrVectorTy() ? Q.DL.getIntPtrType(DstTy) : nullptr; +      if (CastInst::isEliminableCastPair(FirstOp, SecondOp, SrcTy, MidTy, DstTy, +                                         SrcIntPtrTy, MidIntPtrTy, +                                         DstIntPtrTy) == Instruction::BitCast) +        return Src; +    } +  } + +  // bitcast x -> x +  if (CastOpc == Instruction::BitCast) +    if (Op->getType() == Ty) +      return Op; + +  return nullptr; +} + +Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, +                              const SimplifyQuery &Q) { +  return ::SimplifyCastInst(CastOpc, Op, Ty, Q, RecursionLimit); +} + +/// For the given destination element of a shuffle, peek through shuffles to +/// match a root vector source operand that contains that element in the same +/// vector lane (ie, the same mask index), so we can eliminate the shuffle(s). +static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, +                                   int MaskVal, Value *RootVec, +                                   unsigned MaxRecurse) { +  if (!MaxRecurse--) +    return nullptr; + +  // Bail out if any mask value is undefined. That kind of shuffle may be +  // simplified further based on demanded bits or other folds. +  if (MaskVal == -1) +    return nullptr; + +  // The mask value chooses which source operand we need to look at next. +  int InVecNumElts = Op0->getType()->getVectorNumElements(); +  int RootElt = MaskVal; +  Value *SourceOp = Op0; +  if (MaskVal >= InVecNumElts) { +    RootElt = MaskVal - InVecNumElts; +    SourceOp = Op1; +  } + +  // If the source operand is a shuffle itself, look through it to find the +  // matching root vector. +  if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) { +    return foldIdentityShuffles( +        DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1), +        SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse); +  } + +  // TODO: Look through bitcasts? What if the bitcast changes the vector element +  // size? + +  // The source operand is not a shuffle. Initialize the root vector value for +  // this shuffle if that has not been done yet. +  if (!RootVec) +    RootVec = SourceOp; + +  // Give up as soon as a source operand does not match the existing root value. +  if (RootVec != SourceOp) +    return nullptr; + +  // The element must be coming from the same lane in the source vector +  // (although it may have crossed lanes in intermediate shuffles). +  if (RootElt != DestElt) +    return nullptr; + +  return RootVec; +} + +static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, +                                        Type *RetTy, const SimplifyQuery &Q, +                                        unsigned MaxRecurse) { +  if (isa<UndefValue>(Mask)) +    return UndefValue::get(RetTy); + +  Type *InVecTy = Op0->getType(); +  unsigned MaskNumElts = Mask->getType()->getVectorNumElements(); +  unsigned InVecNumElts = InVecTy->getVectorNumElements(); + +  SmallVector<int, 32> Indices; +  ShuffleVectorInst::getShuffleMask(Mask, Indices); +  assert(MaskNumElts == Indices.size() && +         "Size of Indices not same as number of mask elements?"); + +  // Canonicalization: If mask does not select elements from an input vector, +  // replace that input vector with undef. +  bool MaskSelects0 = false, MaskSelects1 = false; +  for (unsigned i = 0; i != MaskNumElts; ++i) { +    if (Indices[i] == -1) +      continue; +    if ((unsigned)Indices[i] < InVecNumElts) +      MaskSelects0 = true; +    else +      MaskSelects1 = true; +  } +  if (!MaskSelects0) +    Op0 = UndefValue::get(InVecTy); +  if (!MaskSelects1) +    Op1 = UndefValue::get(InVecTy); + +  auto *Op0Const = dyn_cast<Constant>(Op0); +  auto *Op1Const = dyn_cast<Constant>(Op1); + +  // If all operands are constant, constant fold the shuffle. +  if (Op0Const && Op1Const) +    return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); + +  // Canonicalization: if only one input vector is constant, it shall be the +  // second one. +  if (Op0Const && !Op1Const) { +    std::swap(Op0, Op1); +    ShuffleVectorInst::commuteShuffleMask(Indices, InVecNumElts); +  } + +  // A shuffle of a splat is always the splat itself. Legal if the shuffle's +  // value type is same as the input vectors' type. +  if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0)) +    if (isa<UndefValue>(Op1) && RetTy == InVecTy && +        OpShuf->getMask()->getSplatValue()) +      return Op0; + +  // Don't fold a shuffle with undef mask elements. This may get folded in a +  // better way using demanded bits or other analysis. +  // TODO: Should we allow this? +  if (find(Indices, -1) != Indices.end()) +    return nullptr; + +  // Check if every element of this shuffle can be mapped back to the +  // corresponding element of a single root vector. If so, we don't need this +  // shuffle. This handles simple identity shuffles as well as chains of +  // shuffles that may widen/narrow and/or move elements across lanes and back. +  Value *RootVec = nullptr; +  for (unsigned i = 0; i != MaskNumElts; ++i) { +    // Note that recursion is limited for each vector element, so if any element +    // exceeds the limit, this will fail to simplify. +    RootVec = +        foldIdentityShuffles(i, Op0, Op1, Indices[i], RootVec, MaxRecurse); + +    // We can't replace a widening/narrowing shuffle with one of its operands. +    if (!RootVec || RootVec->getType() != RetTy) +      return nullptr; +  } +  return RootVec; +} + +/// Given operands for a ShuffleVectorInst, fold the result or return null. +Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, +                                       Type *RetTy, const SimplifyQuery &Q) { +  return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); +} + +static Constant *propagateNaN(Constant *In) { +  // If the input is a vector with undef elements, just return a default NaN. +  if (!In->isNaN()) +    return ConstantFP::getNaN(In->getType()); + +  // Propagate the existing NaN constant when possible. +  // TODO: Should we quiet a signaling NaN? +  return In; +} + +static Constant *simplifyFPBinop(Value *Op0, Value *Op1) { +  if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) +    return ConstantFP::getNaN(Op0->getType()); + +  if (match(Op0, m_NaN())) +    return propagateNaN(cast<Constant>(Op0)); +  if (match(Op1, m_NaN())) +    return propagateNaN(cast<Constant>(Op1)); + +  return nullptr; +} + +/// Given operands for an FAdd, see if we can fold the result.  If not, this +/// returns null. +static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                               const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) +    return C; + +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; + +  // fadd X, -0 ==> X +  if (match(Op1, m_NegZeroFP())) +    return Op0; + +  // fadd X, 0 ==> X, when we know X is not -0 +  if (match(Op1, m_PosZeroFP()) && +      (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) +    return Op0; + +  // With nnan: (+/-0.0 - X) + X --> 0.0 (and commuted variant) +  // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN. +  // Negative zeros are allowed because we always end up with positive zero: +  // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 +  // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 +  // X =  0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 +  // X =  0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 +  if (FMF.noNaNs() && (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || +                       match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))) +    return ConstantFP::getNullValue(Op0->getType()); + +  return nullptr; +} + +/// Given operands for an FSub, see if we can fold the result.  If not, this +/// returns null. +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                               const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) +    return C; + +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; + +  // fsub X, +0 ==> X +  if (match(Op1, m_PosZeroFP())) +    return Op0; + +  // fsub X, -0 ==> X, when we know X is not -0 +  if (match(Op1, m_NegZeroFP()) && +      (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) +    return Op0; + +  // fsub -0.0, (fsub -0.0, X) ==> X +  Value *X; +  if (match(Op0, m_NegZeroFP()) && +      match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) +    return X; + +  // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. +  if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && +      match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) +    return X; + +  // fsub nnan x, x ==> 0.0 +  if (FMF.noNaNs() && Op0 == Op1) +    return Constant::getNullValue(Op0->getType()); + +  return nullptr; +} + +/// Given the operands for an FMul, see if we can fold the result +static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                               const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q)) +    return C; + +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; + +  // fmul X, 1.0 ==> X +  if (match(Op1, m_FPOne())) +    return Op0; + +  // fmul nnan nsz X, 0 ==> 0 +  if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP())) +    return ConstantFP::getNullValue(Op0->getType()); + +  // sqrt(X) * sqrt(X) --> X, if we can: +  // 1. Remove the intermediate rounding (reassociate). +  // 2. Ignore non-zero negative numbers because sqrt would produce NAN. +  // 3. Ignore -0.0 because sqrt(-0.0) == -0.0, but -0.0 * -0.0 == 0.0. +  Value *X; +  if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && +      FMF.allowReassoc() && FMF.noNaNs() && FMF.noSignedZeros()) +    return X; + +  return nullptr; +} + +Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                              const SimplifyQuery &Q) { +  return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); +} + + +Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                              const SimplifyQuery &Q) { +  return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                              const SimplifyQuery &Q) { +  return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                               const SimplifyQuery &Q, unsigned) { +  if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) +    return C; + +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; + +  // X / 1.0 -> X +  if (match(Op1, m_FPOne())) +    return Op0; + +  // 0 / X -> 0 +  // Requires that NaNs are off (X could be zero) and signed zeroes are +  // ignored (X could be positive or negative, so the output sign is unknown). +  if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) +    return ConstantFP::getNullValue(Op0->getType()); + +  if (FMF.noNaNs()) { +    // X / X -> 1.0 is legal when NaNs are ignored. +    // We can ignore infinities because INF/INF is NaN. +    if (Op0 == Op1) +      return ConstantFP::get(Op0->getType(), 1.0); + +    // (X * Y) / Y --> X if we can reassociate to the above form. +    Value *X; +    if (FMF.allowReassoc() && match(Op0, m_c_FMul(m_Value(X), m_Specific(Op1)))) +      return X; + +    // -X /  X -> -1.0 and +    //  X / -X -> -1.0 are legal when NaNs are ignored. +    // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored. +    if ((BinaryOperator::isFNeg(Op0, /*IgnoreZeroSign=*/true) && +         BinaryOperator::getFNegArgument(Op0) == Op1) || +        (BinaryOperator::isFNeg(Op1, /*IgnoreZeroSign=*/true) && +         BinaryOperator::getFNegArgument(Op1) == Op0)) +      return ConstantFP::get(Op0->getType(), -1.0); +  } + +  return nullptr; +} + +Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                              const SimplifyQuery &Q) { +  return ::SimplifyFDivInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                               const SimplifyQuery &Q, unsigned) { +  if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) +    return C; + +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; + +  // Unlike fdiv, the result of frem always matches the sign of the dividend. +  // The constant match may include undef elements in a vector, so return a full +  // zero constant as the result. +  if (FMF.noNaNs()) { +    // +0 % X -> 0 +    if (match(Op0, m_PosZeroFP())) +      return ConstantFP::getNullValue(Op0->getType()); +    // -0 % X -> -0 +    if (match(Op0, m_NegZeroFP())) +      return ConstantFP::getNegativeZero(Op0->getType()); +  } + +  return nullptr; +} + +Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, +                              const SimplifyQuery &Q) { +  return ::SimplifyFRemInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +//=== Helper functions for higher up the class hierarchy. + +/// Given operands for a BinaryOperator, see if we can fold the result. +/// If not, this returns null. +static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, +                            const SimplifyQuery &Q, unsigned MaxRecurse) { +  switch (Opcode) { +  case Instruction::Add: +    return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); +  case Instruction::Sub: +    return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); +  case Instruction::Mul: +    return SimplifyMulInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::SDiv: +    return SimplifySDivInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::UDiv: +    return SimplifyUDivInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::SRem: +    return SimplifySRemInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::URem: +    return SimplifyURemInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::Shl: +    return SimplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse); +  case Instruction::LShr: +    return SimplifyLShrInst(LHS, RHS, false, Q, MaxRecurse); +  case Instruction::AShr: +    return SimplifyAShrInst(LHS, RHS, false, Q, MaxRecurse); +  case Instruction::And: +    return SimplifyAndInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::Or: +    return SimplifyOrInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::Xor: +    return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); +  case Instruction::FAdd: +    return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); +  case Instruction::FSub: +    return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); +  case Instruction::FMul: +    return SimplifyFMulInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); +  case Instruction::FDiv: +    return SimplifyFDivInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); +  case Instruction::FRem: +    return SimplifyFRemInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); +  default: +    llvm_unreachable("Unexpected opcode"); +  } +} + +/// Given operands for a BinaryOperator, see if we can fold the result. +/// If not, this returns null. +/// In contrast to SimplifyBinOp, try to use FastMathFlag when folding the +/// result. In case we don't need FastMathFlags, simply fall to SimplifyBinOp. +static Value *SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, +                              const FastMathFlags &FMF, const SimplifyQuery &Q, +                              unsigned MaxRecurse) { +  switch (Opcode) { +  case Instruction::FAdd: +    return SimplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse); +  case Instruction::FSub: +    return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); +  case Instruction::FMul: +    return SimplifyFMulInst(LHS, RHS, FMF, Q, MaxRecurse); +  case Instruction::FDiv: +    return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); +  default: +    return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); +  } +} + +Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, +                           const SimplifyQuery &Q) { +  return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); +} + +Value *llvm::SimplifyFPBinOp(unsigned Opcode, Value *LHS, Value *RHS, +                             FastMathFlags FMF, const SimplifyQuery &Q) { +  return ::SimplifyFPBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); +} + +/// Given operands for a CmpInst, see if we can fold the result. +static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +                              const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate)) +    return SimplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); +  return SimplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); +} + +Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, +                             const SimplifyQuery &Q) { +  return ::SimplifyCmpInst(Predicate, LHS, RHS, Q, RecursionLimit); +} + +static bool IsIdempotent(Intrinsic::ID ID) { +  switch (ID) { +  default: return false; + +  // Unary idempotent: f(f(x)) = f(x) +  case Intrinsic::fabs: +  case Intrinsic::floor: +  case Intrinsic::ceil: +  case Intrinsic::trunc: +  case Intrinsic::rint: +  case Intrinsic::nearbyint: +  case Intrinsic::round: +  case Intrinsic::canonicalize: +    return true; +  } +} + +static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset, +                                   const DataLayout &DL) { +  GlobalValue *PtrSym; +  APInt PtrOffset; +  if (!IsConstantOffsetFromGlobal(Ptr, PtrSym, PtrOffset, DL)) +    return nullptr; + +  Type *Int8PtrTy = Type::getInt8PtrTy(Ptr->getContext()); +  Type *Int32Ty = Type::getInt32Ty(Ptr->getContext()); +  Type *Int32PtrTy = Int32Ty->getPointerTo(); +  Type *Int64Ty = Type::getInt64Ty(Ptr->getContext()); + +  auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset); +  if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64) +    return nullptr; + +  uint64_t OffsetInt = OffsetConstInt->getSExtValue(); +  if (OffsetInt % 4 != 0) +    return nullptr; + +  Constant *C = ConstantExpr::getGetElementPtr( +      Int32Ty, ConstantExpr::getBitCast(Ptr, Int32PtrTy), +      ConstantInt::get(Int64Ty, OffsetInt / 4)); +  Constant *Loaded = ConstantFoldLoadFromConstPtr(C, Int32Ty, DL); +  if (!Loaded) +    return nullptr; + +  auto *LoadedCE = dyn_cast<ConstantExpr>(Loaded); +  if (!LoadedCE) +    return nullptr; + +  if (LoadedCE->getOpcode() == Instruction::Trunc) { +    LoadedCE = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0)); +    if (!LoadedCE) +      return nullptr; +  } + +  if (LoadedCE->getOpcode() != Instruction::Sub) +    return nullptr; + +  auto *LoadedLHS = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0)); +  if (!LoadedLHS || LoadedLHS->getOpcode() != Instruction::PtrToInt) +    return nullptr; +  auto *LoadedLHSPtr = LoadedLHS->getOperand(0); + +  Constant *LoadedRHS = LoadedCE->getOperand(1); +  GlobalValue *LoadedRHSSym; +  APInt LoadedRHSOffset; +  if (!IsConstantOffsetFromGlobal(LoadedRHS, LoadedRHSSym, LoadedRHSOffset, +                                  DL) || +      PtrSym != LoadedRHSSym || PtrOffset != LoadedRHSOffset) +    return nullptr; + +  return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy); +} + +static bool maskIsAllZeroOrUndef(Value *Mask) { +  auto *ConstMask = dyn_cast<Constant>(Mask); +  if (!ConstMask) +    return false; +  if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask)) +    return true; +  for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; +       ++I) { +    if (auto *MaskElt = ConstMask->getAggregateElement(I)) +      if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt)) +        continue; +    return false; +  } +  return true; +} + +static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, +                                     const SimplifyQuery &Q) { +  // Idempotent functions return the same result when called repeatedly. +  Intrinsic::ID IID = F->getIntrinsicID(); +  if (IsIdempotent(IID)) +    if (auto *II = dyn_cast<IntrinsicInst>(Op0)) +      if (II->getIntrinsicID() == IID) +        return II; + +  Value *X; +  switch (IID) { +  case Intrinsic::fabs: +    if (SignBitMustBeZero(Op0, Q.TLI)) return Op0; +    break; +  case Intrinsic::bswap: +    // bswap(bswap(x)) -> x +    if (match(Op0, m_BSwap(m_Value(X)))) return X; +    break; +  case Intrinsic::bitreverse: +    // bitreverse(bitreverse(x)) -> x +    if (match(Op0, m_BitReverse(m_Value(X)))) return X; +    break; +  case Intrinsic::exp: +    // exp(log(x)) -> x +    if (Q.CxtI->hasAllowReassoc() && +        match(Op0, m_Intrinsic<Intrinsic::log>(m_Value(X)))) return X; +    break; +  case Intrinsic::exp2: +    // exp2(log2(x)) -> x +    if (Q.CxtI->hasAllowReassoc() && +        match(Op0, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) return X; +    break; +  case Intrinsic::log: +    // log(exp(x)) -> x +    if (Q.CxtI->hasAllowReassoc() && +        match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) return X; +    break; +  case Intrinsic::log2: +    // log2(exp2(x)) -> x +    if (Q.CxtI->hasAllowReassoc() && +        match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X)))) return X; +    break; +  default: +    break; +  } + +  return nullptr; +} + +static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, +                                      const SimplifyQuery &Q) { +  Intrinsic::ID IID = F->getIntrinsicID(); +  Type *ReturnType = F->getReturnType(); +  switch (IID) { +  case Intrinsic::usub_with_overflow: +  case Intrinsic::ssub_with_overflow: +    // X - X -> { 0, false } +    if (Op0 == Op1) +      return Constant::getNullValue(ReturnType); +    // X - undef -> undef +    // undef - X -> undef +    if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) +      return UndefValue::get(ReturnType); +    break; +  case Intrinsic::uadd_with_overflow: +  case Intrinsic::sadd_with_overflow: +    // X + undef -> undef +    if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) +      return UndefValue::get(ReturnType); +    break; +  case Intrinsic::umul_with_overflow: +  case Intrinsic::smul_with_overflow: +    // 0 * X -> { 0, false } +    // X * 0 -> { 0, false } +    if (match(Op0, m_Zero()) || match(Op1, m_Zero())) +      return Constant::getNullValue(ReturnType); +    // undef * X -> { 0, false } +    // X * undef -> { 0, false } +    if (match(Op0, m_Undef()) || match(Op1, m_Undef())) +      return Constant::getNullValue(ReturnType); +    break; +  case Intrinsic::load_relative: +    if (auto *C0 = dyn_cast<Constant>(Op0)) +      if (auto *C1 = dyn_cast<Constant>(Op1)) +        return SimplifyRelativeLoad(C0, C1, Q.DL); +    break; +  case Intrinsic::powi: +    if (auto *Power = dyn_cast<ConstantInt>(Op1)) { +      // powi(x, 0) -> 1.0 +      if (Power->isZero()) +        return ConstantFP::get(Op0->getType(), 1.0); +      // powi(x, 1) -> x +      if (Power->isOne()) +        return Op0; +    } +    break; +  case Intrinsic::maxnum: +  case Intrinsic::minnum: +    // If one argument is NaN, return the other argument. +    if (match(Op0, m_NaN())) return Op1; +    if (match(Op1, m_NaN())) return Op0; +    break; +  default: +    break; +  } + +  return nullptr; +} + +template <typename IterTy> +static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, +                                const SimplifyQuery &Q) { +  // Intrinsics with no operands have some kind of side effect. Don't simplify. +  unsigned NumOperands = std::distance(ArgBegin, ArgEnd); +  if (NumOperands == 0) +    return nullptr; + +  Intrinsic::ID IID = F->getIntrinsicID(); +  if (NumOperands == 1) +    return simplifyUnaryIntrinsic(F, ArgBegin[0], Q); + +  if (NumOperands == 2) +    return simplifyBinaryIntrinsic(F, ArgBegin[0], ArgBegin[1], Q); + +  // Handle intrinsics with 3 or more arguments. +  switch (IID) { +  case Intrinsic::masked_load: { +    Value *MaskArg = ArgBegin[2]; +    Value *PassthruArg = ArgBegin[3]; +    // If the mask is all zeros or undef, the "passthru" argument is the result. +    if (maskIsAllZeroOrUndef(MaskArg)) +      return PassthruArg; +    return nullptr; +  } +  case Intrinsic::fshl: +  case Intrinsic::fshr: { +    Value *ShAmtArg = ArgBegin[2]; +    const APInt *ShAmtC; +    if (match(ShAmtArg, m_APInt(ShAmtC))) { +      // If there's effectively no shift, return the 1st arg or 2nd arg. +      // TODO: For vectors, we could check each element of a non-splat constant. +      APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth()); +      if (ShAmtC->urem(BitWidth).isNullValue()) +        return ArgBegin[IID == Intrinsic::fshl ? 0 : 1]; +    } +    return nullptr; +  } +  default: +    return nullptr; +  } +} + +template <typename IterTy> +static Value *SimplifyCall(ImmutableCallSite CS, Value *V, IterTy ArgBegin, +                           IterTy ArgEnd, const SimplifyQuery &Q, +                           unsigned MaxRecurse) { +  Type *Ty = V->getType(); +  if (PointerType *PTy = dyn_cast<PointerType>(Ty)) +    Ty = PTy->getElementType(); +  FunctionType *FTy = cast<FunctionType>(Ty); + +  // call undef -> undef +  // call null -> undef +  if (isa<UndefValue>(V) || isa<ConstantPointerNull>(V)) +    return UndefValue::get(FTy->getReturnType()); + +  Function *F = dyn_cast<Function>(V); +  if (!F) +    return nullptr; + +  if (F->isIntrinsic()) +    if (Value *Ret = simplifyIntrinsic(F, ArgBegin, ArgEnd, Q)) +      return Ret; + +  if (!canConstantFoldCallTo(CS, F)) +    return nullptr; + +  SmallVector<Constant *, 4> ConstantArgs; +  ConstantArgs.reserve(ArgEnd - ArgBegin); +  for (IterTy I = ArgBegin, E = ArgEnd; I != E; ++I) { +    Constant *C = dyn_cast<Constant>(*I); +    if (!C) +      return nullptr; +    ConstantArgs.push_back(C); +  } + +  return ConstantFoldCall(CS, F, ConstantArgs, Q.TLI); +} + +Value *llvm::SimplifyCall(ImmutableCallSite CS, Value *V, +                          User::op_iterator ArgBegin, User::op_iterator ArgEnd, +                          const SimplifyQuery &Q) { +  return ::SimplifyCall(CS, V, ArgBegin, ArgEnd, Q, RecursionLimit); +} + +Value *llvm::SimplifyCall(ImmutableCallSite CS, Value *V, +                          ArrayRef<Value *> Args, const SimplifyQuery &Q) { +  return ::SimplifyCall(CS, V, Args.begin(), Args.end(), Q, RecursionLimit); +} + +Value *llvm::SimplifyCall(ImmutableCallSite ICS, const SimplifyQuery &Q) { +  CallSite CS(const_cast<Instruction*>(ICS.getInstruction())); +  return ::SimplifyCall(CS, CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), +                        Q, RecursionLimit); +} + +/// See if we can compute a simplified version of this instruction. +/// If not, this returns null. + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, +                                 OptimizationRemarkEmitter *ORE) { +  const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); +  Value *Result; + +  switch (I->getOpcode()) { +  default: +    Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); +    break; +  case Instruction::FAdd: +    Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), +                              I->getFastMathFlags(), Q); +    break; +  case Instruction::Add: +    Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1), +                             cast<BinaryOperator>(I)->hasNoSignedWrap(), +                             cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); +    break; +  case Instruction::FSub: +    Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), +                              I->getFastMathFlags(), Q); +    break; +  case Instruction::Sub: +    Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), +                             cast<BinaryOperator>(I)->hasNoSignedWrap(), +                             cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); +    break; +  case Instruction::FMul: +    Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), +                              I->getFastMathFlags(), Q); +    break; +  case Instruction::Mul: +    Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::SDiv: +    Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::UDiv: +    Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::FDiv: +    Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), +                              I->getFastMathFlags(), Q); +    break; +  case Instruction::SRem: +    Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::URem: +    Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::FRem: +    Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), +                              I->getFastMathFlags(), Q); +    break; +  case Instruction::Shl: +    Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), +                             cast<BinaryOperator>(I)->hasNoSignedWrap(), +                             cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); +    break; +  case Instruction::LShr: +    Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), +                              cast<BinaryOperator>(I)->isExact(), Q); +    break; +  case Instruction::AShr: +    Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), +                              cast<BinaryOperator>(I)->isExact(), Q); +    break; +  case Instruction::And: +    Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::Or: +    Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::Xor: +    Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::ICmp: +    Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), +                              I->getOperand(0), I->getOperand(1), Q); +    break; +  case Instruction::FCmp: +    Result = +        SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0), +                         I->getOperand(1), I->getFastMathFlags(), Q); +    break; +  case Instruction::Select: +    Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), +                                I->getOperand(2), Q); +    break; +  case Instruction::GetElementPtr: { +    SmallVector<Value *, 8> Ops(I->op_begin(), I->op_end()); +    Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(), +                             Ops, Q); +    break; +  } +  case Instruction::InsertValue: { +    InsertValueInst *IV = cast<InsertValueInst>(I); +    Result = SimplifyInsertValueInst(IV->getAggregateOperand(), +                                     IV->getInsertedValueOperand(), +                                     IV->getIndices(), Q); +    break; +  } +  case Instruction::InsertElement: { +    auto *IE = cast<InsertElementInst>(I); +    Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1), +                                       IE->getOperand(2), Q); +    break; +  } +  case Instruction::ExtractValue: { +    auto *EVI = cast<ExtractValueInst>(I); +    Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), +                                      EVI->getIndices(), Q); +    break; +  } +  case Instruction::ExtractElement: { +    auto *EEI = cast<ExtractElementInst>(I); +    Result = SimplifyExtractElementInst(EEI->getVectorOperand(), +                                        EEI->getIndexOperand(), Q); +    break; +  } +  case Instruction::ShuffleVector: { +    auto *SVI = cast<ShuffleVectorInst>(I); +    Result = SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), +                                       SVI->getMask(), SVI->getType(), Q); +    break; +  } +  case Instruction::PHI: +    Result = SimplifyPHINode(cast<PHINode>(I), Q); +    break; +  case Instruction::Call: { +    CallSite CS(cast<CallInst>(I)); +    Result = SimplifyCall(CS, Q); +    break; +  } +#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: +#include "llvm/IR/Instruction.def" +#undef HANDLE_CAST_INST +    Result = +        SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); +    break; +  case Instruction::Alloca: +    // No simplifications for Alloca and it can't be constant folded. +    Result = nullptr; +    break; +  } + +  // In general, it is possible for computeKnownBits to determine all bits in a +  // value even when the operands are not all constants. +  if (!Result && I->getType()->isIntOrIntVectorTy()) { +    KnownBits Known = computeKnownBits(I, Q.DL, /*Depth*/ 0, Q.AC, I, Q.DT, ORE); +    if (Known.isConstant()) +      Result = ConstantInt::get(I->getType(), Known.getConstant()); +  } + +  /// If called on unreachable code, the above logic may report that the +  /// instruction simplified to itself.  Make life easier for users by +  /// detecting that case here, returning a safe value instead. +  return Result == I ? UndefValue::get(I->getType()) : Result; +} + +/// Implementation of recursive simplification through an instruction's +/// uses. +/// +/// This is the common implementation of the recursive simplification routines. +/// If we have a pre-simplified value in 'SimpleV', that is forcibly used to +/// replace the instruction 'I'. Otherwise, we simply add 'I' to the list of +/// instructions to process and attempt to simplify it using +/// InstructionSimplify. +/// +/// This routine returns 'true' only when *it* simplifies something. The passed +/// in simplified value does not count toward this. +static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, +                                              const TargetLibraryInfo *TLI, +                                              const DominatorTree *DT, +                                              AssumptionCache *AC) { +  bool Simplified = false; +  SmallSetVector<Instruction *, 8> Worklist; +  const DataLayout &DL = I->getModule()->getDataLayout(); + +  // If we have an explicit value to collapse to, do that round of the +  // simplification loop by hand initially. +  if (SimpleV) { +    for (User *U : I->users()) +      if (U != I) +        Worklist.insert(cast<Instruction>(U)); + +    // Replace the instruction with its simplified value. +    I->replaceAllUsesWith(SimpleV); + +    // Gracefully handle edge cases where the instruction is not wired into any +    // parent block. +    if (I->getParent() && !I->isEHPad() && !isa<TerminatorInst>(I) && +        !I->mayHaveSideEffects()) +      I->eraseFromParent(); +  } else { +    Worklist.insert(I); +  } + +  // Note that we must test the size on each iteration, the worklist can grow. +  for (unsigned Idx = 0; Idx != Worklist.size(); ++Idx) { +    I = Worklist[Idx]; + +    // See if this instruction simplifies. +    SimpleV = SimplifyInstruction(I, {DL, TLI, DT, AC}); +    if (!SimpleV) +      continue; + +    Simplified = true; + +    // Stash away all the uses of the old instruction so we can check them for +    // recursive simplifications after a RAUW. This is cheaper than checking all +    // uses of To on the recursive step in most cases. +    for (User *U : I->users()) +      Worklist.insert(cast<Instruction>(U)); + +    // Replace the instruction with its simplified value. +    I->replaceAllUsesWith(SimpleV); + +    // Gracefully handle edge cases where the instruction is not wired into any +    // parent block. +    if (I->getParent() && !I->isEHPad() && !isa<TerminatorInst>(I) && +        !I->mayHaveSideEffects()) +      I->eraseFromParent(); +  } +  return Simplified; +} + +bool llvm::recursivelySimplifyInstruction(Instruction *I, +                                          const TargetLibraryInfo *TLI, +                                          const DominatorTree *DT, +                                          AssumptionCache *AC) { +  return replaceAndRecursivelySimplifyImpl(I, nullptr, TLI, DT, AC); +} + +bool llvm::replaceAndRecursivelySimplify(Instruction *I, Value *SimpleV, +                                         const TargetLibraryInfo *TLI, +                                         const DominatorTree *DT, +                                         AssumptionCache *AC) { +  assert(I != SimpleV && "replaceAndRecursivelySimplify(X,X) is not valid!"); +  assert(SimpleV && "Must provide a simplified value."); +  return replaceAndRecursivelySimplifyImpl(I, SimpleV, TLI, DT, AC); +} + +namespace llvm { +const SimplifyQuery getBestSimplifyQuery(Pass &P, Function &F) { +  auto *DTWP = P.getAnalysisIfAvailable<DominatorTreeWrapperPass>(); +  auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; +  auto *TLIWP = P.getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); +  auto *TLI = TLIWP ? &TLIWP->getTLI() : nullptr; +  auto *ACWP = P.getAnalysisIfAvailable<AssumptionCacheTracker>(); +  auto *AC = ACWP ? &ACWP->getAssumptionCache(F) : nullptr; +  return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} + +const SimplifyQuery getBestSimplifyQuery(LoopStandardAnalysisResults &AR, +                                         const DataLayout &DL) { +  return {DL, &AR.TLI, &AR.DT, &AR.AC}; +} + +template <class T, class... TArgs> +const SimplifyQuery getBestSimplifyQuery(AnalysisManager<T, TArgs...> &AM, +                                         Function &F) { +  auto *DT = AM.template getCachedResult<DominatorTreeAnalysis>(F); +  auto *TLI = AM.template getCachedResult<TargetLibraryAnalysis>(F); +  auto *AC = AM.template getCachedResult<AssumptionAnalysis>(F); +  return {F.getParent()->getDataLayout(), TLI, DT, AC}; +} +template const SimplifyQuery getBestSimplifyQuery(AnalysisManager<Function> &, +                                                  Function &); +} diff --git a/contrib/llvm/lib/Analysis/Interval.cpp b/contrib/llvm/lib/Analysis/Interval.cpp new file mode 100644 index 000000000000..6d5de22cb93f --- /dev/null +++ b/contrib/llvm/lib/Analysis/Interval.cpp @@ -0,0 +1,52 @@ +//===- Interval.cpp - Interval class code ---------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the Interval class, which represents a +// partition of a control flow graph of some kind. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Interval.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +//===----------------------------------------------------------------------===// +// Interval Implementation +//===----------------------------------------------------------------------===// + +// isLoop - Find out if there is a back edge in this interval... +bool Interval::isLoop() const { +  // There is a loop in this interval iff one of the predecessors of the header +  // node lives in the interval. +  for (::pred_iterator I = ::pred_begin(HeaderNode), E = ::pred_end(HeaderNode); +       I != E; ++I) +    if (contains(*I)) +      return true; +  return false; +} + +void Interval::print(raw_ostream &OS) const { +  OS << "-------------------------------------------------------------\n" +       << "Interval Contents:\n"; + +  // Print out all of the basic blocks in the interval... +  for (const BasicBlock *Node : Nodes) +    OS << *Node << "\n"; + +  OS << "Interval Predecessors:\n"; +  for (const BasicBlock *Predecessor : Predecessors) +    OS << *Predecessor << "\n"; + +  OS << "Interval Successors:\n"; +  for (const BasicBlock *Successor : Successors) +    OS << *Successor << "\n"; +} diff --git a/contrib/llvm/lib/Analysis/IntervalPartition.cpp b/contrib/llvm/lib/Analysis/IntervalPartition.cpp new file mode 100644 index 000000000000..c777d91b67c6 --- /dev/null +++ b/contrib/llvm/lib/Analysis/IntervalPartition.cpp @@ -0,0 +1,114 @@ +//===- IntervalPartition.cpp - Interval Partition module code -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the IntervalPartition class, which +// calculates and represent the interval partition of a function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IntervalPartition.h" +#include "llvm/Analysis/Interval.h" +#include "llvm/Analysis/IntervalIterator.h" +#include "llvm/Pass.h" +#include <cassert> +#include <utility> + +using namespace llvm; + +char IntervalPartition::ID = 0; + +INITIALIZE_PASS(IntervalPartition, "intervals", +                "Interval Partition Construction", true, true) + +//===----------------------------------------------------------------------===// +// IntervalPartition Implementation +//===----------------------------------------------------------------------===// + +// releaseMemory - Reset state back to before function was analyzed +void IntervalPartition::releaseMemory() { +  for (unsigned i = 0, e = Intervals.size(); i != e; ++i) +    delete Intervals[i]; +  IntervalMap.clear(); +  Intervals.clear(); +  RootInterval = nullptr; +} + +void IntervalPartition::print(raw_ostream &O, const Module*) const { +  for(unsigned i = 0, e = Intervals.size(); i != e; ++i) +    Intervals[i]->print(O); +} + +// addIntervalToPartition - Add an interval to the internal list of intervals, +// and then add mappings from all of the basic blocks in the interval to the +// interval itself (in the IntervalMap). +void IntervalPartition::addIntervalToPartition(Interval *I) { +  Intervals.push_back(I); + +  // Add mappings for all of the basic blocks in I to the IntervalPartition +  for (Interval::node_iterator It = I->Nodes.begin(), End = I->Nodes.end(); +       It != End; ++It) +    IntervalMap.insert(std::make_pair(*It, I)); +} + +// updatePredecessors - Interval generation only sets the successor fields of +// the interval data structures.  After interval generation is complete, +// run through all of the intervals and propagate successor info as +// predecessor info. +void IntervalPartition::updatePredecessors(Interval *Int) { +  BasicBlock *Header = Int->getHeaderNode(); +  for (BasicBlock *Successor : Int->Successors) +    getBlockInterval(Successor)->Predecessors.push_back(Header); +} + +// IntervalPartition ctor - Build the first level interval partition for the +// specified function... +bool IntervalPartition::runOnFunction(Function &F) { +  // Pass false to intervals_begin because we take ownership of it's memory +  function_interval_iterator I = intervals_begin(&F, false); +  assert(I != intervals_end(&F) && "No intervals in function!?!?!"); + +  addIntervalToPartition(RootInterval = *I); + +  ++I;  // After the first one... + +  // Add the rest of the intervals to the partition. +  for (function_interval_iterator E = intervals_end(&F); I != E; ++I) +    addIntervalToPartition(*I); + +  // Now that we know all of the successor information, propagate this to the +  // predecessors for each block. +  for (unsigned i = 0, e = Intervals.size(); i != e; ++i) +    updatePredecessors(Intervals[i]); +  return false; +} + +// IntervalPartition ctor - Build a reduced interval partition from an +// existing interval graph.  This takes an additional boolean parameter to +// distinguish it from a copy constructor.  Always pass in false for now. +IntervalPartition::IntervalPartition(IntervalPartition &IP, bool) +  : FunctionPass(ID) { +  assert(IP.getRootInterval() && "Cannot operate on empty IntervalPartitions!"); + +  // Pass false to intervals_begin because we take ownership of it's memory +  interval_part_interval_iterator I = intervals_begin(IP, false); +  assert(I != intervals_end(IP) && "No intervals in interval partition!?!?!"); + +  addIntervalToPartition(RootInterval = *I); + +  ++I;  // After the first one... + +  // Add the rest of the intervals to the partition. +  for (interval_part_interval_iterator E = intervals_end(IP); I != E; ++I) +    addIntervalToPartition(*I); + +  // Now that we know all of the successor information, propagate this to the +  // predecessors for each block. +  for (unsigned i = 0, e = Intervals.size(); i != e; ++i) +    updatePredecessors(Intervals[i]); +} diff --git a/contrib/llvm/lib/Analysis/IteratedDominanceFrontier.cpp b/contrib/llvm/lib/Analysis/IteratedDominanceFrontier.cpp new file mode 100644 index 000000000000..e7751d32aab3 --- /dev/null +++ b/contrib/llvm/lib/Analysis/IteratedDominanceFrontier.cpp @@ -0,0 +1,99 @@ +//===- IteratedDominanceFrontier.cpp - Compute IDF ------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Compute iterated dominance frontiers using a linear time algorithm. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include <queue> + +namespace llvm { +template <class NodeTy, bool IsPostDom> +void IDFCalculator<NodeTy, IsPostDom>::calculate( +    SmallVectorImpl<BasicBlock *> &PHIBlocks) { +  // Use a priority queue keyed on dominator tree level so that inserted nodes +  // are handled from the bottom of the dominator tree upwards. We also augment +  // the level with a DFS number to ensure that the blocks are ordered in a +  // deterministic way. +  typedef std::pair<DomTreeNode *, std::pair<unsigned, unsigned>> +      DomTreeNodePair; +  typedef std::priority_queue<DomTreeNodePair, SmallVector<DomTreeNodePair, 32>, +                              less_second> IDFPriorityQueue; +  IDFPriorityQueue PQ; + +  DT.updateDFSNumbers(); + +  for (BasicBlock *BB : *DefBlocks) { +    if (DomTreeNode *Node = DT.getNode(BB)) +      PQ.push({Node, std::make_pair(Node->getLevel(), Node->getDFSNumIn())}); +  } + +  SmallVector<DomTreeNode *, 32> Worklist; +  SmallPtrSet<DomTreeNode *, 32> VisitedPQ; +  SmallPtrSet<DomTreeNode *, 32> VisitedWorklist; + +  while (!PQ.empty()) { +    DomTreeNodePair RootPair = PQ.top(); +    PQ.pop(); +    DomTreeNode *Root = RootPair.first; +    unsigned RootLevel = RootPair.second.first; + +    // Walk all dominator tree children of Root, inspecting their CFG edges with +    // targets elsewhere on the dominator tree. Only targets whose level is at +    // most Root's level are added to the iterated dominance frontier of the +    // definition set. + +    Worklist.clear(); +    Worklist.push_back(Root); +    VisitedWorklist.insert(Root); + +    while (!Worklist.empty()) { +      DomTreeNode *Node = Worklist.pop_back_val(); +      BasicBlock *BB = Node->getBlock(); +      // Succ is the successor in the direction we are calculating IDF, so it is +      // successor for IDF, and predecessor for Reverse IDF. +      for (auto *Succ : children<NodeTy>(BB)) { +        DomTreeNode *SuccNode = DT.getNode(Succ); + +        // Quickly skip all CFG edges that are also dominator tree edges instead +        // of catching them below. +        if (SuccNode->getIDom() == Node) +          continue; + +        const unsigned SuccLevel = SuccNode->getLevel(); +        if (SuccLevel > RootLevel) +          continue; + +        if (!VisitedPQ.insert(SuccNode).second) +          continue; + +        BasicBlock *SuccBB = SuccNode->getBlock(); +        if (useLiveIn && !LiveInBlocks->count(SuccBB)) +          continue; + +        PHIBlocks.emplace_back(SuccBB); +        if (!DefBlocks->count(SuccBB)) +          PQ.push(std::make_pair( +              SuccNode, std::make_pair(SuccLevel, SuccNode->getDFSNumIn()))); +      } + +      for (auto DomChild : *Node) { +        if (VisitedWorklist.insert(DomChild).second) +          Worklist.push_back(DomChild); +      } +    } +  } +} + +template class IDFCalculator<BasicBlock *, false>; +template class IDFCalculator<Inverse<BasicBlock *>, true>; +} diff --git a/contrib/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp b/contrib/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp new file mode 100644 index 000000000000..93c23bca96af --- /dev/null +++ b/contrib/llvm/lib/Analysis/LazyBlockFrequencyInfo.cpp @@ -0,0 +1,72 @@ +//===- LazyBlockFrequencyInfo.cpp - Lazy Block Frequency Analysis ---------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This is an alternative analysis pass to BlockFrequencyInfoWrapperPass.  The +// difference is that with this pass the block frequencies are not computed when +// the analysis pass is executed but rather when the BFI result is explicitly +// requested by the analysis client. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/LazyBranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +#define DEBUG_TYPE "lazy-block-freq" + +INITIALIZE_PASS_BEGIN(LazyBlockFrequencyInfoPass, DEBUG_TYPE, +                      "Lazy Block Frequency Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(LazyBPIPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LazyBlockFrequencyInfoPass, DEBUG_TYPE, +                    "Lazy Block Frequency Analysis", true, true) + +char LazyBlockFrequencyInfoPass::ID = 0; + +LazyBlockFrequencyInfoPass::LazyBlockFrequencyInfoPass() : FunctionPass(ID) { +  initializeLazyBlockFrequencyInfoPassPass(*PassRegistry::getPassRegistry()); +} + +void LazyBlockFrequencyInfoPass::print(raw_ostream &OS, const Module *) const { +  LBFI.getCalculated().print(OS); +} + +void LazyBlockFrequencyInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { +  LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AU); +  // We require DT so it's available when LI is available. The LI updating code +  // asserts that DT is also present so if we don't make sure that we have DT +  // here, that assert will trigger. +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.addRequired<LoopInfoWrapperPass>(); +  AU.setPreservesAll(); +} + +void LazyBlockFrequencyInfoPass::releaseMemory() { LBFI.releaseMemory(); } + +bool LazyBlockFrequencyInfoPass::runOnFunction(Function &F) { +  auto &BPIPass = getAnalysis<LazyBranchProbabilityInfoPass>(); +  LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  LBFI.setAnalysis(&F, &BPIPass, &LI); +  return false; +} + +void LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AnalysisUsage &AU) { +  LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AU); +  AU.addRequired<LazyBlockFrequencyInfoPass>(); +  AU.addRequired<LoopInfoWrapperPass>(); +} + +void llvm::initializeLazyBFIPassPass(PassRegistry &Registry) { +  initializeLazyBPIPassPass(Registry); +  INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass); +  INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); +} diff --git a/contrib/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp b/contrib/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp new file mode 100644 index 000000000000..429b78c3a47e --- /dev/null +++ b/contrib/llvm/lib/Analysis/LazyBranchProbabilityInfo.cpp @@ -0,0 +1,74 @@ +//===- LazyBranchProbabilityInfo.cpp - Lazy Branch Probability Analysis ---===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This is an alternative analysis pass to BranchProbabilityInfoWrapperPass. +// The difference is that with this pass the branch probabilities are not +// computed when the analysis pass is executed but rather when the BPI results +// is explicitly requested by the analysis client. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyBranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +#define DEBUG_TYPE "lazy-branch-prob" + +INITIALIZE_PASS_BEGIN(LazyBranchProbabilityInfoPass, DEBUG_TYPE, +                      "Lazy Branch Probability Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LazyBranchProbabilityInfoPass, DEBUG_TYPE, +                    "Lazy Branch Probability Analysis", true, true) + +char LazyBranchProbabilityInfoPass::ID = 0; + +LazyBranchProbabilityInfoPass::LazyBranchProbabilityInfoPass() +    : FunctionPass(ID) { +  initializeLazyBranchProbabilityInfoPassPass(*PassRegistry::getPassRegistry()); +} + +void LazyBranchProbabilityInfoPass::print(raw_ostream &OS, +                                          const Module *) const { +  LBPI->getCalculated().print(OS); +} + +void LazyBranchProbabilityInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { +  // We require DT so it's available when LI is available. The LI updating code +  // asserts that DT is also present so if we don't make sure that we have DT +  // here, that assert will trigger. +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.addRequired<LoopInfoWrapperPass>(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +  AU.setPreservesAll(); +} + +void LazyBranchProbabilityInfoPass::releaseMemory() { LBPI.reset(); } + +bool LazyBranchProbabilityInfoPass::runOnFunction(Function &F) { +  LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); +  LBPI = llvm::make_unique<LazyBranchProbabilityInfo>(&F, &LI, &TLI); +  return false; +} + +void LazyBranchProbabilityInfoPass::getLazyBPIAnalysisUsage(AnalysisUsage &AU) { +  AU.addRequired<LazyBranchProbabilityInfoPass>(); +  AU.addRequired<LoopInfoWrapperPass>(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +} + +void llvm::initializeLazyBPIPassPass(PassRegistry &Registry) { +  INITIALIZE_PASS_DEPENDENCY(LazyBranchProbabilityInfoPass); +  INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); +  INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass); +} diff --git a/contrib/llvm/lib/Analysis/LazyCallGraph.cpp b/contrib/llvm/lib/Analysis/LazyCallGraph.cpp new file mode 100644 index 000000000000..b1d585bfc683 --- /dev/null +++ b/contrib/llvm/lib/Analysis/LazyCallGraph.cpp @@ -0,0 +1,1805 @@ +//===- LazyCallGraph.cpp - Analysis of a Module's call graph --------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <iterator> +#include <string> +#include <tuple> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "lcg" + +void LazyCallGraph::EdgeSequence::insertEdgeInternal(Node &TargetN, +                                                     Edge::Kind EK) { +  EdgeIndexMap.insert({&TargetN, Edges.size()}); +  Edges.emplace_back(TargetN, EK); +} + +void LazyCallGraph::EdgeSequence::setEdgeKind(Node &TargetN, Edge::Kind EK) { +  Edges[EdgeIndexMap.find(&TargetN)->second].setKind(EK); +} + +bool LazyCallGraph::EdgeSequence::removeEdgeInternal(Node &TargetN) { +  auto IndexMapI = EdgeIndexMap.find(&TargetN); +  if (IndexMapI == EdgeIndexMap.end()) +    return false; + +  Edges[IndexMapI->second] = Edge(); +  EdgeIndexMap.erase(IndexMapI); +  return true; +} + +static void addEdge(SmallVectorImpl<LazyCallGraph::Edge> &Edges, +                    DenseMap<LazyCallGraph::Node *, int> &EdgeIndexMap, +                    LazyCallGraph::Node &N, LazyCallGraph::Edge::Kind EK) { +  if (!EdgeIndexMap.insert({&N, Edges.size()}).second) +    return; + +  LLVM_DEBUG(dbgs() << "    Added callable function: " << N.getName() << "\n"); +  Edges.emplace_back(LazyCallGraph::Edge(N, EK)); +} + +LazyCallGraph::EdgeSequence &LazyCallGraph::Node::populateSlow() { +  assert(!Edges && "Must not have already populated the edges for this node!"); + +  LLVM_DEBUG(dbgs() << "  Adding functions called by '" << getName() +                    << "' to the graph.\n"); + +  Edges = EdgeSequence(); + +  SmallVector<Constant *, 16> Worklist; +  SmallPtrSet<Function *, 4> Callees; +  SmallPtrSet<Constant *, 16> Visited; + +  // Find all the potential call graph edges in this function. We track both +  // actual call edges and indirect references to functions. The direct calls +  // are trivially added, but to accumulate the latter we walk the instructions +  // and add every operand which is a constant to the worklist to process +  // afterward. +  // +  // Note that we consider *any* function with a definition to be a viable +  // edge. Even if the function's definition is subject to replacement by +  // some other module (say, a weak definition) there may still be +  // optimizations which essentially speculate based on the definition and +  // a way to check that the specific definition is in fact the one being +  // used. For example, this could be done by moving the weak definition to +  // a strong (internal) definition and making the weak definition be an +  // alias. Then a test of the address of the weak function against the new +  // strong definition's address would be an effective way to determine the +  // safety of optimizing a direct call edge. +  for (BasicBlock &BB : *F) +    for (Instruction &I : BB) { +      if (auto CS = CallSite(&I)) +        if (Function *Callee = CS.getCalledFunction()) +          if (!Callee->isDeclaration()) +            if (Callees.insert(Callee).second) { +              Visited.insert(Callee); +              addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(*Callee), +                      LazyCallGraph::Edge::Call); +            } + +      for (Value *Op : I.operand_values()) +        if (Constant *C = dyn_cast<Constant>(Op)) +          if (Visited.insert(C).second) +            Worklist.push_back(C); +    } + +  // We've collected all the constant (and thus potentially function or +  // function containing) operands to all of the instructions in the function. +  // Process them (recursively) collecting every function found. +  visitReferences(Worklist, Visited, [&](Function &F) { +    addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(F), +            LazyCallGraph::Edge::Ref); +  }); + +  // Add implicit reference edges to any defined libcall functions (if we +  // haven't found an explicit edge). +  for (auto *F : G->LibFunctions) +    if (!Visited.count(F)) +      addEdge(Edges->Edges, Edges->EdgeIndexMap, G->get(*F), +              LazyCallGraph::Edge::Ref); + +  return *Edges; +} + +void LazyCallGraph::Node::replaceFunction(Function &NewF) { +  assert(F != &NewF && "Must not replace a function with itself!"); +  F = &NewF; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::Node::dump() const { +  dbgs() << *this << '\n'; +} +#endif + +static bool isKnownLibFunction(Function &F, TargetLibraryInfo &TLI) { +  LibFunc LF; + +  // Either this is a normal library function or a "vectorizable" function. +  return TLI.getLibFunc(F, LF) || TLI.isFunctionVectorizable(F.getName()); +} + +LazyCallGraph::LazyCallGraph(Module &M, TargetLibraryInfo &TLI) { +  LLVM_DEBUG(dbgs() << "Building CG for module: " << M.getModuleIdentifier() +                    << "\n"); +  for (Function &F : M) { +    if (F.isDeclaration()) +      continue; +    // If this function is a known lib function to LLVM then we want to +    // synthesize reference edges to it to model the fact that LLVM can turn +    // arbitrary code into a library function call. +    if (isKnownLibFunction(F, TLI)) +      LibFunctions.insert(&F); + +    if (F.hasLocalLinkage()) +      continue; + +    // External linkage defined functions have edges to them from other +    // modules. +    LLVM_DEBUG(dbgs() << "  Adding '" << F.getName() +                      << "' to entry set of the graph.\n"); +    addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F), Edge::Ref); +  } + +  // Now add entry nodes for functions reachable via initializers to globals. +  SmallVector<Constant *, 16> Worklist; +  SmallPtrSet<Constant *, 16> Visited; +  for (GlobalVariable &GV : M.globals()) +    if (GV.hasInitializer()) +      if (Visited.insert(GV.getInitializer()).second) +        Worklist.push_back(GV.getInitializer()); + +  LLVM_DEBUG( +      dbgs() << "  Adding functions referenced by global initializers to the " +                "entry set.\n"); +  visitReferences(Worklist, Visited, [&](Function &F) { +    addEdge(EntryEdges.Edges, EntryEdges.EdgeIndexMap, get(F), +            LazyCallGraph::Edge::Ref); +  }); +} + +LazyCallGraph::LazyCallGraph(LazyCallGraph &&G) +    : BPA(std::move(G.BPA)), NodeMap(std::move(G.NodeMap)), +      EntryEdges(std::move(G.EntryEdges)), SCCBPA(std::move(G.SCCBPA)), +      SCCMap(std::move(G.SCCMap)), +      LibFunctions(std::move(G.LibFunctions)) { +  updateGraphPtrs(); +} + +LazyCallGraph &LazyCallGraph::operator=(LazyCallGraph &&G) { +  BPA = std::move(G.BPA); +  NodeMap = std::move(G.NodeMap); +  EntryEdges = std::move(G.EntryEdges); +  SCCBPA = std::move(G.SCCBPA); +  SCCMap = std::move(G.SCCMap); +  LibFunctions = std::move(G.LibFunctions); +  updateGraphPtrs(); +  return *this; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::SCC::dump() const { +  dbgs() << *this << '\n'; +} +#endif + +#ifndef NDEBUG +void LazyCallGraph::SCC::verify() { +  assert(OuterRefSCC && "Can't have a null RefSCC!"); +  assert(!Nodes.empty() && "Can't have an empty SCC!"); + +  for (Node *N : Nodes) { +    assert(N && "Can't have a null node!"); +    assert(OuterRefSCC->G->lookupSCC(*N) == this && +           "Node does not map to this SCC!"); +    assert(N->DFSNumber == -1 && +           "Must set DFS numbers to -1 when adding a node to an SCC!"); +    assert(N->LowLink == -1 && +           "Must set low link to -1 when adding a node to an SCC!"); +    for (Edge &E : **N) +      assert(E.getNode().isPopulated() && "Can't have an unpopulated node!"); +  } +} +#endif + +bool LazyCallGraph::SCC::isParentOf(const SCC &C) const { +  if (this == &C) +    return false; + +  for (Node &N : *this) +    for (Edge &E : N->calls()) +      if (OuterRefSCC->G->lookupSCC(E.getNode()) == &C) +        return true; + +  // No edges found. +  return false; +} + +bool LazyCallGraph::SCC::isAncestorOf(const SCC &TargetC) const { +  if (this == &TargetC) +    return false; + +  LazyCallGraph &G = *OuterRefSCC->G; + +  // Start with this SCC. +  SmallPtrSet<const SCC *, 16> Visited = {this}; +  SmallVector<const SCC *, 16> Worklist = {this}; + +  // Walk down the graph until we run out of edges or find a path to TargetC. +  do { +    const SCC &C = *Worklist.pop_back_val(); +    for (Node &N : C) +      for (Edge &E : N->calls()) { +        SCC *CalleeC = G.lookupSCC(E.getNode()); +        if (!CalleeC) +          continue; + +        // If the callee's SCC is the TargetC, we're done. +        if (CalleeC == &TargetC) +          return true; + +        // If this is the first time we've reached this SCC, put it on the +        // worklist to recurse through. +        if (Visited.insert(CalleeC).second) +          Worklist.push_back(CalleeC); +      } +  } while (!Worklist.empty()); + +  // No paths found. +  return false; +} + +LazyCallGraph::RefSCC::RefSCC(LazyCallGraph &G) : G(&G) {} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void LazyCallGraph::RefSCC::dump() const { +  dbgs() << *this << '\n'; +} +#endif + +#ifndef NDEBUG +void LazyCallGraph::RefSCC::verify() { +  assert(G && "Can't have a null graph!"); +  assert(!SCCs.empty() && "Can't have an empty SCC!"); + +  // Verify basic properties of the SCCs. +  SmallPtrSet<SCC *, 4> SCCSet; +  for (SCC *C : SCCs) { +    assert(C && "Can't have a null SCC!"); +    C->verify(); +    assert(&C->getOuterRefSCC() == this && +           "SCC doesn't think it is inside this RefSCC!"); +    bool Inserted = SCCSet.insert(C).second; +    assert(Inserted && "Found a duplicate SCC!"); +    auto IndexIt = SCCIndices.find(C); +    assert(IndexIt != SCCIndices.end() && +           "Found an SCC that doesn't have an index!"); +  } + +  // Check that our indices map correctly. +  for (auto &SCCIndexPair : SCCIndices) { +    SCC *C = SCCIndexPair.first; +    int i = SCCIndexPair.second; +    assert(C && "Can't have a null SCC in the indices!"); +    assert(SCCSet.count(C) && "Found an index for an SCC not in the RefSCC!"); +    assert(SCCs[i] == C && "Index doesn't point to SCC!"); +  } + +  // Check that the SCCs are in fact in post-order. +  for (int i = 0, Size = SCCs.size(); i < Size; ++i) { +    SCC &SourceSCC = *SCCs[i]; +    for (Node &N : SourceSCC) +      for (Edge &E : *N) { +        if (!E.isCall()) +          continue; +        SCC &TargetSCC = *G->lookupSCC(E.getNode()); +        if (&TargetSCC.getOuterRefSCC() == this) { +          assert(SCCIndices.find(&TargetSCC)->second <= i && +                 "Edge between SCCs violates post-order relationship."); +          continue; +        } +      } +  } +} +#endif + +bool LazyCallGraph::RefSCC::isParentOf(const RefSCC &RC) const { +  if (&RC == this) +    return false; + +  // Search all edges to see if this is a parent. +  for (SCC &C : *this) +    for (Node &N : C) +      for (Edge &E : *N) +        if (G->lookupRefSCC(E.getNode()) == &RC) +          return true; + +  return false; +} + +bool LazyCallGraph::RefSCC::isAncestorOf(const RefSCC &RC) const { +  if (&RC == this) +    return false; + +  // For each descendant of this RefSCC, see if one of its children is the +  // argument. If not, add that descendant to the worklist and continue +  // searching. +  SmallVector<const RefSCC *, 4> Worklist; +  SmallPtrSet<const RefSCC *, 4> Visited; +  Worklist.push_back(this); +  Visited.insert(this); +  do { +    const RefSCC &DescendantRC = *Worklist.pop_back_val(); +    for (SCC &C : DescendantRC) +      for (Node &N : C) +        for (Edge &E : *N) { +          auto *ChildRC = G->lookupRefSCC(E.getNode()); +          if (ChildRC == &RC) +            return true; +          if (!ChildRC || !Visited.insert(ChildRC).second) +            continue; +          Worklist.push_back(ChildRC); +        } +  } while (!Worklist.empty()); + +  return false; +} + +/// Generic helper that updates a postorder sequence of SCCs for a potentially +/// cycle-introducing edge insertion. +/// +/// A postorder sequence of SCCs of a directed graph has one fundamental +/// property: all deges in the DAG of SCCs point "up" the sequence. That is, +/// all edges in the SCC DAG point to prior SCCs in the sequence. +/// +/// This routine both updates a postorder sequence and uses that sequence to +/// compute the set of SCCs connected into a cycle. It should only be called to +/// insert a "downward" edge which will require changing the sequence to +/// restore it to a postorder. +/// +/// When inserting an edge from an earlier SCC to a later SCC in some postorder +/// sequence, all of the SCCs which may be impacted are in the closed range of +/// those two within the postorder sequence. The algorithm used here to restore +/// the state is as follows: +/// +/// 1) Starting from the source SCC, construct a set of SCCs which reach the +///    source SCC consisting of just the source SCC. Then scan toward the +///    target SCC in postorder and for each SCC, if it has an edge to an SCC +///    in the set, add it to the set. Otherwise, the source SCC is not +///    a successor, move it in the postorder sequence to immediately before +///    the source SCC, shifting the source SCC and all SCCs in the set one +///    position toward the target SCC. Stop scanning after processing the +///    target SCC. +/// 2) If the source SCC is now past the target SCC in the postorder sequence, +///    and thus the new edge will flow toward the start, we are done. +/// 3) Otherwise, starting from the target SCC, walk all edges which reach an +///    SCC between the source and the target, and add them to the set of +///    connected SCCs, then recurse through them. Once a complete set of the +///    SCCs the target connects to is known, hoist the remaining SCCs between +///    the source and the target to be above the target. Note that there is no +///    need to process the source SCC, it is already known to connect. +/// 4) At this point, all of the SCCs in the closed range between the source +///    SCC and the target SCC in the postorder sequence are connected, +///    including the target SCC and the source SCC. Inserting the edge from +///    the source SCC to the target SCC will form a cycle out of precisely +///    these SCCs. Thus we can merge all of the SCCs in this closed range into +///    a single SCC. +/// +/// This process has various important properties: +/// - Only mutates the SCCs when adding the edge actually changes the SCC +///   structure. +/// - Never mutates SCCs which are unaffected by the change. +/// - Updates the postorder sequence to correctly satisfy the postorder +///   constraint after the edge is inserted. +/// - Only reorders SCCs in the closed postorder sequence from the source to +///   the target, so easy to bound how much has changed even in the ordering. +/// - Big-O is the number of edges in the closed postorder range of SCCs from +///   source to target. +/// +/// This helper routine, in addition to updating the postorder sequence itself +/// will also update a map from SCCs to indices within that sequence. +/// +/// The sequence and the map must operate on pointers to the SCC type. +/// +/// Two callbacks must be provided. The first computes the subset of SCCs in +/// the postorder closed range from the source to the target which connect to +/// the source SCC via some (transitive) set of edges. The second computes the +/// subset of the same range which the target SCC connects to via some +/// (transitive) set of edges. Both callbacks should populate the set argument +/// provided. +template <typename SCCT, typename PostorderSequenceT, typename SCCIndexMapT, +          typename ComputeSourceConnectedSetCallableT, +          typename ComputeTargetConnectedSetCallableT> +static iterator_range<typename PostorderSequenceT::iterator> +updatePostorderSequenceForEdgeInsertion( +    SCCT &SourceSCC, SCCT &TargetSCC, PostorderSequenceT &SCCs, +    SCCIndexMapT &SCCIndices, +    ComputeSourceConnectedSetCallableT ComputeSourceConnectedSet, +    ComputeTargetConnectedSetCallableT ComputeTargetConnectedSet) { +  int SourceIdx = SCCIndices[&SourceSCC]; +  int TargetIdx = SCCIndices[&TargetSCC]; +  assert(SourceIdx < TargetIdx && "Cannot have equal indices here!"); + +  SmallPtrSet<SCCT *, 4> ConnectedSet; + +  // Compute the SCCs which (transitively) reach the source. +  ComputeSourceConnectedSet(ConnectedSet); + +  // Partition the SCCs in this part of the port-order sequence so only SCCs +  // connecting to the source remain between it and the target. This is +  // a benign partition as it preserves postorder. +  auto SourceI = std::stable_partition( +      SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx + 1, +      [&ConnectedSet](SCCT *C) { return !ConnectedSet.count(C); }); +  for (int i = SourceIdx, e = TargetIdx + 1; i < e; ++i) +    SCCIndices.find(SCCs[i])->second = i; + +  // If the target doesn't connect to the source, then we've corrected the +  // post-order and there are no cycles formed. +  if (!ConnectedSet.count(&TargetSCC)) { +    assert(SourceI > (SCCs.begin() + SourceIdx) && +           "Must have moved the source to fix the post-order."); +    assert(*std::prev(SourceI) == &TargetSCC && +           "Last SCC to move should have bene the target."); + +    // Return an empty range at the target SCC indicating there is nothing to +    // merge. +    return make_range(std::prev(SourceI), std::prev(SourceI)); +  } + +  assert(SCCs[TargetIdx] == &TargetSCC && +         "Should not have moved target if connected!"); +  SourceIdx = SourceI - SCCs.begin(); +  assert(SCCs[SourceIdx] == &SourceSCC && +         "Bad updated index computation for the source SCC!"); + + +  // See whether there are any remaining intervening SCCs between the source +  // and target. If so we need to make sure they all are reachable form the +  // target. +  if (SourceIdx + 1 < TargetIdx) { +    ConnectedSet.clear(); +    ComputeTargetConnectedSet(ConnectedSet); + +    // Partition SCCs so that only SCCs reached from the target remain between +    // the source and the target. This preserves postorder. +    auto TargetI = std::stable_partition( +        SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1, +        [&ConnectedSet](SCCT *C) { return ConnectedSet.count(C); }); +    for (int i = SourceIdx + 1, e = TargetIdx + 1; i < e; ++i) +      SCCIndices.find(SCCs[i])->second = i; +    TargetIdx = std::prev(TargetI) - SCCs.begin(); +    assert(SCCs[TargetIdx] == &TargetSCC && +           "Should always end with the target!"); +  } + +  // At this point, we know that connecting source to target forms a cycle +  // because target connects back to source, and we know that all of the SCCs +  // between the source and target in the postorder sequence participate in that +  // cycle. +  return make_range(SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx); +} + +bool +LazyCallGraph::RefSCC::switchInternalEdgeToCall( +    Node &SourceN, Node &TargetN, +    function_ref<void(ArrayRef<SCC *> MergeSCCs)> MergeCB) { +  assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!"); +  SmallVector<SCC *, 1> DeletedSCCs; + +#ifndef NDEBUG +  // In a debug build, verify the RefSCC is valid to start with and when this +  // routine finishes. +  verify(); +  auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + +  SCC &SourceSCC = *G->lookupSCC(SourceN); +  SCC &TargetSCC = *G->lookupSCC(TargetN); + +  // If the two nodes are already part of the same SCC, we're also done as +  // we've just added more connectivity. +  if (&SourceSCC == &TargetSCC) { +    SourceN->setEdgeKind(TargetN, Edge::Call); +    return false; // No new cycle. +  } + +  // At this point we leverage the postorder list of SCCs to detect when the +  // insertion of an edge changes the SCC structure in any way. +  // +  // First and foremost, we can eliminate the need for any changes when the +  // edge is toward the beginning of the postorder sequence because all edges +  // flow in that direction already. Thus adding a new one cannot form a cycle. +  int SourceIdx = SCCIndices[&SourceSCC]; +  int TargetIdx = SCCIndices[&TargetSCC]; +  if (TargetIdx < SourceIdx) { +    SourceN->setEdgeKind(TargetN, Edge::Call); +    return false; // No new cycle. +  } + +  // Compute the SCCs which (transitively) reach the source. +  auto ComputeSourceConnectedSet = [&](SmallPtrSetImpl<SCC *> &ConnectedSet) { +#ifndef NDEBUG +    // Check that the RefSCC is still valid before computing this as the +    // results will be nonsensical of we've broken its invariants. +    verify(); +#endif +    ConnectedSet.insert(&SourceSCC); +    auto IsConnected = [&](SCC &C) { +      for (Node &N : C) +        for (Edge &E : N->calls()) +          if (ConnectedSet.count(G->lookupSCC(E.getNode()))) +            return true; + +      return false; +    }; + +    for (SCC *C : +         make_range(SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1)) +      if (IsConnected(*C)) +        ConnectedSet.insert(C); +  }; + +  // Use a normal worklist to find which SCCs the target connects to. We still +  // bound the search based on the range in the postorder list we care about, +  // but because this is forward connectivity we just "recurse" through the +  // edges. +  auto ComputeTargetConnectedSet = [&](SmallPtrSetImpl<SCC *> &ConnectedSet) { +#ifndef NDEBUG +    // Check that the RefSCC is still valid before computing this as the +    // results will be nonsensical of we've broken its invariants. +    verify(); +#endif +    ConnectedSet.insert(&TargetSCC); +    SmallVector<SCC *, 4> Worklist; +    Worklist.push_back(&TargetSCC); +    do { +      SCC &C = *Worklist.pop_back_val(); +      for (Node &N : C) +        for (Edge &E : *N) { +          if (!E.isCall()) +            continue; +          SCC &EdgeC = *G->lookupSCC(E.getNode()); +          if (&EdgeC.getOuterRefSCC() != this) +            // Not in this RefSCC... +            continue; +          if (SCCIndices.find(&EdgeC)->second <= SourceIdx) +            // Not in the postorder sequence between source and target. +            continue; + +          if (ConnectedSet.insert(&EdgeC).second) +            Worklist.push_back(&EdgeC); +        } +    } while (!Worklist.empty()); +  }; + +  // Use a generic helper to update the postorder sequence of SCCs and return +  // a range of any SCCs connected into a cycle by inserting this edge. This +  // routine will also take care of updating the indices into the postorder +  // sequence. +  auto MergeRange = updatePostorderSequenceForEdgeInsertion( +      SourceSCC, TargetSCC, SCCs, SCCIndices, ComputeSourceConnectedSet, +      ComputeTargetConnectedSet); + +  // Run the user's callback on the merged SCCs before we actually merge them. +  if (MergeCB) +    MergeCB(makeArrayRef(MergeRange.begin(), MergeRange.end())); + +  // If the merge range is empty, then adding the edge didn't actually form any +  // new cycles. We're done. +  if (MergeRange.begin() == MergeRange.end()) { +    // Now that the SCC structure is finalized, flip the kind to call. +    SourceN->setEdgeKind(TargetN, Edge::Call); +    return false; // No new cycle. +  } + +#ifndef NDEBUG +  // Before merging, check that the RefSCC remains valid after all the +  // postorder updates. +  verify(); +#endif + +  // Otherwise we need to merge all of the SCCs in the cycle into a single +  // result SCC. +  // +  // NB: We merge into the target because all of these functions were already +  // reachable from the target, meaning any SCC-wide properties deduced about it +  // other than the set of functions within it will not have changed. +  for (SCC *C : MergeRange) { +    assert(C != &TargetSCC && +           "We merge *into* the target and shouldn't process it here!"); +    SCCIndices.erase(C); +    TargetSCC.Nodes.append(C->Nodes.begin(), C->Nodes.end()); +    for (Node *N : C->Nodes) +      G->SCCMap[N] = &TargetSCC; +    C->clear(); +    DeletedSCCs.push_back(C); +  } + +  // Erase the merged SCCs from the list and update the indices of the +  // remaining SCCs. +  int IndexOffset = MergeRange.end() - MergeRange.begin(); +  auto EraseEnd = SCCs.erase(MergeRange.begin(), MergeRange.end()); +  for (SCC *C : make_range(EraseEnd, SCCs.end())) +    SCCIndices[C] -= IndexOffset; + +  // Now that the SCC structure is finalized, flip the kind to call. +  SourceN->setEdgeKind(TargetN, Edge::Call); + +  // And we're done, but we did form a new cycle. +  return true; +} + +void LazyCallGraph::RefSCC::switchTrivialInternalEdgeToRef(Node &SourceN, +                                                           Node &TargetN) { +  assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); + +#ifndef NDEBUG +  // In a debug build, verify the RefSCC is valid to start with and when this +  // routine finishes. +  verify(); +  auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + +  assert(G->lookupRefSCC(SourceN) == this && +         "Source must be in this RefSCC."); +  assert(G->lookupRefSCC(TargetN) == this && +         "Target must be in this RefSCC."); +  assert(G->lookupSCC(SourceN) != G->lookupSCC(TargetN) && +         "Source and Target must be in separate SCCs for this to be trivial!"); + +  // Set the edge kind. +  SourceN->setEdgeKind(TargetN, Edge::Ref); +} + +iterator_range<LazyCallGraph::RefSCC::iterator> +LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, Node &TargetN) { +  assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); + +#ifndef NDEBUG +  // In a debug build, verify the RefSCC is valid to start with and when this +  // routine finishes. +  verify(); +  auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + +  assert(G->lookupRefSCC(SourceN) == this && +         "Source must be in this RefSCC."); +  assert(G->lookupRefSCC(TargetN) == this && +         "Target must be in this RefSCC."); + +  SCC &TargetSCC = *G->lookupSCC(TargetN); +  assert(G->lookupSCC(SourceN) == &TargetSCC && "Source and Target must be in " +                                                "the same SCC to require the " +                                                "full CG update."); + +  // Set the edge kind. +  SourceN->setEdgeKind(TargetN, Edge::Ref); + +  // Otherwise we are removing a call edge from a single SCC. This may break +  // the cycle. In order to compute the new set of SCCs, we need to do a small +  // DFS over the nodes within the SCC to form any sub-cycles that remain as +  // distinct SCCs and compute a postorder over the resulting SCCs. +  // +  // However, we specially handle the target node. The target node is known to +  // reach all other nodes in the original SCC by definition. This means that +  // we want the old SCC to be replaced with an SCC containing that node as it +  // will be the root of whatever SCC DAG results from the DFS. Assumptions +  // about an SCC such as the set of functions called will continue to hold, +  // etc. + +  SCC &OldSCC = TargetSCC; +  SmallVector<std::pair<Node *, EdgeSequence::call_iterator>, 16> DFSStack; +  SmallVector<Node *, 16> PendingSCCStack; +  SmallVector<SCC *, 4> NewSCCs; + +  // Prepare the nodes for a fresh DFS. +  SmallVector<Node *, 16> Worklist; +  Worklist.swap(OldSCC.Nodes); +  for (Node *N : Worklist) { +    N->DFSNumber = N->LowLink = 0; +    G->SCCMap.erase(N); +  } + +  // Force the target node to be in the old SCC. This also enables us to take +  // a very significant short-cut in the standard Tarjan walk to re-form SCCs +  // below: whenever we build an edge that reaches the target node, we know +  // that the target node eventually connects back to all other nodes in our +  // walk. As a consequence, we can detect and handle participants in that +  // cycle without walking all the edges that form this connection, and instead +  // by relying on the fundamental guarantee coming into this operation (all +  // nodes are reachable from the target due to previously forming an SCC). +  TargetN.DFSNumber = TargetN.LowLink = -1; +  OldSCC.Nodes.push_back(&TargetN); +  G->SCCMap[&TargetN] = &OldSCC; + +  // Scan down the stack and DFS across the call edges. +  for (Node *RootN : Worklist) { +    assert(DFSStack.empty() && +           "Cannot begin a new root with a non-empty DFS stack!"); +    assert(PendingSCCStack.empty() && +           "Cannot begin a new root with pending nodes for an SCC!"); + +    // Skip any nodes we've already reached in the DFS. +    if (RootN->DFSNumber != 0) { +      assert(RootN->DFSNumber == -1 && +             "Shouldn't have any mid-DFS root nodes!"); +      continue; +    } + +    RootN->DFSNumber = RootN->LowLink = 1; +    int NextDFSNumber = 2; + +    DFSStack.push_back({RootN, (*RootN)->call_begin()}); +    do { +      Node *N; +      EdgeSequence::call_iterator I; +      std::tie(N, I) = DFSStack.pop_back_val(); +      auto E = (*N)->call_end(); +      while (I != E) { +        Node &ChildN = I->getNode(); +        if (ChildN.DFSNumber == 0) { +          // We haven't yet visited this child, so descend, pushing the current +          // node onto the stack. +          DFSStack.push_back({N, I}); + +          assert(!G->SCCMap.count(&ChildN) && +                 "Found a node with 0 DFS number but already in an SCC!"); +          ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; +          N = &ChildN; +          I = (*N)->call_begin(); +          E = (*N)->call_end(); +          continue; +        } + +        // Check for the child already being part of some component. +        if (ChildN.DFSNumber == -1) { +          if (G->lookupSCC(ChildN) == &OldSCC) { +            // If the child is part of the old SCC, we know that it can reach +            // every other node, so we have formed a cycle. Pull the entire DFS +            // and pending stacks into it. See the comment above about setting +            // up the old SCC for why we do this. +            int OldSize = OldSCC.size(); +            OldSCC.Nodes.push_back(N); +            OldSCC.Nodes.append(PendingSCCStack.begin(), PendingSCCStack.end()); +            PendingSCCStack.clear(); +            while (!DFSStack.empty()) +              OldSCC.Nodes.push_back(DFSStack.pop_back_val().first); +            for (Node &N : make_range(OldSCC.begin() + OldSize, OldSCC.end())) { +              N.DFSNumber = N.LowLink = -1; +              G->SCCMap[&N] = &OldSCC; +            } +            N = nullptr; +            break; +          } + +          // If the child has already been added to some child component, it +          // couldn't impact the low-link of this parent because it isn't +          // connected, and thus its low-link isn't relevant so skip it. +          ++I; +          continue; +        } + +        // Track the lowest linked child as the lowest link for this node. +        assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); +        if (ChildN.LowLink < N->LowLink) +          N->LowLink = ChildN.LowLink; + +        // Move to the next edge. +        ++I; +      } +      if (!N) +        // Cleared the DFS early, start another round. +        break; + +      // We've finished processing N and its descendants, put it on our pending +      // SCC stack to eventually get merged into an SCC of nodes. +      PendingSCCStack.push_back(N); + +      // If this node is linked to some lower entry, continue walking up the +      // stack. +      if (N->LowLink != N->DFSNumber) +        continue; + +      // Otherwise, we've completed an SCC. Append it to our post order list of +      // SCCs. +      int RootDFSNumber = N->DFSNumber; +      // Find the range of the node stack by walking down until we pass the +      // root DFS number. +      auto SCCNodes = make_range( +          PendingSCCStack.rbegin(), +          find_if(reverse(PendingSCCStack), [RootDFSNumber](const Node *N) { +            return N->DFSNumber < RootDFSNumber; +          })); + +      // Form a new SCC out of these nodes and then clear them off our pending +      // stack. +      NewSCCs.push_back(G->createSCC(*this, SCCNodes)); +      for (Node &N : *NewSCCs.back()) { +        N.DFSNumber = N.LowLink = -1; +        G->SCCMap[&N] = NewSCCs.back(); +      } +      PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end()); +    } while (!DFSStack.empty()); +  } + +  // Insert the remaining SCCs before the old one. The old SCC can reach all +  // other SCCs we form because it contains the target node of the removed edge +  // of the old SCC. This means that we will have edges into all of the new +  // SCCs, which means the old one must come last for postorder. +  int OldIdx = SCCIndices[&OldSCC]; +  SCCs.insert(SCCs.begin() + OldIdx, NewSCCs.begin(), NewSCCs.end()); + +  // Update the mapping from SCC* to index to use the new SCC*s, and remove the +  // old SCC from the mapping. +  for (int Idx = OldIdx, Size = SCCs.size(); Idx < Size; ++Idx) +    SCCIndices[SCCs[Idx]] = Idx; + +  return make_range(SCCs.begin() + OldIdx, +                    SCCs.begin() + OldIdx + NewSCCs.size()); +} + +void LazyCallGraph::RefSCC::switchOutgoingEdgeToCall(Node &SourceN, +                                                     Node &TargetN) { +  assert(!(*SourceN)[TargetN].isCall() && "Must start with a ref edge!"); + +  assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); +  assert(G->lookupRefSCC(TargetN) != this && +         "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS +  assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && +         "Target must be a descendant of the Source."); +#endif + +  // Edges between RefSCCs are the same regardless of call or ref, so we can +  // just flip the edge here. +  SourceN->setEdgeKind(TargetN, Edge::Call); + +#ifndef NDEBUG +  // Check that the RefSCC is still valid. +  verify(); +#endif +} + +void LazyCallGraph::RefSCC::switchOutgoingEdgeToRef(Node &SourceN, +                                                    Node &TargetN) { +  assert((*SourceN)[TargetN].isCall() && "Must start with a call edge!"); + +  assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); +  assert(G->lookupRefSCC(TargetN) != this && +         "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS +  assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && +         "Target must be a descendant of the Source."); +#endif + +  // Edges between RefSCCs are the same regardless of call or ref, so we can +  // just flip the edge here. +  SourceN->setEdgeKind(TargetN, Edge::Ref); + +#ifndef NDEBUG +  // Check that the RefSCC is still valid. +  verify(); +#endif +} + +void LazyCallGraph::RefSCC::insertInternalRefEdge(Node &SourceN, +                                                  Node &TargetN) { +  assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); +  assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC."); + +  SourceN->insertEdgeInternal(TargetN, Edge::Ref); + +#ifndef NDEBUG +  // Check that the RefSCC is still valid. +  verify(); +#endif +} + +void LazyCallGraph::RefSCC::insertOutgoingEdge(Node &SourceN, Node &TargetN, +                                               Edge::Kind EK) { +  // First insert it into the caller. +  SourceN->insertEdgeInternal(TargetN, EK); + +  assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + +  assert(G->lookupRefSCC(TargetN) != this && +         "Target must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS +  assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && +         "Target must be a descendant of the Source."); +#endif + +#ifndef NDEBUG +  // Check that the RefSCC is still valid. +  verify(); +#endif +} + +SmallVector<LazyCallGraph::RefSCC *, 1> +LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) { +  assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC."); +  RefSCC &SourceC = *G->lookupRefSCC(SourceN); +  assert(&SourceC != this && "Source must not be in this RefSCC."); +#ifdef EXPENSIVE_CHECKS +  assert(SourceC.isDescendantOf(*this) && +         "Source must be a descendant of the Target."); +#endif + +  SmallVector<RefSCC *, 1> DeletedRefSCCs; + +#ifndef NDEBUG +  // In a debug build, verify the RefSCC is valid to start with and when this +  // routine finishes. +  verify(); +  auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + +  int SourceIdx = G->RefSCCIndices[&SourceC]; +  int TargetIdx = G->RefSCCIndices[this]; +  assert(SourceIdx < TargetIdx && +         "Postorder list doesn't see edge as incoming!"); + +  // Compute the RefSCCs which (transitively) reach the source. We do this by +  // working backwards from the source using the parent set in each RefSCC, +  // skipping any RefSCCs that don't fall in the postorder range. This has the +  // advantage of walking the sparser parent edge (in high fan-out graphs) but +  // more importantly this removes examining all forward edges in all RefSCCs +  // within the postorder range which aren't in fact connected. Only connected +  // RefSCCs (and their edges) are visited here. +  auto ComputeSourceConnectedSet = [&](SmallPtrSetImpl<RefSCC *> &Set) { +    Set.insert(&SourceC); +    auto IsConnected = [&](RefSCC &RC) { +      for (SCC &C : RC) +        for (Node &N : C) +          for (Edge &E : *N) +            if (Set.count(G->lookupRefSCC(E.getNode()))) +              return true; + +      return false; +    }; + +    for (RefSCC *C : make_range(G->PostOrderRefSCCs.begin() + SourceIdx + 1, +                                G->PostOrderRefSCCs.begin() + TargetIdx + 1)) +      if (IsConnected(*C)) +        Set.insert(C); +  }; + +  // Use a normal worklist to find which SCCs the target connects to. We still +  // bound the search based on the range in the postorder list we care about, +  // but because this is forward connectivity we just "recurse" through the +  // edges. +  auto ComputeTargetConnectedSet = [&](SmallPtrSetImpl<RefSCC *> &Set) { +    Set.insert(this); +    SmallVector<RefSCC *, 4> Worklist; +    Worklist.push_back(this); +    do { +      RefSCC &RC = *Worklist.pop_back_val(); +      for (SCC &C : RC) +        for (Node &N : C) +          for (Edge &E : *N) { +            RefSCC &EdgeRC = *G->lookupRefSCC(E.getNode()); +            if (G->getRefSCCIndex(EdgeRC) <= SourceIdx) +              // Not in the postorder sequence between source and target. +              continue; + +            if (Set.insert(&EdgeRC).second) +              Worklist.push_back(&EdgeRC); +          } +    } while (!Worklist.empty()); +  }; + +  // Use a generic helper to update the postorder sequence of RefSCCs and return +  // a range of any RefSCCs connected into a cycle by inserting this edge. This +  // routine will also take care of updating the indices into the postorder +  // sequence. +  iterator_range<SmallVectorImpl<RefSCC *>::iterator> MergeRange = +      updatePostorderSequenceForEdgeInsertion( +          SourceC, *this, G->PostOrderRefSCCs, G->RefSCCIndices, +          ComputeSourceConnectedSet, ComputeTargetConnectedSet); + +  // Build a set so we can do fast tests for whether a RefSCC will end up as +  // part of the merged RefSCC. +  SmallPtrSet<RefSCC *, 16> MergeSet(MergeRange.begin(), MergeRange.end()); + +  // This RefSCC will always be part of that set, so just insert it here. +  MergeSet.insert(this); + +  // Now that we have identified all of the SCCs which need to be merged into +  // a connected set with the inserted edge, merge all of them into this SCC. +  SmallVector<SCC *, 16> MergedSCCs; +  int SCCIndex = 0; +  for (RefSCC *RC : MergeRange) { +    assert(RC != this && "We're merging into the target RefSCC, so it " +                         "shouldn't be in the range."); + +    // Walk the inner SCCs to update their up-pointer and walk all the edges to +    // update any parent sets. +    // FIXME: We should try to find a way to avoid this (rather expensive) edge +    // walk by updating the parent sets in some other manner. +    for (SCC &InnerC : *RC) { +      InnerC.OuterRefSCC = this; +      SCCIndices[&InnerC] = SCCIndex++; +      for (Node &N : InnerC) +        G->SCCMap[&N] = &InnerC; +    } + +    // Now merge in the SCCs. We can actually move here so try to reuse storage +    // the first time through. +    if (MergedSCCs.empty()) +      MergedSCCs = std::move(RC->SCCs); +    else +      MergedSCCs.append(RC->SCCs.begin(), RC->SCCs.end()); +    RC->SCCs.clear(); +    DeletedRefSCCs.push_back(RC); +  } + +  // Append our original SCCs to the merged list and move it into place. +  for (SCC &InnerC : *this) +    SCCIndices[&InnerC] = SCCIndex++; +  MergedSCCs.append(SCCs.begin(), SCCs.end()); +  SCCs = std::move(MergedSCCs); + +  // Remove the merged away RefSCCs from the post order sequence. +  for (RefSCC *RC : MergeRange) +    G->RefSCCIndices.erase(RC); +  int IndexOffset = MergeRange.end() - MergeRange.begin(); +  auto EraseEnd = +      G->PostOrderRefSCCs.erase(MergeRange.begin(), MergeRange.end()); +  for (RefSCC *RC : make_range(EraseEnd, G->PostOrderRefSCCs.end())) +    G->RefSCCIndices[RC] -= IndexOffset; + +  // At this point we have a merged RefSCC with a post-order SCCs list, just +  // connect the nodes to form the new edge. +  SourceN->insertEdgeInternal(TargetN, Edge::Ref); + +  // We return the list of SCCs which were merged so that callers can +  // invalidate any data they have associated with those SCCs. Note that these +  // SCCs are no longer in an interesting state (they are totally empty) but +  // the pointers will remain stable for the life of the graph itself. +  return DeletedRefSCCs; +} + +void LazyCallGraph::RefSCC::removeOutgoingEdge(Node &SourceN, Node &TargetN) { +  assert(G->lookupRefSCC(SourceN) == this && +         "The source must be a member of this RefSCC."); +  assert(G->lookupRefSCC(TargetN) != this && +         "The target must not be a member of this RefSCC"); + +#ifndef NDEBUG +  // In a debug build, verify the RefSCC is valid to start with and when this +  // routine finishes. +  verify(); +  auto VerifyOnExit = make_scope_exit([&]() { verify(); }); +#endif + +  // First remove it from the node. +  bool Removed = SourceN->removeEdgeInternal(TargetN); +  (void)Removed; +  assert(Removed && "Target not in the edge set for this caller?"); +} + +SmallVector<LazyCallGraph::RefSCC *, 1> +LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, +                                             ArrayRef<Node *> TargetNs) { +  // We return a list of the resulting *new* RefSCCs in post-order. +  SmallVector<RefSCC *, 1> Result; + +#ifndef NDEBUG +  // In a debug build, verify the RefSCC is valid to start with and that either +  // we return an empty list of result RefSCCs and this RefSCC remains valid, +  // or we return new RefSCCs and this RefSCC is dead. +  verify(); +  auto VerifyOnExit = make_scope_exit([&]() { +    // If we didn't replace our RefSCC with new ones, check that this one +    // remains valid. +    if (G) +      verify(); +  }); +#endif + +  // First remove the actual edges. +  for (Node *TargetN : TargetNs) { +    assert(!(*SourceN)[*TargetN].isCall() && +           "Cannot remove a call edge, it must first be made a ref edge"); + +    bool Removed = SourceN->removeEdgeInternal(*TargetN); +    (void)Removed; +    assert(Removed && "Target not in the edge set for this caller?"); +  } + +  // Direct self references don't impact the ref graph at all. +  if (llvm::all_of(TargetNs, +                   [&](Node *TargetN) { return &SourceN == TargetN; })) +    return Result; + +  // If all targets are in the same SCC as the source, because no call edges +  // were removed there is no RefSCC structure change. +  SCC &SourceC = *G->lookupSCC(SourceN); +  if (llvm::all_of(TargetNs, [&](Node *TargetN) { +        return G->lookupSCC(*TargetN) == &SourceC; +      })) +    return Result; + +  // We build somewhat synthetic new RefSCCs by providing a postorder mapping +  // for each inner SCC. We store these inside the low-link field of the nodes +  // rather than associated with SCCs because this saves a round-trip through +  // the node->SCC map and in the common case, SCCs are small. We will verify +  // that we always give the same number to every node in the SCC such that +  // these are equivalent. +  int PostOrderNumber = 0; + +  // Reset all the other nodes to prepare for a DFS over them, and add them to +  // our worklist. +  SmallVector<Node *, 8> Worklist; +  for (SCC *C : SCCs) { +    for (Node &N : *C) +      N.DFSNumber = N.LowLink = 0; + +    Worklist.append(C->Nodes.begin(), C->Nodes.end()); +  } + +  // Track the number of nodes in this RefSCC so that we can quickly recognize +  // an important special case of the edge removal not breaking the cycle of +  // this RefSCC. +  const int NumRefSCCNodes = Worklist.size(); + +  SmallVector<std::pair<Node *, EdgeSequence::iterator>, 4> DFSStack; +  SmallVector<Node *, 4> PendingRefSCCStack; +  do { +    assert(DFSStack.empty() && +           "Cannot begin a new root with a non-empty DFS stack!"); +    assert(PendingRefSCCStack.empty() && +           "Cannot begin a new root with pending nodes for an SCC!"); + +    Node *RootN = Worklist.pop_back_val(); +    // Skip any nodes we've already reached in the DFS. +    if (RootN->DFSNumber != 0) { +      assert(RootN->DFSNumber == -1 && +             "Shouldn't have any mid-DFS root nodes!"); +      continue; +    } + +    RootN->DFSNumber = RootN->LowLink = 1; +    int NextDFSNumber = 2; + +    DFSStack.push_back({RootN, (*RootN)->begin()}); +    do { +      Node *N; +      EdgeSequence::iterator I; +      std::tie(N, I) = DFSStack.pop_back_val(); +      auto E = (*N)->end(); + +      assert(N->DFSNumber != 0 && "We should always assign a DFS number " +                                  "before processing a node."); + +      while (I != E) { +        Node &ChildN = I->getNode(); +        if (ChildN.DFSNumber == 0) { +          // Mark that we should start at this child when next this node is the +          // top of the stack. We don't start at the next child to ensure this +          // child's lowlink is reflected. +          DFSStack.push_back({N, I}); + +          // Continue, resetting to the child node. +          ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++; +          N = &ChildN; +          I = ChildN->begin(); +          E = ChildN->end(); +          continue; +        } +        if (ChildN.DFSNumber == -1) { +          // If this child isn't currently in this RefSCC, no need to process +          // it. +          ++I; +          continue; +        } + +        // Track the lowest link of the children, if any are still in the stack. +        // Any child not on the stack will have a LowLink of -1. +        assert(ChildN.LowLink != 0 && +               "Low-link must not be zero with a non-zero DFS number."); +        if (ChildN.LowLink >= 0 && ChildN.LowLink < N->LowLink) +          N->LowLink = ChildN.LowLink; +        ++I; +      } + +      // We've finished processing N and its descendants, put it on our pending +      // stack to eventually get merged into a RefSCC. +      PendingRefSCCStack.push_back(N); + +      // If this node is linked to some lower entry, continue walking up the +      // stack. +      if (N->LowLink != N->DFSNumber) { +        assert(!DFSStack.empty() && +               "We never found a viable root for a RefSCC to pop off!"); +        continue; +      } + +      // Otherwise, form a new RefSCC from the top of the pending node stack. +      int RefSCCNumber = PostOrderNumber++; +      int RootDFSNumber = N->DFSNumber; + +      // Find the range of the node stack by walking down until we pass the +      // root DFS number. Update the DFS numbers and low link numbers in the +      // process to avoid re-walking this list where possible. +      auto StackRI = find_if(reverse(PendingRefSCCStack), [&](Node *N) { +        if (N->DFSNumber < RootDFSNumber) +          // We've found the bottom. +          return true; + +        // Update this node and keep scanning. +        N->DFSNumber = -1; +        // Save the post-order number in the lowlink field so that we can use +        // it to map SCCs into new RefSCCs after we finish the DFS. +        N->LowLink = RefSCCNumber; +        return false; +      }); +      auto RefSCCNodes = make_range(StackRI.base(), PendingRefSCCStack.end()); + +      // If we find a cycle containing all nodes originally in this RefSCC then +      // the removal hasn't changed the structure at all. This is an important +      // special case and we can directly exit the entire routine more +      // efficiently as soon as we discover it. +      if (llvm::size(RefSCCNodes) == NumRefSCCNodes) { +        // Clear out the low link field as we won't need it. +        for (Node *N : RefSCCNodes) +          N->LowLink = -1; +        // Return the empty result immediately. +        return Result; +      } + +      // We've already marked the nodes internally with the RefSCC number so +      // just clear them off the stack and continue. +      PendingRefSCCStack.erase(RefSCCNodes.begin(), PendingRefSCCStack.end()); +    } while (!DFSStack.empty()); + +    assert(DFSStack.empty() && "Didn't flush the entire DFS stack!"); +    assert(PendingRefSCCStack.empty() && "Didn't flush all pending nodes!"); +  } while (!Worklist.empty()); + +  assert(PostOrderNumber > 1 && +         "Should never finish the DFS when the existing RefSCC remains valid!"); + +  // Otherwise we create a collection of new RefSCC nodes and build +  // a radix-sort style map from postorder number to these new RefSCCs. We then +  // append SCCs to each of these RefSCCs in the order they occurred in the +  // original SCCs container. +  for (int i = 0; i < PostOrderNumber; ++i) +    Result.push_back(G->createRefSCC(*G)); + +  // Insert the resulting postorder sequence into the global graph postorder +  // sequence before the current RefSCC in that sequence, and then remove the +  // current one. +  // +  // FIXME: It'd be nice to change the APIs so that we returned an iterator +  // range over the global postorder sequence and generally use that sequence +  // rather than building a separate result vector here. +  int Idx = G->getRefSCCIndex(*this); +  G->PostOrderRefSCCs.erase(G->PostOrderRefSCCs.begin() + Idx); +  G->PostOrderRefSCCs.insert(G->PostOrderRefSCCs.begin() + Idx, Result.begin(), +                             Result.end()); +  for (int i : seq<int>(Idx, G->PostOrderRefSCCs.size())) +    G->RefSCCIndices[G->PostOrderRefSCCs[i]] = i; + +  for (SCC *C : SCCs) { +    // We store the SCC number in the node's low-link field above. +    int SCCNumber = C->begin()->LowLink; +    // Clear out all of the SCC's node's low-link fields now that we're done +    // using them as side-storage. +    for (Node &N : *C) { +      assert(N.LowLink == SCCNumber && +             "Cannot have different numbers for nodes in the same SCC!"); +      N.LowLink = -1; +    } + +    RefSCC &RC = *Result[SCCNumber]; +    int SCCIndex = RC.SCCs.size(); +    RC.SCCs.push_back(C); +    RC.SCCIndices[C] = SCCIndex; +    C->OuterRefSCC = &RC; +  } + +  // Now that we've moved things into the new RefSCCs, clear out our current +  // one. +  G = nullptr; +  SCCs.clear(); +  SCCIndices.clear(); + +#ifndef NDEBUG +  // Verify the new RefSCCs we've built. +  for (RefSCC *RC : Result) +    RC->verify(); +#endif + +  // Return the new list of SCCs. +  return Result; +} + +void LazyCallGraph::RefSCC::handleTrivialEdgeInsertion(Node &SourceN, +                                                       Node &TargetN) { +  // The only trivial case that requires any graph updates is when we add new +  // ref edge and may connect different RefSCCs along that path. This is only +  // because of the parents set. Every other part of the graph remains constant +  // after this edge insertion. +  assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); +  RefSCC &TargetRC = *G->lookupRefSCC(TargetN); +  if (&TargetRC == this) +    return; + +#ifdef EXPENSIVE_CHECKS +  assert(TargetRC.isDescendantOf(*this) && +         "Target must be a descendant of the Source."); +#endif +} + +void LazyCallGraph::RefSCC::insertTrivialCallEdge(Node &SourceN, +                                                  Node &TargetN) { +#ifndef NDEBUG +  // Check that the RefSCC is still valid when we finish. +  auto ExitVerifier = make_scope_exit([this] { verify(); }); + +#ifdef EXPENSIVE_CHECKS +  // Check that we aren't breaking some invariants of the SCC graph. Note that +  // this is quadratic in the number of edges in the call graph! +  SCC &SourceC = *G->lookupSCC(SourceN); +  SCC &TargetC = *G->lookupSCC(TargetN); +  if (&SourceC != &TargetC) +    assert(SourceC.isAncestorOf(TargetC) && +           "Call edge is not trivial in the SCC graph!"); +#endif // EXPENSIVE_CHECKS +#endif // NDEBUG + +  // First insert it into the source or find the existing edge. +  auto InsertResult = +      SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()}); +  if (!InsertResult.second) { +    // Already an edge, just update it. +    Edge &E = SourceN->Edges[InsertResult.first->second]; +    if (E.isCall()) +      return; // Nothing to do! +    E.setKind(Edge::Call); +  } else { +    // Create the new edge. +    SourceN->Edges.emplace_back(TargetN, Edge::Call); +  } + +  // Now that we have the edge, handle the graph fallout. +  handleTrivialEdgeInsertion(SourceN, TargetN); +} + +void LazyCallGraph::RefSCC::insertTrivialRefEdge(Node &SourceN, Node &TargetN) { +#ifndef NDEBUG +  // Check that the RefSCC is still valid when we finish. +  auto ExitVerifier = make_scope_exit([this] { verify(); }); + +#ifdef EXPENSIVE_CHECKS +  // Check that we aren't breaking some invariants of the RefSCC graph. +  RefSCC &SourceRC = *G->lookupRefSCC(SourceN); +  RefSCC &TargetRC = *G->lookupRefSCC(TargetN); +  if (&SourceRC != &TargetRC) +    assert(SourceRC.isAncestorOf(TargetRC) && +           "Ref edge is not trivial in the RefSCC graph!"); +#endif // EXPENSIVE_CHECKS +#endif // NDEBUG + +  // First insert it into the source or find the existing edge. +  auto InsertResult = +      SourceN->EdgeIndexMap.insert({&TargetN, SourceN->Edges.size()}); +  if (!InsertResult.second) +    // Already an edge, we're done. +    return; + +  // Create the new edge. +  SourceN->Edges.emplace_back(TargetN, Edge::Ref); + +  // Now that we have the edge, handle the graph fallout. +  handleTrivialEdgeInsertion(SourceN, TargetN); +} + +void LazyCallGraph::RefSCC::replaceNodeFunction(Node &N, Function &NewF) { +  Function &OldF = N.getFunction(); + +#ifndef NDEBUG +  // Check that the RefSCC is still valid when we finish. +  auto ExitVerifier = make_scope_exit([this] { verify(); }); + +  assert(G->lookupRefSCC(N) == this && +         "Cannot replace the function of a node outside this RefSCC."); + +  assert(G->NodeMap.find(&NewF) == G->NodeMap.end() && +         "Must not have already walked the new function!'"); + +  // It is important that this replacement not introduce graph changes so we +  // insist that the caller has already removed every use of the original +  // function and that all uses of the new function correspond to existing +  // edges in the graph. The common and expected way to use this is when +  // replacing the function itself in the IR without changing the call graph +  // shape and just updating the analysis based on that. +  assert(&OldF != &NewF && "Cannot replace a function with itself!"); +  assert(OldF.use_empty() && +         "Must have moved all uses from the old function to the new!"); +#endif + +  N.replaceFunction(NewF); + +  // Update various call graph maps. +  G->NodeMap.erase(&OldF); +  G->NodeMap[&NewF] = &N; +} + +void LazyCallGraph::insertEdge(Node &SourceN, Node &TargetN, Edge::Kind EK) { +  assert(SCCMap.empty() && +         "This method cannot be called after SCCs have been formed!"); + +  return SourceN->insertEdgeInternal(TargetN, EK); +} + +void LazyCallGraph::removeEdge(Node &SourceN, Node &TargetN) { +  assert(SCCMap.empty() && +         "This method cannot be called after SCCs have been formed!"); + +  bool Removed = SourceN->removeEdgeInternal(TargetN); +  (void)Removed; +  assert(Removed && "Target not in the edge set for this caller?"); +} + +void LazyCallGraph::removeDeadFunction(Function &F) { +  // FIXME: This is unnecessarily restrictive. We should be able to remove +  // functions which recursively call themselves. +  assert(F.use_empty() && +         "This routine should only be called on trivially dead functions!"); + +  // We shouldn't remove library functions as they are never really dead while +  // the call graph is in use -- every function definition refers to them. +  assert(!isLibFunction(F) && +         "Must not remove lib functions from the call graph!"); + +  auto NI = NodeMap.find(&F); +  if (NI == NodeMap.end()) +    // Not in the graph at all! +    return; + +  Node &N = *NI->second; +  NodeMap.erase(NI); + +  // Remove this from the entry edges if present. +  EntryEdges.removeEdgeInternal(N); + +  if (SCCMap.empty()) { +    // No SCCs have been formed, so removing this is fine and there is nothing +    // else necessary at this point but clearing out the node. +    N.clear(); +    return; +  } + +  // Cannot remove a function which has yet to be visited in the DFS walk, so +  // if we have a node at all then we must have an SCC and RefSCC. +  auto CI = SCCMap.find(&N); +  assert(CI != SCCMap.end() && +         "Tried to remove a node without an SCC after DFS walk started!"); +  SCC &C = *CI->second; +  SCCMap.erase(CI); +  RefSCC &RC = C.getOuterRefSCC(); + +  // This node must be the only member of its SCC as it has no callers, and +  // that SCC must be the only member of a RefSCC as it has no references. +  // Validate these properties first. +  assert(C.size() == 1 && "Dead functions must be in a singular SCC"); +  assert(RC.size() == 1 && "Dead functions must be in a singular RefSCC"); + +  auto RCIndexI = RefSCCIndices.find(&RC); +  int RCIndex = RCIndexI->second; +  PostOrderRefSCCs.erase(PostOrderRefSCCs.begin() + RCIndex); +  RefSCCIndices.erase(RCIndexI); +  for (int i = RCIndex, Size = PostOrderRefSCCs.size(); i < Size; ++i) +    RefSCCIndices[PostOrderRefSCCs[i]] = i; + +  // Finally clear out all the data structures from the node down through the +  // components. +  N.clear(); +  N.G = nullptr; +  N.F = nullptr; +  C.clear(); +  RC.clear(); +  RC.G = nullptr; + +  // Nothing to delete as all the objects are allocated in stable bump pointer +  // allocators. +} + +LazyCallGraph::Node &LazyCallGraph::insertInto(Function &F, Node *&MappedN) { +  return *new (MappedN = BPA.Allocate()) Node(*this, F); +} + +void LazyCallGraph::updateGraphPtrs() { +  // Walk the node map to update their graph pointers. While this iterates in +  // an unstable order, the order has no effect so it remains correct. +  for (auto &FunctionNodePair : NodeMap) +    FunctionNodePair.second->G = this; + +  for (auto *RC : PostOrderRefSCCs) +    RC->G = this; +} + +template <typename RootsT, typename GetBeginT, typename GetEndT, +          typename GetNodeT, typename FormSCCCallbackT> +void LazyCallGraph::buildGenericSCCs(RootsT &&Roots, GetBeginT &&GetBegin, +                                     GetEndT &&GetEnd, GetNodeT &&GetNode, +                                     FormSCCCallbackT &&FormSCC) { +  using EdgeItT = decltype(GetBegin(std::declval<Node &>())); + +  SmallVector<std::pair<Node *, EdgeItT>, 16> DFSStack; +  SmallVector<Node *, 16> PendingSCCStack; + +  // Scan down the stack and DFS across the call edges. +  for (Node *RootN : Roots) { +    assert(DFSStack.empty() && +           "Cannot begin a new root with a non-empty DFS stack!"); +    assert(PendingSCCStack.empty() && +           "Cannot begin a new root with pending nodes for an SCC!"); + +    // Skip any nodes we've already reached in the DFS. +    if (RootN->DFSNumber != 0) { +      assert(RootN->DFSNumber == -1 && +             "Shouldn't have any mid-DFS root nodes!"); +      continue; +    } + +    RootN->DFSNumber = RootN->LowLink = 1; +    int NextDFSNumber = 2; + +    DFSStack.push_back({RootN, GetBegin(*RootN)}); +    do { +      Node *N; +      EdgeItT I; +      std::tie(N, I) = DFSStack.pop_back_val(); +      auto E = GetEnd(*N); +      while (I != E) { +        Node &ChildN = GetNode(I); +        if (ChildN.DFSNumber == 0) { +          // We haven't yet visited this child, so descend, pushing the current +          // node onto the stack. +          DFSStack.push_back({N, I}); + +          ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; +          N = &ChildN; +          I = GetBegin(*N); +          E = GetEnd(*N); +          continue; +        } + +        // If the child has already been added to some child component, it +        // couldn't impact the low-link of this parent because it isn't +        // connected, and thus its low-link isn't relevant so skip it. +        if (ChildN.DFSNumber == -1) { +          ++I; +          continue; +        } + +        // Track the lowest linked child as the lowest link for this node. +        assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); +        if (ChildN.LowLink < N->LowLink) +          N->LowLink = ChildN.LowLink; + +        // Move to the next edge. +        ++I; +      } + +      // We've finished processing N and its descendants, put it on our pending +      // SCC stack to eventually get merged into an SCC of nodes. +      PendingSCCStack.push_back(N); + +      // If this node is linked to some lower entry, continue walking up the +      // stack. +      if (N->LowLink != N->DFSNumber) +        continue; + +      // Otherwise, we've completed an SCC. Append it to our post order list of +      // SCCs. +      int RootDFSNumber = N->DFSNumber; +      // Find the range of the node stack by walking down until we pass the +      // root DFS number. +      auto SCCNodes = make_range( +          PendingSCCStack.rbegin(), +          find_if(reverse(PendingSCCStack), [RootDFSNumber](const Node *N) { +            return N->DFSNumber < RootDFSNumber; +          })); +      // Form a new SCC out of these nodes and then clear them off our pending +      // stack. +      FormSCC(SCCNodes); +      PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end()); +    } while (!DFSStack.empty()); +  } +} + +/// Build the internal SCCs for a RefSCC from a sequence of nodes. +/// +/// Appends the SCCs to the provided vector and updates the map with their +/// indices. Both the vector and map must be empty when passed into this +/// routine. +void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) { +  assert(RC.SCCs.empty() && "Already built SCCs!"); +  assert(RC.SCCIndices.empty() && "Already mapped SCC indices!"); + +  for (Node *N : Nodes) { +    assert(N->LowLink >= (*Nodes.begin())->LowLink && +           "We cannot have a low link in an SCC lower than its root on the " +           "stack!"); + +    // This node will go into the next RefSCC, clear out its DFS and low link +    // as we scan. +    N->DFSNumber = N->LowLink = 0; +  } + +  // Each RefSCC contains a DAG of the call SCCs. To build these, we do +  // a direct walk of the call edges using Tarjan's algorithm. We reuse the +  // internal storage as we won't need it for the outer graph's DFS any longer. +  buildGenericSCCs( +      Nodes, [](Node &N) { return N->call_begin(); }, +      [](Node &N) { return N->call_end(); }, +      [](EdgeSequence::call_iterator I) -> Node & { return I->getNode(); }, +      [this, &RC](node_stack_range Nodes) { +        RC.SCCs.push_back(createSCC(RC, Nodes)); +        for (Node &N : *RC.SCCs.back()) { +          N.DFSNumber = N.LowLink = -1; +          SCCMap[&N] = RC.SCCs.back(); +        } +      }); + +  // Wire up the SCC indices. +  for (int i = 0, Size = RC.SCCs.size(); i < Size; ++i) +    RC.SCCIndices[RC.SCCs[i]] = i; +} + +void LazyCallGraph::buildRefSCCs() { +  if (EntryEdges.empty() || !PostOrderRefSCCs.empty()) +    // RefSCCs are either non-existent or already built! +    return; + +  assert(RefSCCIndices.empty() && "Already mapped RefSCC indices!"); + +  SmallVector<Node *, 16> Roots; +  for (Edge &E : *this) +    Roots.push_back(&E.getNode()); + +  // The roots will be popped of a stack, so use reverse to get a less +  // surprising order. This doesn't change any of the semantics anywhere. +  std::reverse(Roots.begin(), Roots.end()); + +  buildGenericSCCs( +      Roots, +      [](Node &N) { +        // We need to populate each node as we begin to walk its edges. +        N.populate(); +        return N->begin(); +      }, +      [](Node &N) { return N->end(); }, +      [](EdgeSequence::iterator I) -> Node & { return I->getNode(); }, +      [this](node_stack_range Nodes) { +        RefSCC *NewRC = createRefSCC(*this); +        buildSCCs(*NewRC, Nodes); + +        // Push the new node into the postorder list and remember its position +        // in the index map. +        bool Inserted = +            RefSCCIndices.insert({NewRC, PostOrderRefSCCs.size()}).second; +        (void)Inserted; +        assert(Inserted && "Cannot already have this RefSCC in the index map!"); +        PostOrderRefSCCs.push_back(NewRC); +#ifndef NDEBUG +        NewRC->verify(); +#endif +      }); +} + +AnalysisKey LazyCallGraphAnalysis::Key; + +LazyCallGraphPrinterPass::LazyCallGraphPrinterPass(raw_ostream &OS) : OS(OS) {} + +static void printNode(raw_ostream &OS, LazyCallGraph::Node &N) { +  OS << "  Edges in function: " << N.getFunction().getName() << "\n"; +  for (LazyCallGraph::Edge &E : N.populate()) +    OS << "    " << (E.isCall() ? "call" : "ref ") << " -> " +       << E.getFunction().getName() << "\n"; + +  OS << "\n"; +} + +static void printSCC(raw_ostream &OS, LazyCallGraph::SCC &C) { +  ptrdiff_t Size = size(C); +  OS << "    SCC with " << Size << " functions:\n"; + +  for (LazyCallGraph::Node &N : C) +    OS << "      " << N.getFunction().getName() << "\n"; +} + +static void printRefSCC(raw_ostream &OS, LazyCallGraph::RefSCC &C) { +  ptrdiff_t Size = size(C); +  OS << "  RefSCC with " << Size << " call SCCs:\n"; + +  for (LazyCallGraph::SCC &InnerC : C) +    printSCC(OS, InnerC); + +  OS << "\n"; +} + +PreservedAnalyses LazyCallGraphPrinterPass::run(Module &M, +                                                ModuleAnalysisManager &AM) { +  LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M); + +  OS << "Printing the call graph for module: " << M.getModuleIdentifier() +     << "\n\n"; + +  for (Function &F : M) +    printNode(OS, G.get(F)); + +  G.buildRefSCCs(); +  for (LazyCallGraph::RefSCC &C : G.postorder_ref_sccs()) +    printRefSCC(OS, C); + +  return PreservedAnalyses::all(); +} + +LazyCallGraphDOTPrinterPass::LazyCallGraphDOTPrinterPass(raw_ostream &OS) +    : OS(OS) {} + +static void printNodeDOT(raw_ostream &OS, LazyCallGraph::Node &N) { +  std::string Name = "\"" + DOT::EscapeString(N.getFunction().getName()) + "\""; + +  for (LazyCallGraph::Edge &E : N.populate()) { +    OS << "  " << Name << " -> \"" +       << DOT::EscapeString(E.getFunction().getName()) << "\""; +    if (!E.isCall()) // It is a ref edge. +      OS << " [style=dashed,label=\"ref\"]"; +    OS << ";\n"; +  } + +  OS << "\n"; +} + +PreservedAnalyses LazyCallGraphDOTPrinterPass::run(Module &M, +                                                   ModuleAnalysisManager &AM) { +  LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M); + +  OS << "digraph \"" << DOT::EscapeString(M.getModuleIdentifier()) << "\" {\n"; + +  for (Function &F : M) +    printNodeDOT(OS, G.get(F)); + +  OS << "}\n"; + +  return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Analysis/LazyValueInfo.cpp b/contrib/llvm/lib/Analysis/LazyValueInfo.cpp new file mode 100644 index 000000000000..ee0148e0d795 --- /dev/null +++ b/contrib/llvm/lib/Analysis/LazyValueInfo.cpp @@ -0,0 +1,1920 @@ +//===- LazyValueInfo.cpp - Value constraint analysis ------------*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the interface for lazy computation of value constraint +// information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/ValueLattice.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/raw_ostream.h" +#include <map> +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "lazy-value-info" + +// This is the number of worklist items we will process to try to discover an +// answer for a given value. +static const unsigned MaxProcessedPerValue = 500; + +char LazyValueInfoWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(LazyValueInfoWrapperPass, "lazy-value-info", +                "Lazy Value Information Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LazyValueInfoWrapperPass, "lazy-value-info", +                "Lazy Value Information Analysis", false, true) + +namespace llvm { +  FunctionPass *createLazyValueInfoPass() { return new LazyValueInfoWrapperPass(); } +} + +AnalysisKey LazyValueAnalysis::Key; + +/// Returns true if this lattice value represents at most one possible value. +/// This is as precise as any lattice value can get while still representing +/// reachable code. +static bool hasSingleValue(const ValueLatticeElement &Val) { +  if (Val.isConstantRange() && +      Val.getConstantRange().isSingleElement()) +    // Integer constants are single element ranges +    return true; +  if (Val.isConstant()) +    // Non integer constants +    return true; +  return false; +} + +/// Combine two sets of facts about the same value into a single set of +/// facts.  Note that this method is not suitable for merging facts along +/// different paths in a CFG; that's what the mergeIn function is for.  This +/// is for merging facts gathered about the same value at the same location +/// through two independent means. +/// Notes: +/// * This method does not promise to return the most precise possible lattice +///   value implied by A and B.  It is allowed to return any lattice element +///   which is at least as strong as *either* A or B (unless our facts +///   conflict, see below). +/// * Due to unreachable code, the intersection of two lattice values could be +///   contradictory.  If this happens, we return some valid lattice value so as +///   not confuse the rest of LVI.  Ideally, we'd always return Undefined, but +///   we do not make this guarantee.  TODO: This would be a useful enhancement. +static ValueLatticeElement intersect(const ValueLatticeElement &A, +                                     const ValueLatticeElement &B) { +  // Undefined is the strongest state.  It means the value is known to be along +  // an unreachable path. +  if (A.isUndefined()) +    return A; +  if (B.isUndefined()) +    return B; + +  // If we gave up for one, but got a useable fact from the other, use it. +  if (A.isOverdefined()) +    return B; +  if (B.isOverdefined()) +    return A; + +  // Can't get any more precise than constants. +  if (hasSingleValue(A)) +    return A; +  if (hasSingleValue(B)) +    return B; + +  // Could be either constant range or not constant here. +  if (!A.isConstantRange() || !B.isConstantRange()) { +    // TODO: Arbitrary choice, could be improved +    return A; +  } + +  // Intersect two constant ranges +  ConstantRange Range = +    A.getConstantRange().intersectWith(B.getConstantRange()); +  // Note: An empty range is implicitly converted to overdefined internally. +  // TODO: We could instead use Undefined here since we've proven a conflict +  // and thus know this path must be unreachable. +  return ValueLatticeElement::getRange(std::move(Range)); +} + +//===----------------------------------------------------------------------===// +//                          LazyValueInfoCache Decl +//===----------------------------------------------------------------------===// + +namespace { +  /// A callback value handle updates the cache when values are erased. +  class LazyValueInfoCache; +  struct LVIValueHandle final : public CallbackVH { +    // Needs to access getValPtr(), which is protected. +    friend struct DenseMapInfo<LVIValueHandle>; + +    LazyValueInfoCache *Parent; + +    LVIValueHandle(Value *V, LazyValueInfoCache *P) +      : CallbackVH(V), Parent(P) { } + +    void deleted() override; +    void allUsesReplacedWith(Value *V) override { +      deleted(); +    } +  }; +} // end anonymous namespace + +namespace { +  /// This is the cache kept by LazyValueInfo which +  /// maintains information about queries across the clients' queries. +  class LazyValueInfoCache { +    /// This is all of the cached block information for exactly one Value*. +    /// The entries are sorted by the BasicBlock* of the +    /// entries, allowing us to do a lookup with a binary search. +    /// Over-defined lattice values are recorded in OverDefinedCache to reduce +    /// memory overhead. +    struct ValueCacheEntryTy { +      ValueCacheEntryTy(Value *V, LazyValueInfoCache *P) : Handle(V, P) {} +      LVIValueHandle Handle; +      SmallDenseMap<PoisoningVH<BasicBlock>, ValueLatticeElement, 4> BlockVals; +    }; + +    /// This tracks, on a per-block basis, the set of values that are +    /// over-defined at the end of that block. +    typedef DenseMap<PoisoningVH<BasicBlock>, SmallPtrSet<Value *, 4>> +        OverDefinedCacheTy; +    /// Keep track of all blocks that we have ever seen, so we +    /// don't spend time removing unused blocks from our caches. +    DenseSet<PoisoningVH<BasicBlock> > SeenBlocks; + +    /// This is all of the cached information for all values, +    /// mapped from Value* to key information. +    DenseMap<Value *, std::unique_ptr<ValueCacheEntryTy>> ValueCache; +    OverDefinedCacheTy OverDefinedCache; + + +  public: +    void insertResult(Value *Val, BasicBlock *BB, +                      const ValueLatticeElement &Result) { +      SeenBlocks.insert(BB); + +      // Insert over-defined values into their own cache to reduce memory +      // overhead. +      if (Result.isOverdefined()) +        OverDefinedCache[BB].insert(Val); +      else { +        auto It = ValueCache.find_as(Val); +        if (It == ValueCache.end()) { +          ValueCache[Val] = make_unique<ValueCacheEntryTy>(Val, this); +          It = ValueCache.find_as(Val); +          assert(It != ValueCache.end() && "Val was just added to the map!"); +        } +        It->second->BlockVals[BB] = Result; +      } +    } + +    bool isOverdefined(Value *V, BasicBlock *BB) const { +      auto ODI = OverDefinedCache.find(BB); + +      if (ODI == OverDefinedCache.end()) +        return false; + +      return ODI->second.count(V); +    } + +    bool hasCachedValueInfo(Value *V, BasicBlock *BB) const { +      if (isOverdefined(V, BB)) +        return true; + +      auto I = ValueCache.find_as(V); +      if (I == ValueCache.end()) +        return false; + +      return I->second->BlockVals.count(BB); +    } + +    ValueLatticeElement getCachedValueInfo(Value *V, BasicBlock *BB) const { +      if (isOverdefined(V, BB)) +        return ValueLatticeElement::getOverdefined(); + +      auto I = ValueCache.find_as(V); +      if (I == ValueCache.end()) +        return ValueLatticeElement(); +      auto BBI = I->second->BlockVals.find(BB); +      if (BBI == I->second->BlockVals.end()) +        return ValueLatticeElement(); +      return BBI->second; +    } + +    /// clear - Empty the cache. +    void clear() { +      SeenBlocks.clear(); +      ValueCache.clear(); +      OverDefinedCache.clear(); +    } + +    /// Inform the cache that a given value has been deleted. +    void eraseValue(Value *V); + +    /// This is part of the update interface to inform the cache +    /// that a block has been deleted. +    void eraseBlock(BasicBlock *BB); + +    /// Updates the cache to remove any influence an overdefined value in +    /// OldSucc might have (unless also overdefined in NewSucc).  This just +    /// flushes elements from the cache and does not add any. +    void threadEdgeImpl(BasicBlock *OldSucc,BasicBlock *NewSucc); + +    friend struct LVIValueHandle; +  }; +} + +void LazyValueInfoCache::eraseValue(Value *V) { +  for (auto I = OverDefinedCache.begin(), E = OverDefinedCache.end(); I != E;) { +    // Copy and increment the iterator immediately so we can erase behind +    // ourselves. +    auto Iter = I++; +    SmallPtrSetImpl<Value *> &ValueSet = Iter->second; +    ValueSet.erase(V); +    if (ValueSet.empty()) +      OverDefinedCache.erase(Iter); +  } + +  ValueCache.erase(V); +} + +void LVIValueHandle::deleted() { +  // This erasure deallocates *this, so it MUST happen after we're done +  // using any and all members of *this. +  Parent->eraseValue(*this); +} + +void LazyValueInfoCache::eraseBlock(BasicBlock *BB) { +  // Shortcut if we have never seen this block. +  DenseSet<PoisoningVH<BasicBlock> >::iterator I = SeenBlocks.find(BB); +  if (I == SeenBlocks.end()) +    return; +  SeenBlocks.erase(I); + +  auto ODI = OverDefinedCache.find(BB); +  if (ODI != OverDefinedCache.end()) +    OverDefinedCache.erase(ODI); + +  for (auto &I : ValueCache) +    I.second->BlockVals.erase(BB); +} + +void LazyValueInfoCache::threadEdgeImpl(BasicBlock *OldSucc, +                                        BasicBlock *NewSucc) { +  // When an edge in the graph has been threaded, values that we could not +  // determine a value for before (i.e. were marked overdefined) may be +  // possible to solve now. We do NOT try to proactively update these values. +  // Instead, we clear their entries from the cache, and allow lazy updating to +  // recompute them when needed. + +  // The updating process is fairly simple: we need to drop cached info +  // for all values that were marked overdefined in OldSucc, and for those same +  // values in any successor of OldSucc (except NewSucc) in which they were +  // also marked overdefined. +  std::vector<BasicBlock*> worklist; +  worklist.push_back(OldSucc); + +  auto I = OverDefinedCache.find(OldSucc); +  if (I == OverDefinedCache.end()) +    return; // Nothing to process here. +  SmallVector<Value *, 4> ValsToClear(I->second.begin(), I->second.end()); + +  // Use a worklist to perform a depth-first search of OldSucc's successors. +  // NOTE: We do not need a visited list since any blocks we have already +  // visited will have had their overdefined markers cleared already, and we +  // thus won't loop to their successors. +  while (!worklist.empty()) { +    BasicBlock *ToUpdate = worklist.back(); +    worklist.pop_back(); + +    // Skip blocks only accessible through NewSucc. +    if (ToUpdate == NewSucc) continue; + +    // If a value was marked overdefined in OldSucc, and is here too... +    auto OI = OverDefinedCache.find(ToUpdate); +    if (OI == OverDefinedCache.end()) +      continue; +    SmallPtrSetImpl<Value *> &ValueSet = OI->second; + +    bool changed = false; +    for (Value *V : ValsToClear) { +      if (!ValueSet.erase(V)) +        continue; + +      // If we removed anything, then we potentially need to update +      // blocks successors too. +      changed = true; + +      if (ValueSet.empty()) { +        OverDefinedCache.erase(OI); +        break; +      } +    } + +    if (!changed) continue; + +    worklist.insert(worklist.end(), succ_begin(ToUpdate), succ_end(ToUpdate)); +  } +} + + +namespace { +/// An assembly annotator class to print LazyValueCache information in +/// comments. +class LazyValueInfoImpl; +class LazyValueInfoAnnotatedWriter : public AssemblyAnnotationWriter { +  LazyValueInfoImpl *LVIImpl; +  // While analyzing which blocks we can solve values for, we need the dominator +  // information. Since this is an optional parameter in LVI, we require this +  // DomTreeAnalysis pass in the printer pass, and pass the dominator +  // tree to the LazyValueInfoAnnotatedWriter. +  DominatorTree &DT; + +public: +  LazyValueInfoAnnotatedWriter(LazyValueInfoImpl *L, DominatorTree &DTree) +      : LVIImpl(L), DT(DTree) {} + +  virtual void emitBasicBlockStartAnnot(const BasicBlock *BB, +                                        formatted_raw_ostream &OS); + +  virtual void emitInstructionAnnot(const Instruction *I, +                                    formatted_raw_ostream &OS); +}; +} +namespace { +  // The actual implementation of the lazy analysis and update.  Note that the +  // inheritance from LazyValueInfoCache is intended to be temporary while +  // splitting the code and then transitioning to a has-a relationship. +  class LazyValueInfoImpl { + +    /// Cached results from previous queries +    LazyValueInfoCache TheCache; + +    /// This stack holds the state of the value solver during a query. +    /// It basically emulates the callstack of the naive +    /// recursive value lookup process. +    SmallVector<std::pair<BasicBlock*, Value*>, 8> BlockValueStack; + +    /// Keeps track of which block-value pairs are in BlockValueStack. +    DenseSet<std::pair<BasicBlock*, Value*> > BlockValueSet; + +    /// Push BV onto BlockValueStack unless it's already in there. +    /// Returns true on success. +    bool pushBlockValue(const std::pair<BasicBlock *, Value *> &BV) { +      if (!BlockValueSet.insert(BV).second) +        return false;  // It's already in the stack. + +      LLVM_DEBUG(dbgs() << "PUSH: " << *BV.second << " in " +                        << BV.first->getName() << "\n"); +      BlockValueStack.push_back(BV); +      return true; +    } + +    AssumptionCache *AC;  ///< A pointer to the cache of @llvm.assume calls. +    const DataLayout &DL; ///< A mandatory DataLayout +    DominatorTree *DT;    ///< An optional DT pointer. +    DominatorTree *DisabledDT; ///< Stores DT if it's disabled. + +  ValueLatticeElement getBlockValue(Value *Val, BasicBlock *BB); +  bool getEdgeValue(Value *V, BasicBlock *F, BasicBlock *T, +                    ValueLatticeElement &Result, Instruction *CxtI = nullptr); +  bool hasBlockValue(Value *Val, BasicBlock *BB); + +  // These methods process one work item and may add more. A false value +  // returned means that the work item was not completely processed and must +  // be revisited after going through the new items. +  bool solveBlockValue(Value *Val, BasicBlock *BB); +  bool solveBlockValueImpl(ValueLatticeElement &Res, Value *Val, +                           BasicBlock *BB); +  bool solveBlockValueNonLocal(ValueLatticeElement &BBLV, Value *Val, +                               BasicBlock *BB); +  bool solveBlockValuePHINode(ValueLatticeElement &BBLV, PHINode *PN, +                              BasicBlock *BB); +  bool solveBlockValueSelect(ValueLatticeElement &BBLV, SelectInst *S, +                             BasicBlock *BB); +  bool solveBlockValueBinaryOp(ValueLatticeElement &BBLV, BinaryOperator *BBI, +                               BasicBlock *BB); +  bool solveBlockValueCast(ValueLatticeElement &BBLV, CastInst *CI, +                           BasicBlock *BB); +  void intersectAssumeOrGuardBlockValueConstantRange(Value *Val, +                                                     ValueLatticeElement &BBLV, +                                                     Instruction *BBI); + +  void solve(); + +  public: +    /// This is the query interface to determine the lattice +    /// value for the specified Value* at the end of the specified block. +    ValueLatticeElement getValueInBlock(Value *V, BasicBlock *BB, +                                        Instruction *CxtI = nullptr); + +    /// This is the query interface to determine the lattice +    /// value for the specified Value* at the specified instruction (generally +    /// from an assume intrinsic). +    ValueLatticeElement getValueAt(Value *V, Instruction *CxtI); + +    /// This is the query interface to determine the lattice +    /// value for the specified Value* that is true on the specified edge. +    ValueLatticeElement getValueOnEdge(Value *V, BasicBlock *FromBB, +                                       BasicBlock *ToBB, +                                   Instruction *CxtI = nullptr); + +    /// Complete flush all previously computed values +    void clear() { +      TheCache.clear(); +    } + +    /// Printing the LazyValueInfo Analysis. +    void printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) { +        LazyValueInfoAnnotatedWriter Writer(this, DTree); +        F.print(OS, &Writer); +    } + +    /// This is part of the update interface to inform the cache +    /// that a block has been deleted. +    void eraseBlock(BasicBlock *BB) { +      TheCache.eraseBlock(BB); +    } + +    /// Disables use of the DominatorTree within LVI. +    void disableDT() { +      if (DT) { +        assert(!DisabledDT && "Both DT and DisabledDT are not nullptr!"); +        std::swap(DT, DisabledDT); +      } +    } + +    /// Enables use of the DominatorTree within LVI. Does nothing if the class +    /// instance was initialized without a DT pointer. +    void enableDT() { +      if (DisabledDT) { +        assert(!DT && "Both DT and DisabledDT are not nullptr!"); +        std::swap(DT, DisabledDT); +      } +    } + +    /// This is the update interface to inform the cache that an edge from +    /// PredBB to OldSucc has been threaded to be from PredBB to NewSucc. +    void threadEdge(BasicBlock *PredBB,BasicBlock *OldSucc,BasicBlock *NewSucc); + +    LazyValueInfoImpl(AssumptionCache *AC, const DataLayout &DL, +                       DominatorTree *DT = nullptr) +        : AC(AC), DL(DL), DT(DT), DisabledDT(nullptr) {} +  }; +} // end anonymous namespace + + +void LazyValueInfoImpl::solve() { +  SmallVector<std::pair<BasicBlock *, Value *>, 8> StartingStack( +      BlockValueStack.begin(), BlockValueStack.end()); + +  unsigned processedCount = 0; +  while (!BlockValueStack.empty()) { +    processedCount++; +    // Abort if we have to process too many values to get a result for this one. +    // Because of the design of the overdefined cache currently being per-block +    // to avoid naming-related issues (IE it wants to try to give different +    // results for the same name in different blocks), overdefined results don't +    // get cached globally, which in turn means we will often try to rediscover +    // the same overdefined result again and again.  Once something like +    // PredicateInfo is used in LVI or CVP, we should be able to make the +    // overdefined cache global, and remove this throttle. +    if (processedCount > MaxProcessedPerValue) { +      LLVM_DEBUG( +          dbgs() << "Giving up on stack because we are getting too deep\n"); +      // Fill in the original values +      while (!StartingStack.empty()) { +        std::pair<BasicBlock *, Value *> &e = StartingStack.back(); +        TheCache.insertResult(e.second, e.first, +                              ValueLatticeElement::getOverdefined()); +        StartingStack.pop_back(); +      } +      BlockValueSet.clear(); +      BlockValueStack.clear(); +      return; +    } +    std::pair<BasicBlock *, Value *> e = BlockValueStack.back(); +    assert(BlockValueSet.count(e) && "Stack value should be in BlockValueSet!"); + +    if (solveBlockValue(e.second, e.first)) { +      // The work item was completely processed. +      assert(BlockValueStack.back() == e && "Nothing should have been pushed!"); +      assert(TheCache.hasCachedValueInfo(e.second, e.first) && +             "Result should be in cache!"); + +      LLVM_DEBUG( +          dbgs() << "POP " << *e.second << " in " << e.first->getName() << " = " +                 << TheCache.getCachedValueInfo(e.second, e.first) << "\n"); + +      BlockValueStack.pop_back(); +      BlockValueSet.erase(e); +    } else { +      // More work needs to be done before revisiting. +      assert(BlockValueStack.back() != e && "Stack should have been pushed!"); +    } +  } +} + +bool LazyValueInfoImpl::hasBlockValue(Value *Val, BasicBlock *BB) { +  // If already a constant, there is nothing to compute. +  if (isa<Constant>(Val)) +    return true; + +  return TheCache.hasCachedValueInfo(Val, BB); +} + +ValueLatticeElement LazyValueInfoImpl::getBlockValue(Value *Val, +                                                     BasicBlock *BB) { +  // If already a constant, there is nothing to compute. +  if (Constant *VC = dyn_cast<Constant>(Val)) +    return ValueLatticeElement::get(VC); + +  return TheCache.getCachedValueInfo(Val, BB); +} + +static ValueLatticeElement getFromRangeMetadata(Instruction *BBI) { +  switch (BBI->getOpcode()) { +  default: break; +  case Instruction::Load: +  case Instruction::Call: +  case Instruction::Invoke: +    if (MDNode *Ranges = BBI->getMetadata(LLVMContext::MD_range)) +      if (isa<IntegerType>(BBI->getType())) { +        return ValueLatticeElement::getRange( +            getConstantRangeFromMetadata(*Ranges)); +      } +    break; +  }; +  // Nothing known - will be intersected with other facts +  return ValueLatticeElement::getOverdefined(); +} + +bool LazyValueInfoImpl::solveBlockValue(Value *Val, BasicBlock *BB) { +  if (isa<Constant>(Val)) +    return true; + +  if (TheCache.hasCachedValueInfo(Val, BB)) { +    // If we have a cached value, use that. +    LLVM_DEBUG(dbgs() << "  reuse BB '" << BB->getName() << "' val=" +                      << TheCache.getCachedValueInfo(Val, BB) << '\n'); + +    // Since we're reusing a cached value, we don't need to update the +    // OverDefinedCache. The cache will have been properly updated whenever the +    // cached value was inserted. +    return true; +  } + +  // Hold off inserting this value into the Cache in case we have to return +  // false and come back later. +  ValueLatticeElement Res; +  if (!solveBlockValueImpl(Res, Val, BB)) +    // Work pushed, will revisit +    return false; + +  TheCache.insertResult(Val, BB, Res); +  return true; +} + +bool LazyValueInfoImpl::solveBlockValueImpl(ValueLatticeElement &Res, +                                            Value *Val, BasicBlock *BB) { + +  Instruction *BBI = dyn_cast<Instruction>(Val); +  if (!BBI || BBI->getParent() != BB) +    return solveBlockValueNonLocal(Res, Val, BB); + +  if (PHINode *PN = dyn_cast<PHINode>(BBI)) +    return solveBlockValuePHINode(Res, PN, BB); + +  if (auto *SI = dyn_cast<SelectInst>(BBI)) +    return solveBlockValueSelect(Res, SI, BB); + +  // If this value is a nonnull pointer, record it's range and bailout.  Note +  // that for all other pointer typed values, we terminate the search at the +  // definition.  We could easily extend this to look through geps, bitcasts, +  // and the like to prove non-nullness, but it's not clear that's worth it +  // compile time wise.  The context-insensitive value walk done inside +  // isKnownNonZero gets most of the profitable cases at much less expense. +  // This does mean that we have a sensativity to where the defining +  // instruction is placed, even if it could legally be hoisted much higher. +  // That is unfortunate. +  PointerType *PT = dyn_cast<PointerType>(BBI->getType()); +  if (PT && isKnownNonZero(BBI, DL)) { +    Res = ValueLatticeElement::getNot(ConstantPointerNull::get(PT)); +    return true; +  } +  if (BBI->getType()->isIntegerTy()) { +    if (auto *CI = dyn_cast<CastInst>(BBI)) +      return solveBlockValueCast(Res, CI, BB); + +    BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI); +    if (BO && isa<ConstantInt>(BO->getOperand(1))) +      return solveBlockValueBinaryOp(Res, BO, BB); +  } + +  LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() +                    << "' - unknown inst def found.\n"); +  Res = getFromRangeMetadata(BBI); +  return true; +} + +static bool InstructionDereferencesPointer(Instruction *I, Value *Ptr) { +  if (LoadInst *L = dyn_cast<LoadInst>(I)) { +    return L->getPointerAddressSpace() == 0 && +           GetUnderlyingObject(L->getPointerOperand(), +                               L->getModule()->getDataLayout()) == Ptr; +  } +  if (StoreInst *S = dyn_cast<StoreInst>(I)) { +    return S->getPointerAddressSpace() == 0 && +           GetUnderlyingObject(S->getPointerOperand(), +                               S->getModule()->getDataLayout()) == Ptr; +  } +  if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) { +    if (MI->isVolatile()) return false; + +    // FIXME: check whether it has a valuerange that excludes zero? +    ConstantInt *Len = dyn_cast<ConstantInt>(MI->getLength()); +    if (!Len || Len->isZero()) return false; + +    if (MI->getDestAddressSpace() == 0) +      if (GetUnderlyingObject(MI->getRawDest(), +                              MI->getModule()->getDataLayout()) == Ptr) +        return true; +    if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) +      if (MTI->getSourceAddressSpace() == 0) +        if (GetUnderlyingObject(MTI->getRawSource(), +                                MTI->getModule()->getDataLayout()) == Ptr) +          return true; +  } +  return false; +} + +/// Return true if the allocation associated with Val is ever dereferenced +/// within the given basic block.  This establishes the fact Val is not null, +/// but does not imply that the memory at Val is dereferenceable.  (Val may +/// point off the end of the dereferenceable part of the object.) +static bool isObjectDereferencedInBlock(Value *Val, BasicBlock *BB) { +  assert(Val->getType()->isPointerTy()); + +  const DataLayout &DL = BB->getModule()->getDataLayout(); +  Value *UnderlyingVal = GetUnderlyingObject(Val, DL); +  // If 'GetUnderlyingObject' didn't converge, skip it. It won't converge +  // inside InstructionDereferencesPointer either. +  if (UnderlyingVal == GetUnderlyingObject(UnderlyingVal, DL, 1)) +    for (Instruction &I : *BB) +      if (InstructionDereferencesPointer(&I, UnderlyingVal)) +        return true; +  return false; +} + +bool LazyValueInfoImpl::solveBlockValueNonLocal(ValueLatticeElement &BBLV, +                                                 Value *Val, BasicBlock *BB) { +  ValueLatticeElement Result;  // Start Undefined. + +  // If this is the entry block, we must be asking about an argument.  The +  // value is overdefined. +  if (BB == &BB->getParent()->getEntryBlock()) { +    assert(isa<Argument>(Val) && "Unknown live-in to the entry block"); +    // Before giving up, see if we can prove the pointer non-null local to +    // this particular block. +    PointerType *PTy = dyn_cast<PointerType>(Val->getType()); +    if (PTy && +        (isKnownNonZero(Val, DL) || +          (isObjectDereferencedInBlock(Val, BB) && +           !NullPointerIsDefined(BB->getParent(), PTy->getAddressSpace())))) { +      Result = ValueLatticeElement::getNot(ConstantPointerNull::get(PTy)); +    } else { +      Result = ValueLatticeElement::getOverdefined(); +    } +    BBLV = Result; +    return true; +  } + +  // Loop over all of our predecessors, merging what we know from them into +  // result.  If we encounter an unexplored predecessor, we eagerly explore it +  // in a depth first manner.  In practice, this has the effect of discovering +  // paths we can't analyze eagerly without spending compile times analyzing +  // other paths.  This heuristic benefits from the fact that predecessors are +  // frequently arranged such that dominating ones come first and we quickly +  // find a path to function entry.  TODO: We should consider explicitly +  // canonicalizing to make this true rather than relying on this happy +  // accident. +  for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { +    ValueLatticeElement EdgeResult; +    if (!getEdgeValue(Val, *PI, BB, EdgeResult)) +      // Explore that input, then return here +      return false; + +    Result.mergeIn(EdgeResult, DL); + +    // If we hit overdefined, exit early.  The BlockVals entry is already set +    // to overdefined. +    if (Result.isOverdefined()) { +      LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() +                        << "' - overdefined because of pred (non local).\n"); +      // Before giving up, see if we can prove the pointer non-null local to +      // this particular block. +      PointerType *PTy = dyn_cast<PointerType>(Val->getType()); +      if (PTy && isObjectDereferencedInBlock(Val, BB) && +          !NullPointerIsDefined(BB->getParent(), PTy->getAddressSpace())) { +        Result = ValueLatticeElement::getNot(ConstantPointerNull::get(PTy)); +      } + +      BBLV = Result; +      return true; +    } +  } + +  // Return the merged value, which is more precise than 'overdefined'. +  assert(!Result.isOverdefined()); +  BBLV = Result; +  return true; +} + +bool LazyValueInfoImpl::solveBlockValuePHINode(ValueLatticeElement &BBLV, +                                               PHINode *PN, BasicBlock *BB) { +  ValueLatticeElement Result;  // Start Undefined. + +  // Loop over all of our predecessors, merging what we know from them into +  // result.  See the comment about the chosen traversal order in +  // solveBlockValueNonLocal; the same reasoning applies here. +  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { +    BasicBlock *PhiBB = PN->getIncomingBlock(i); +    Value *PhiVal = PN->getIncomingValue(i); +    ValueLatticeElement EdgeResult; +    // Note that we can provide PN as the context value to getEdgeValue, even +    // though the results will be cached, because PN is the value being used as +    // the cache key in the caller. +    if (!getEdgeValue(PhiVal, PhiBB, BB, EdgeResult, PN)) +      // Explore that input, then return here +      return false; + +    Result.mergeIn(EdgeResult, DL); + +    // If we hit overdefined, exit early.  The BlockVals entry is already set +    // to overdefined. +    if (Result.isOverdefined()) { +      LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() +                        << "' - overdefined because of pred (local).\n"); + +      BBLV = Result; +      return true; +    } +  } + +  // Return the merged value, which is more precise than 'overdefined'. +  assert(!Result.isOverdefined() && "Possible PHI in entry block?"); +  BBLV = Result; +  return true; +} + +static ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, +                                                 bool isTrueDest = true); + +// If we can determine a constraint on the value given conditions assumed by +// the program, intersect those constraints with BBLV +void LazyValueInfoImpl::intersectAssumeOrGuardBlockValueConstantRange( +        Value *Val, ValueLatticeElement &BBLV, Instruction *BBI) { +  BBI = BBI ? BBI : dyn_cast<Instruction>(Val); +  if (!BBI) +    return; + +  for (auto &AssumeVH : AC->assumptionsFor(Val)) { +    if (!AssumeVH) +      continue; +    auto *I = cast<CallInst>(AssumeVH); +    if (!isValidAssumeForContext(I, BBI, DT)) +      continue; + +    BBLV = intersect(BBLV, getValueFromCondition(Val, I->getArgOperand(0))); +  } + +  // If guards are not used in the module, don't spend time looking for them +  auto *GuardDecl = BBI->getModule()->getFunction( +          Intrinsic::getName(Intrinsic::experimental_guard)); +  if (!GuardDecl || GuardDecl->use_empty()) +    return; + +  for (Instruction &I : make_range(BBI->getIterator().getReverse(), +                                   BBI->getParent()->rend())) { +    Value *Cond = nullptr; +    if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(Cond)))) +      BBLV = intersect(BBLV, getValueFromCondition(Val, Cond)); +  } +} + +bool LazyValueInfoImpl::solveBlockValueSelect(ValueLatticeElement &BBLV, +                                              SelectInst *SI, BasicBlock *BB) { + +  // Recurse on our inputs if needed +  if (!hasBlockValue(SI->getTrueValue(), BB)) { +    if (pushBlockValue(std::make_pair(BB, SI->getTrueValue()))) +      return false; +    BBLV = ValueLatticeElement::getOverdefined(); +    return true; +  } +  ValueLatticeElement TrueVal = getBlockValue(SI->getTrueValue(), BB); +  // If we hit overdefined, don't ask more queries.  We want to avoid poisoning +  // extra slots in the table if we can. +  if (TrueVal.isOverdefined()) { +    BBLV = ValueLatticeElement::getOverdefined(); +    return true; +  } + +  if (!hasBlockValue(SI->getFalseValue(), BB)) { +    if (pushBlockValue(std::make_pair(BB, SI->getFalseValue()))) +      return false; +    BBLV = ValueLatticeElement::getOverdefined(); +    return true; +  } +  ValueLatticeElement FalseVal = getBlockValue(SI->getFalseValue(), BB); +  // If we hit overdefined, don't ask more queries.  We want to avoid poisoning +  // extra slots in the table if we can. +  if (FalseVal.isOverdefined()) { +    BBLV = ValueLatticeElement::getOverdefined(); +    return true; +  } + +  if (TrueVal.isConstantRange() && FalseVal.isConstantRange()) { +    const ConstantRange &TrueCR = TrueVal.getConstantRange(); +    const ConstantRange &FalseCR = FalseVal.getConstantRange(); +    Value *LHS = nullptr; +    Value *RHS = nullptr; +    SelectPatternResult SPR = matchSelectPattern(SI, LHS, RHS); +    // Is this a min specifically of our two inputs?  (Avoid the risk of +    // ValueTracking getting smarter looking back past our immediate inputs.) +    if (SelectPatternResult::isMinOrMax(SPR.Flavor) && +        LHS == SI->getTrueValue() && RHS == SI->getFalseValue()) { +      ConstantRange ResultCR = [&]() { +        switch (SPR.Flavor) { +        default: +          llvm_unreachable("unexpected minmax type!"); +        case SPF_SMIN:                   /// Signed minimum +          return TrueCR.smin(FalseCR); +        case SPF_UMIN:                   /// Unsigned minimum +          return TrueCR.umin(FalseCR); +        case SPF_SMAX:                   /// Signed maximum +          return TrueCR.smax(FalseCR); +        case SPF_UMAX:                   /// Unsigned maximum +          return TrueCR.umax(FalseCR); +        }; +      }(); +      BBLV = ValueLatticeElement::getRange(ResultCR); +      return true; +    } + +    // TODO: ABS, NABS from the SelectPatternResult +  } + +  // Can we constrain the facts about the true and false values by using the +  // condition itself?  This shows up with idioms like e.g. select(a > 5, a, 5). +  // TODO: We could potentially refine an overdefined true value above. +  Value *Cond = SI->getCondition(); +  TrueVal = intersect(TrueVal, +                      getValueFromCondition(SI->getTrueValue(), Cond, true)); +  FalseVal = intersect(FalseVal, +                       getValueFromCondition(SI->getFalseValue(), Cond, false)); + +  // Handle clamp idioms such as: +  //   %24 = constantrange<0, 17> +  //   %39 = icmp eq i32 %24, 0 +  //   %40 = add i32 %24, -1 +  //   %siv.next = select i1 %39, i32 16, i32 %40 +  //   %siv.next = constantrange<0, 17> not <-1, 17> +  // In general, this can handle any clamp idiom which tests the edge +  // condition via an equality or inequality. +  if (auto *ICI = dyn_cast<ICmpInst>(Cond)) { +    ICmpInst::Predicate Pred = ICI->getPredicate(); +    Value *A = ICI->getOperand(0); +    if (ConstantInt *CIBase = dyn_cast<ConstantInt>(ICI->getOperand(1))) { +      auto addConstants = [](ConstantInt *A, ConstantInt *B) { +        assert(A->getType() == B->getType()); +        return ConstantInt::get(A->getType(), A->getValue() + B->getValue()); +      }; +      // See if either input is A + C2, subject to the constraint from the +      // condition that A != C when that input is used.  We can assume that +      // that input doesn't include C + C2. +      ConstantInt *CIAdded; +      switch (Pred) { +      default: break; +      case ICmpInst::ICMP_EQ: +        if (match(SI->getFalseValue(), m_Add(m_Specific(A), +                                             m_ConstantInt(CIAdded)))) { +          auto ResNot = addConstants(CIBase, CIAdded); +          FalseVal = intersect(FalseVal, +                               ValueLatticeElement::getNot(ResNot)); +        } +        break; +      case ICmpInst::ICMP_NE: +        if (match(SI->getTrueValue(), m_Add(m_Specific(A), +                                            m_ConstantInt(CIAdded)))) { +          auto ResNot = addConstants(CIBase, CIAdded); +          TrueVal = intersect(TrueVal, +                              ValueLatticeElement::getNot(ResNot)); +        } +        break; +      }; +    } +  } + +  ValueLatticeElement Result;  // Start Undefined. +  Result.mergeIn(TrueVal, DL); +  Result.mergeIn(FalseVal, DL); +  BBLV = Result; +  return true; +} + +bool LazyValueInfoImpl::solveBlockValueCast(ValueLatticeElement &BBLV, +                                            CastInst *CI, +                                            BasicBlock *BB) { +  if (!CI->getOperand(0)->getType()->isSized()) { +    // Without knowing how wide the input is, we can't analyze it in any useful +    // way. +    BBLV = ValueLatticeElement::getOverdefined(); +    return true; +  } + +  // Filter out casts we don't know how to reason about before attempting to +  // recurse on our operand.  This can cut a long search short if we know we're +  // not going to be able to get any useful information anways. +  switch (CI->getOpcode()) { +  case Instruction::Trunc: +  case Instruction::SExt: +  case Instruction::ZExt: +  case Instruction::BitCast: +    break; +  default: +    // Unhandled instructions are overdefined. +    LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() +                      << "' - overdefined (unknown cast).\n"); +    BBLV = ValueLatticeElement::getOverdefined(); +    return true; +  } + +  // Figure out the range of the LHS.  If that fails, we still apply the +  // transfer rule on the full set since we may be able to locally infer +  // interesting facts. +  if (!hasBlockValue(CI->getOperand(0), BB)) +    if (pushBlockValue(std::make_pair(BB, CI->getOperand(0)))) +      // More work to do before applying this transfer rule. +      return false; + +  const unsigned OperandBitWidth = +    DL.getTypeSizeInBits(CI->getOperand(0)->getType()); +  ConstantRange LHSRange = ConstantRange(OperandBitWidth); +  if (hasBlockValue(CI->getOperand(0), BB)) { +    ValueLatticeElement LHSVal = getBlockValue(CI->getOperand(0), BB); +    intersectAssumeOrGuardBlockValueConstantRange(CI->getOperand(0), LHSVal, +                                                  CI); +    if (LHSVal.isConstantRange()) +      LHSRange = LHSVal.getConstantRange(); +  } + +  const unsigned ResultBitWidth = CI->getType()->getIntegerBitWidth(); + +  // NOTE: We're currently limited by the set of operations that ConstantRange +  // can evaluate symbolically.  Enhancing that set will allows us to analyze +  // more definitions. +  BBLV = ValueLatticeElement::getRange(LHSRange.castOp(CI->getOpcode(), +                                                       ResultBitWidth)); +  return true; +} + +bool LazyValueInfoImpl::solveBlockValueBinaryOp(ValueLatticeElement &BBLV, +                                                BinaryOperator *BO, +                                                BasicBlock *BB) { + +  assert(BO->getOperand(0)->getType()->isSized() && +         "all operands to binary operators are sized"); + +  // Filter out operators we don't know how to reason about before attempting to +  // recurse on our operand(s).  This can cut a long search short if we know +  // we're not going to be able to get any useful information anyways. +  switch (BO->getOpcode()) { +  case Instruction::Add: +  case Instruction::Sub: +  case Instruction::Mul: +  case Instruction::UDiv: +  case Instruction::Shl: +  case Instruction::LShr: +  case Instruction::AShr: +  case Instruction::And: +  case Instruction::Or: +    // continue into the code below +    break; +  default: +    // Unhandled instructions are overdefined. +    LLVM_DEBUG(dbgs() << " compute BB '" << BB->getName() +                      << "' - overdefined (unknown binary operator).\n"); +    BBLV = ValueLatticeElement::getOverdefined(); +    return true; +  }; + +  // Figure out the range of the LHS.  If that fails, use a conservative range, +  // but apply the transfer rule anyways.  This lets us pick up facts from +  // expressions like "and i32 (call i32 @foo()), 32" +  if (!hasBlockValue(BO->getOperand(0), BB)) +    if (pushBlockValue(std::make_pair(BB, BO->getOperand(0)))) +      // More work to do before applying this transfer rule. +      return false; + +  const unsigned OperandBitWidth = +    DL.getTypeSizeInBits(BO->getOperand(0)->getType()); +  ConstantRange LHSRange = ConstantRange(OperandBitWidth); +  if (hasBlockValue(BO->getOperand(0), BB)) { +    ValueLatticeElement LHSVal = getBlockValue(BO->getOperand(0), BB); +    intersectAssumeOrGuardBlockValueConstantRange(BO->getOperand(0), LHSVal, +                                                  BO); +    if (LHSVal.isConstantRange()) +      LHSRange = LHSVal.getConstantRange(); +  } + +  ConstantInt *RHS = cast<ConstantInt>(BO->getOperand(1)); +  ConstantRange RHSRange = ConstantRange(RHS->getValue()); + +  // NOTE: We're currently limited by the set of operations that ConstantRange +  // can evaluate symbolically.  Enhancing that set will allows us to analyze +  // more definitions. +  Instruction::BinaryOps BinOp = BO->getOpcode(); +  BBLV = ValueLatticeElement::getRange(LHSRange.binaryOp(BinOp, RHSRange)); +  return true; +} + +static ValueLatticeElement getValueFromICmpCondition(Value *Val, ICmpInst *ICI, +                                                     bool isTrueDest) { +  Value *LHS = ICI->getOperand(0); +  Value *RHS = ICI->getOperand(1); +  CmpInst::Predicate Predicate = ICI->getPredicate(); + +  if (isa<Constant>(RHS)) { +    if (ICI->isEquality() && LHS == Val) { +      // We know that V has the RHS constant if this is a true SETEQ or +      // false SETNE. +      if (isTrueDest == (Predicate == ICmpInst::ICMP_EQ)) +        return ValueLatticeElement::get(cast<Constant>(RHS)); +      else +        return ValueLatticeElement::getNot(cast<Constant>(RHS)); +    } +  } + +  if (!Val->getType()->isIntegerTy()) +    return ValueLatticeElement::getOverdefined(); + +  // Use ConstantRange::makeAllowedICmpRegion in order to determine the possible +  // range of Val guaranteed by the condition. Recognize comparisons in the from +  // of: +  //  icmp <pred> Val, ... +  //  icmp <pred> (add Val, Offset), ... +  // The latter is the range checking idiom that InstCombine produces. Subtract +  // the offset from the allowed range for RHS in this case. + +  // Val or (add Val, Offset) can be on either hand of the comparison +  if (LHS != Val && !match(LHS, m_Add(m_Specific(Val), m_ConstantInt()))) { +    std::swap(LHS, RHS); +    Predicate = CmpInst::getSwappedPredicate(Predicate); +  } + +  ConstantInt *Offset = nullptr; +  if (LHS != Val) +    match(LHS, m_Add(m_Specific(Val), m_ConstantInt(Offset))); + +  if (LHS == Val || Offset) { +    // Calculate the range of values that are allowed by the comparison +    ConstantRange RHSRange(RHS->getType()->getIntegerBitWidth(), +                           /*isFullSet=*/true); +    if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) +      RHSRange = ConstantRange(CI->getValue()); +    else if (Instruction *I = dyn_cast<Instruction>(RHS)) +      if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) +        RHSRange = getConstantRangeFromMetadata(*Ranges); + +    // If we're interested in the false dest, invert the condition +    CmpInst::Predicate Pred = +            isTrueDest ? Predicate : CmpInst::getInversePredicate(Predicate); +    ConstantRange TrueValues = +            ConstantRange::makeAllowedICmpRegion(Pred, RHSRange); + +    if (Offset) // Apply the offset from above. +      TrueValues = TrueValues.subtract(Offset->getValue()); + +    return ValueLatticeElement::getRange(std::move(TrueValues)); +  } + +  return ValueLatticeElement::getOverdefined(); +} + +static ValueLatticeElement +getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest, +                      DenseMap<Value*, ValueLatticeElement> &Visited); + +static ValueLatticeElement +getValueFromConditionImpl(Value *Val, Value *Cond, bool isTrueDest, +                          DenseMap<Value*, ValueLatticeElement> &Visited) { +  if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cond)) +    return getValueFromICmpCondition(Val, ICI, isTrueDest); + +  // Handle conditions in the form of (cond1 && cond2), we know that on the +  // true dest path both of the conditions hold. Similarly for conditions of +  // the form (cond1 || cond2), we know that on the false dest path neither +  // condition holds. +  BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond); +  if (!BO || (isTrueDest && BO->getOpcode() != BinaryOperator::And) || +             (!isTrueDest && BO->getOpcode() != BinaryOperator::Or)) +    return ValueLatticeElement::getOverdefined(); + +  // Prevent infinite recursion if Cond references itself as in this example: +  //  Cond: "%tmp4 = and i1 %tmp4, undef" +  //    BL: "%tmp4 = and i1 %tmp4, undef" +  //    BR: "i1 undef" +  Value *BL = BO->getOperand(0); +  Value *BR = BO->getOperand(1); +  if (BL == Cond || BR == Cond) +    return ValueLatticeElement::getOverdefined(); + +  return intersect(getValueFromCondition(Val, BL, isTrueDest, Visited), +                   getValueFromCondition(Val, BR, isTrueDest, Visited)); +} + +static ValueLatticeElement +getValueFromCondition(Value *Val, Value *Cond, bool isTrueDest, +                      DenseMap<Value*, ValueLatticeElement> &Visited) { +  auto I = Visited.find(Cond); +  if (I != Visited.end()) +    return I->second; + +  auto Result = getValueFromConditionImpl(Val, Cond, isTrueDest, Visited); +  Visited[Cond] = Result; +  return Result; +} + +ValueLatticeElement getValueFromCondition(Value *Val, Value *Cond, +                                          bool isTrueDest) { +  assert(Cond && "precondition"); +  DenseMap<Value*, ValueLatticeElement> Visited; +  return getValueFromCondition(Val, Cond, isTrueDest, Visited); +} + +// Return true if Usr has Op as an operand, otherwise false. +static bool usesOperand(User *Usr, Value *Op) { +  return find(Usr->operands(), Op) != Usr->op_end(); +} + +// Return true if the instruction type of Val is supported by +// constantFoldUser(). Currently CastInst and BinaryOperator only.  Call this +// before calling constantFoldUser() to find out if it's even worth attempting +// to call it. +static bool isOperationFoldable(User *Usr) { +  return isa<CastInst>(Usr) || isa<BinaryOperator>(Usr); +} + +// Check if Usr can be simplified to an integer constant when the value of one +// of its operands Op is an integer constant OpConstVal. If so, return it as an +// lattice value range with a single element or otherwise return an overdefined +// lattice value. +static ValueLatticeElement constantFoldUser(User *Usr, Value *Op, +                                            const APInt &OpConstVal, +                                            const DataLayout &DL) { +  assert(isOperationFoldable(Usr) && "Precondition"); +  Constant* OpConst = Constant::getIntegerValue(Op->getType(), OpConstVal); +  // Check if Usr can be simplified to a constant. +  if (auto *CI = dyn_cast<CastInst>(Usr)) { +    assert(CI->getOperand(0) == Op && "Operand 0 isn't Op"); +    if (auto *C = dyn_cast_or_null<ConstantInt>( +            SimplifyCastInst(CI->getOpcode(), OpConst, +                             CI->getDestTy(), DL))) { +      return ValueLatticeElement::getRange(ConstantRange(C->getValue())); +    } +  } else if (auto *BO = dyn_cast<BinaryOperator>(Usr)) { +    bool Op0Match = BO->getOperand(0) == Op; +    bool Op1Match = BO->getOperand(1) == Op; +    assert((Op0Match || Op1Match) && +           "Operand 0 nor Operand 1 isn't a match"); +    Value *LHS = Op0Match ? OpConst : BO->getOperand(0); +    Value *RHS = Op1Match ? OpConst : BO->getOperand(1); +    if (auto *C = dyn_cast_or_null<ConstantInt>( +            SimplifyBinOp(BO->getOpcode(), LHS, RHS, DL))) { +      return ValueLatticeElement::getRange(ConstantRange(C->getValue())); +    } +  } +  return ValueLatticeElement::getOverdefined(); +} + +/// Compute the value of Val on the edge BBFrom -> BBTo. Returns false if +/// Val is not constrained on the edge.  Result is unspecified if return value +/// is false. +static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, +                              BasicBlock *BBTo, ValueLatticeElement &Result) { +  // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we +  // know that v != 0. +  if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) { +    // If this is a conditional branch and only one successor goes to BBTo, then +    // we may be able to infer something from the condition. +    if (BI->isConditional() && +        BI->getSuccessor(0) != BI->getSuccessor(1)) { +      bool isTrueDest = BI->getSuccessor(0) == BBTo; +      assert(BI->getSuccessor(!isTrueDest) == BBTo && +             "BBTo isn't a successor of BBFrom"); +      Value *Condition = BI->getCondition(); + +      // If V is the condition of the branch itself, then we know exactly what +      // it is. +      if (Condition == Val) { +        Result = ValueLatticeElement::get(ConstantInt::get( +                              Type::getInt1Ty(Val->getContext()), isTrueDest)); +        return true; +      } + +      // If the condition of the branch is an equality comparison, we may be +      // able to infer the value. +      Result = getValueFromCondition(Val, Condition, isTrueDest); +      if (!Result.isOverdefined()) +        return true; + +      if (User *Usr = dyn_cast<User>(Val)) { +        assert(Result.isOverdefined() && "Result isn't overdefined"); +        // Check with isOperationFoldable() first to avoid linearly iterating +        // over the operands unnecessarily which can be expensive for +        // instructions with many operands. +        if (isa<IntegerType>(Usr->getType()) && isOperationFoldable(Usr)) { +          const DataLayout &DL = BBTo->getModule()->getDataLayout(); +          if (usesOperand(Usr, Condition)) { +            // If Val has Condition as an operand and Val can be folded into a +            // constant with either Condition == true or Condition == false, +            // propagate the constant. +            // eg. +            //   ; %Val is true on the edge to %then. +            //   %Val = and i1 %Condition, true. +            //   br %Condition, label %then, label %else +            APInt ConditionVal(1, isTrueDest ? 1 : 0); +            Result = constantFoldUser(Usr, Condition, ConditionVal, DL); +          } else { +            // If one of Val's operand has an inferred value, we may be able to +            // infer the value of Val. +            // eg. +            //    ; %Val is 94 on the edge to %then. +            //    %Val = add i8 %Op, 1 +            //    %Condition = icmp eq i8 %Op, 93 +            //    br i1 %Condition, label %then, label %else +            for (unsigned i = 0; i < Usr->getNumOperands(); ++i) { +              Value *Op = Usr->getOperand(i); +              ValueLatticeElement OpLatticeVal = +                  getValueFromCondition(Op, Condition, isTrueDest); +              if (Optional<APInt> OpConst = OpLatticeVal.asConstantInteger()) { +                Result = constantFoldUser(Usr, Op, OpConst.getValue(), DL); +                break; +              } +            } +          } +        } +      } +      if (!Result.isOverdefined()) +        return true; +    } +  } + +  // If the edge was formed by a switch on the value, then we may know exactly +  // what it is. +  if (SwitchInst *SI = dyn_cast<SwitchInst>(BBFrom->getTerminator())) { +    Value *Condition = SI->getCondition(); +    if (!isa<IntegerType>(Val->getType())) +      return false; +    bool ValUsesConditionAndMayBeFoldable = false; +    if (Condition != Val) { +      // Check if Val has Condition as an operand. +      if (User *Usr = dyn_cast<User>(Val)) +        ValUsesConditionAndMayBeFoldable = isOperationFoldable(Usr) && +            usesOperand(Usr, Condition); +      if (!ValUsesConditionAndMayBeFoldable) +        return false; +    } +    assert((Condition == Val || ValUsesConditionAndMayBeFoldable) && +           "Condition != Val nor Val doesn't use Condition"); + +    bool DefaultCase = SI->getDefaultDest() == BBTo; +    unsigned BitWidth = Val->getType()->getIntegerBitWidth(); +    ConstantRange EdgesVals(BitWidth, DefaultCase/*isFullSet*/); + +    for (auto Case : SI->cases()) { +      APInt CaseValue = Case.getCaseValue()->getValue(); +      ConstantRange EdgeVal(CaseValue); +      if (ValUsesConditionAndMayBeFoldable) { +        User *Usr = cast<User>(Val); +        const DataLayout &DL = BBTo->getModule()->getDataLayout(); +        ValueLatticeElement EdgeLatticeVal = +            constantFoldUser(Usr, Condition, CaseValue, DL); +        if (EdgeLatticeVal.isOverdefined()) +          return false; +        EdgeVal = EdgeLatticeVal.getConstantRange(); +      } +      if (DefaultCase) { +        // It is possible that the default destination is the destination of +        // some cases. We cannot perform difference for those cases. +        // We know Condition != CaseValue in BBTo.  In some cases we can use +        // this to infer Val == f(Condition) is != f(CaseValue).  For now, we +        // only do this when f is identity (i.e. Val == Condition), but we +        // should be able to do this for any injective f. +        if (Case.getCaseSuccessor() != BBTo && Condition == Val) +          EdgesVals = EdgesVals.difference(EdgeVal); +      } else if (Case.getCaseSuccessor() == BBTo) +        EdgesVals = EdgesVals.unionWith(EdgeVal); +    } +    Result = ValueLatticeElement::getRange(std::move(EdgesVals)); +    return true; +  } +  return false; +} + +/// Compute the value of Val on the edge BBFrom -> BBTo or the value at +/// the basic block if the edge does not constrain Val. +bool LazyValueInfoImpl::getEdgeValue(Value *Val, BasicBlock *BBFrom, +                                     BasicBlock *BBTo, +                                     ValueLatticeElement &Result, +                                     Instruction *CxtI) { +  // If already a constant, there is nothing to compute. +  if (Constant *VC = dyn_cast<Constant>(Val)) { +    Result = ValueLatticeElement::get(VC); +    return true; +  } + +  ValueLatticeElement LocalResult; +  if (!getEdgeValueLocal(Val, BBFrom, BBTo, LocalResult)) +    // If we couldn't constrain the value on the edge, LocalResult doesn't +    // provide any information. +    LocalResult = ValueLatticeElement::getOverdefined(); + +  if (hasSingleValue(LocalResult)) { +    // Can't get any more precise here +    Result = LocalResult; +    return true; +  } + +  if (!hasBlockValue(Val, BBFrom)) { +    if (pushBlockValue(std::make_pair(BBFrom, Val))) +      return false; +    // No new information. +    Result = LocalResult; +    return true; +  } + +  // Try to intersect ranges of the BB and the constraint on the edge. +  ValueLatticeElement InBlock = getBlockValue(Val, BBFrom); +  intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, +                                                BBFrom->getTerminator()); +  // We can use the context instruction (generically the ultimate instruction +  // the calling pass is trying to simplify) here, even though the result of +  // this function is generally cached when called from the solve* functions +  // (and that cached result might be used with queries using a different +  // context instruction), because when this function is called from the solve* +  // functions, the context instruction is not provided. When called from +  // LazyValueInfoImpl::getValueOnEdge, the context instruction is provided, +  // but then the result is not cached. +  intersectAssumeOrGuardBlockValueConstantRange(Val, InBlock, CxtI); + +  Result = intersect(LocalResult, InBlock); +  return true; +} + +ValueLatticeElement LazyValueInfoImpl::getValueInBlock(Value *V, BasicBlock *BB, +                                                       Instruction *CxtI) { +  LLVM_DEBUG(dbgs() << "LVI Getting block end value " << *V << " at '" +                    << BB->getName() << "'\n"); + +  assert(BlockValueStack.empty() && BlockValueSet.empty()); +  if (!hasBlockValue(V, BB)) { +    pushBlockValue(std::make_pair(BB, V)); +    solve(); +  } +  ValueLatticeElement Result = getBlockValue(V, BB); +  intersectAssumeOrGuardBlockValueConstantRange(V, Result, CxtI); + +  LLVM_DEBUG(dbgs() << "  Result = " << Result << "\n"); +  return Result; +} + +ValueLatticeElement LazyValueInfoImpl::getValueAt(Value *V, Instruction *CxtI) { +  LLVM_DEBUG(dbgs() << "LVI Getting value " << *V << " at '" << CxtI->getName() +                    << "'\n"); + +  if (auto *C = dyn_cast<Constant>(V)) +    return ValueLatticeElement::get(C); + +  ValueLatticeElement Result = ValueLatticeElement::getOverdefined(); +  if (auto *I = dyn_cast<Instruction>(V)) +    Result = getFromRangeMetadata(I); +  intersectAssumeOrGuardBlockValueConstantRange(V, Result, CxtI); + +  LLVM_DEBUG(dbgs() << "  Result = " << Result << "\n"); +  return Result; +} + +ValueLatticeElement LazyValueInfoImpl:: +getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB, +               Instruction *CxtI) { +  LLVM_DEBUG(dbgs() << "LVI Getting edge value " << *V << " from '" +                    << FromBB->getName() << "' to '" << ToBB->getName() +                    << "'\n"); + +  ValueLatticeElement Result; +  if (!getEdgeValue(V, FromBB, ToBB, Result, CxtI)) { +    solve(); +    bool WasFastQuery = getEdgeValue(V, FromBB, ToBB, Result, CxtI); +    (void)WasFastQuery; +    assert(WasFastQuery && "More work to do after problem solved?"); +  } + +  LLVM_DEBUG(dbgs() << "  Result = " << Result << "\n"); +  return Result; +} + +void LazyValueInfoImpl::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, +                                   BasicBlock *NewSucc) { +  TheCache.threadEdgeImpl(OldSucc, NewSucc); +} + +//===----------------------------------------------------------------------===// +//                            LazyValueInfo Impl +//===----------------------------------------------------------------------===// + +/// This lazily constructs the LazyValueInfoImpl. +static LazyValueInfoImpl &getImpl(void *&PImpl, AssumptionCache *AC, +                                  const DataLayout *DL, +                                  DominatorTree *DT = nullptr) { +  if (!PImpl) { +    assert(DL && "getCache() called with a null DataLayout"); +    PImpl = new LazyValueInfoImpl(AC, *DL, DT); +  } +  return *static_cast<LazyValueInfoImpl*>(PImpl); +} + +bool LazyValueInfoWrapperPass::runOnFunction(Function &F) { +  Info.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +  const DataLayout &DL = F.getParent()->getDataLayout(); + +  DominatorTreeWrapperPass *DTWP = +      getAnalysisIfAvailable<DominatorTreeWrapperPass>(); +  Info.DT = DTWP ? &DTWP->getDomTree() : nullptr; +  Info.TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + +  if (Info.PImpl) +    getImpl(Info.PImpl, Info.AC, &DL, Info.DT).clear(); + +  // Fully lazy. +  return false; +} + +void LazyValueInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<AssumptionCacheTracker>(); +  AU.addRequired<TargetLibraryInfoWrapperPass>(); +} + +LazyValueInfo &LazyValueInfoWrapperPass::getLVI() { return Info; } + +LazyValueInfo::~LazyValueInfo() { releaseMemory(); } + +void LazyValueInfo::releaseMemory() { +  // If the cache was allocated, free it. +  if (PImpl) { +    delete &getImpl(PImpl, AC, nullptr); +    PImpl = nullptr; +  } +} + +bool LazyValueInfo::invalidate(Function &F, const PreservedAnalyses &PA, +                               FunctionAnalysisManager::Invalidator &Inv) { +  // We need to invalidate if we have either failed to preserve this analyses +  // result directly or if any of its dependencies have been invalidated. +  auto PAC = PA.getChecker<LazyValueAnalysis>(); +  if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || +      (DT && Inv.invalidate<DominatorTreeAnalysis>(F, PA))) +    return true; + +  return false; +} + +void LazyValueInfoWrapperPass::releaseMemory() { Info.releaseMemory(); } + +LazyValueInfo LazyValueAnalysis::run(Function &F, +                                     FunctionAnalysisManager &FAM) { +  auto &AC = FAM.getResult<AssumptionAnalysis>(F); +  auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); +  auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); + +  return LazyValueInfo(&AC, &F.getParent()->getDataLayout(), &TLI, DT); +} + +/// Returns true if we can statically tell that this value will never be a +/// "useful" constant.  In practice, this means we've got something like an +/// alloca or a malloc call for which a comparison against a constant can +/// only be guarding dead code.  Note that we are potentially giving up some +/// precision in dead code (a constant result) in favour of avoiding a +/// expensive search for a easily answered common query. +static bool isKnownNonConstant(Value *V) { +  V = V->stripPointerCasts(); +  // The return val of alloc cannot be a Constant. +  if (isa<AllocaInst>(V)) +    return true; +  return false; +} + +Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB, +                                     Instruction *CxtI) { +  // Bail out early if V is known not to be a Constant. +  if (isKnownNonConstant(V)) +    return nullptr; + +  const DataLayout &DL = BB->getModule()->getDataLayout(); +  ValueLatticeElement Result = +      getImpl(PImpl, AC, &DL, DT).getValueInBlock(V, BB, CxtI); + +  if (Result.isConstant()) +    return Result.getConstant(); +  if (Result.isConstantRange()) { +    const ConstantRange &CR = Result.getConstantRange(); +    if (const APInt *SingleVal = CR.getSingleElement()) +      return ConstantInt::get(V->getContext(), *SingleVal); +  } +  return nullptr; +} + +ConstantRange LazyValueInfo::getConstantRange(Value *V, BasicBlock *BB, +                                              Instruction *CxtI) { +  assert(V->getType()->isIntegerTy()); +  unsigned Width = V->getType()->getIntegerBitWidth(); +  const DataLayout &DL = BB->getModule()->getDataLayout(); +  ValueLatticeElement Result = +      getImpl(PImpl, AC, &DL, DT).getValueInBlock(V, BB, CxtI); +  if (Result.isUndefined()) +    return ConstantRange(Width, /*isFullSet=*/false); +  if (Result.isConstantRange()) +    return Result.getConstantRange(); +  // We represent ConstantInt constants as constant ranges but other kinds +  // of integer constants, i.e. ConstantExpr will be tagged as constants +  assert(!(Result.isConstant() && isa<ConstantInt>(Result.getConstant())) && +         "ConstantInt value must be represented as constantrange"); +  return ConstantRange(Width, /*isFullSet=*/true); +} + +/// Determine whether the specified value is known to be a +/// constant on the specified edge. Return null if not. +Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, +                                           BasicBlock *ToBB, +                                           Instruction *CxtI) { +  const DataLayout &DL = FromBB->getModule()->getDataLayout(); +  ValueLatticeElement Result = +      getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + +  if (Result.isConstant()) +    return Result.getConstant(); +  if (Result.isConstantRange()) { +    const ConstantRange &CR = Result.getConstantRange(); +    if (const APInt *SingleVal = CR.getSingleElement()) +      return ConstantInt::get(V->getContext(), *SingleVal); +  } +  return nullptr; +} + +ConstantRange LazyValueInfo::getConstantRangeOnEdge(Value *V, +                                                    BasicBlock *FromBB, +                                                    BasicBlock *ToBB, +                                                    Instruction *CxtI) { +  unsigned Width = V->getType()->getIntegerBitWidth(); +  const DataLayout &DL = FromBB->getModule()->getDataLayout(); +  ValueLatticeElement Result = +      getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + +  if (Result.isUndefined()) +    return ConstantRange(Width, /*isFullSet=*/false); +  if (Result.isConstantRange()) +    return Result.getConstantRange(); +  // We represent ConstantInt constants as constant ranges but other kinds +  // of integer constants, i.e. ConstantExpr will be tagged as constants +  assert(!(Result.isConstant() && isa<ConstantInt>(Result.getConstant())) && +         "ConstantInt value must be represented as constantrange"); +  return ConstantRange(Width, /*isFullSet=*/true); +} + +static LazyValueInfo::Tristate +getPredicateResult(unsigned Pred, Constant *C, const ValueLatticeElement &Val, +                   const DataLayout &DL, TargetLibraryInfo *TLI) { +  // If we know the value is a constant, evaluate the conditional. +  Constant *Res = nullptr; +  if (Val.isConstant()) { +    Res = ConstantFoldCompareInstOperands(Pred, Val.getConstant(), C, DL, TLI); +    if (ConstantInt *ResCI = dyn_cast<ConstantInt>(Res)) +      return ResCI->isZero() ? LazyValueInfo::False : LazyValueInfo::True; +    return LazyValueInfo::Unknown; +  } + +  if (Val.isConstantRange()) { +    ConstantInt *CI = dyn_cast<ConstantInt>(C); +    if (!CI) return LazyValueInfo::Unknown; + +    const ConstantRange &CR = Val.getConstantRange(); +    if (Pred == ICmpInst::ICMP_EQ) { +      if (!CR.contains(CI->getValue())) +        return LazyValueInfo::False; + +      if (CR.isSingleElement()) +        return LazyValueInfo::True; +    } else if (Pred == ICmpInst::ICMP_NE) { +      if (!CR.contains(CI->getValue())) +        return LazyValueInfo::True; + +      if (CR.isSingleElement()) +        return LazyValueInfo::False; +    } else { +      // Handle more complex predicates. +      ConstantRange TrueValues = ConstantRange::makeExactICmpRegion( +          (ICmpInst::Predicate)Pred, CI->getValue()); +      if (TrueValues.contains(CR)) +        return LazyValueInfo::True; +      if (TrueValues.inverse().contains(CR)) +        return LazyValueInfo::False; +    } +    return LazyValueInfo::Unknown; +  } + +  if (Val.isNotConstant()) { +    // If this is an equality comparison, we can try to fold it knowing that +    // "V != C1". +    if (Pred == ICmpInst::ICMP_EQ) { +      // !C1 == C -> false iff C1 == C. +      Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, +                                            Val.getNotConstant(), C, DL, +                                            TLI); +      if (Res->isNullValue()) +        return LazyValueInfo::False; +    } else if (Pred == ICmpInst::ICMP_NE) { +      // !C1 != C -> true iff C1 == C. +      Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, +                                            Val.getNotConstant(), C, DL, +                                            TLI); +      if (Res->isNullValue()) +        return LazyValueInfo::True; +    } +    return LazyValueInfo::Unknown; +  } + +  return LazyValueInfo::Unknown; +} + +/// Determine whether the specified value comparison with a constant is known to +/// be true or false on the specified CFG edge. Pred is a CmpInst predicate. +LazyValueInfo::Tristate +LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, +                                  BasicBlock *FromBB, BasicBlock *ToBB, +                                  Instruction *CxtI) { +  const DataLayout &DL = FromBB->getModule()->getDataLayout(); +  ValueLatticeElement Result = +      getImpl(PImpl, AC, &DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + +  return getPredicateResult(Pred, C, Result, DL, TLI); +} + +LazyValueInfo::Tristate +LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C, +                              Instruction *CxtI) { +  // Is or is not NonNull are common predicates being queried. If +  // isKnownNonZero can tell us the result of the predicate, we can +  // return it quickly. But this is only a fastpath, and falling +  // through would still be correct. +  const DataLayout &DL = CxtI->getModule()->getDataLayout(); +  if (V->getType()->isPointerTy() && C->isNullValue() && +      isKnownNonZero(V->stripPointerCasts(), DL)) { +    if (Pred == ICmpInst::ICMP_EQ) +      return LazyValueInfo::False; +    else if (Pred == ICmpInst::ICMP_NE) +      return LazyValueInfo::True; +  } +  ValueLatticeElement Result = getImpl(PImpl, AC, &DL, DT).getValueAt(V, CxtI); +  Tristate Ret = getPredicateResult(Pred, C, Result, DL, TLI); +  if (Ret != Unknown) +    return Ret; + +  // Note: The following bit of code is somewhat distinct from the rest of LVI; +  // LVI as a whole tries to compute a lattice value which is conservatively +  // correct at a given location.  In this case, we have a predicate which we +  // weren't able to prove about the merged result, and we're pushing that +  // predicate back along each incoming edge to see if we can prove it +  // separately for each input.  As a motivating example, consider: +  // bb1: +  //   %v1 = ... ; constantrange<1, 5> +  //   br label %merge +  // bb2: +  //   %v2 = ... ; constantrange<10, 20> +  //   br label %merge +  // merge: +  //   %phi = phi [%v1, %v2] ; constantrange<1,20> +  //   %pred = icmp eq i32 %phi, 8 +  // We can't tell from the lattice value for '%phi' that '%pred' is false +  // along each path, but by checking the predicate over each input separately, +  // we can. +  // We limit the search to one step backwards from the current BB and value. +  // We could consider extending this to search further backwards through the +  // CFG and/or value graph, but there are non-obvious compile time vs quality +  // tradeoffs. +  if (CxtI) { +    BasicBlock *BB = CxtI->getParent(); + +    // Function entry or an unreachable block.  Bail to avoid confusing +    // analysis below. +    pred_iterator PI = pred_begin(BB), PE = pred_end(BB); +    if (PI == PE) +      return Unknown; + +    // If V is a PHI node in the same block as the context, we need to ask +    // questions about the predicate as applied to the incoming value along +    // each edge. This is useful for eliminating cases where the predicate is +    // known along all incoming edges. +    if (auto *PHI = dyn_cast<PHINode>(V)) +      if (PHI->getParent() == BB) { +        Tristate Baseline = Unknown; +        for (unsigned i = 0, e = PHI->getNumIncomingValues(); i < e; i++) { +          Value *Incoming = PHI->getIncomingValue(i); +          BasicBlock *PredBB = PHI->getIncomingBlock(i); +          // Note that PredBB may be BB itself. +          Tristate Result = getPredicateOnEdge(Pred, Incoming, C, PredBB, BB, +                                               CxtI); + +          // Keep going as long as we've seen a consistent known result for +          // all inputs. +          Baseline = (i == 0) ? Result /* First iteration */ +            : (Baseline == Result ? Baseline : Unknown); /* All others */ +          if (Baseline == Unknown) +            break; +        } +        if (Baseline != Unknown) +          return Baseline; +      } + +    // For a comparison where the V is outside this block, it's possible +    // that we've branched on it before. Look to see if the value is known +    // on all incoming edges. +    if (!isa<Instruction>(V) || +        cast<Instruction>(V)->getParent() != BB) { +      // For predecessor edge, determine if the comparison is true or false +      // on that edge. If they're all true or all false, we can conclude +      // the value of the comparison in this block. +      Tristate Baseline = getPredicateOnEdge(Pred, V, C, *PI, BB, CxtI); +      if (Baseline != Unknown) { +        // Check that all remaining incoming values match the first one. +        while (++PI != PE) { +          Tristate Ret = getPredicateOnEdge(Pred, V, C, *PI, BB, CxtI); +          if (Ret != Baseline) break; +        } +        // If we terminated early, then one of the values didn't match. +        if (PI == PE) { +          return Baseline; +        } +      } +    } +  } +  return Unknown; +} + +void LazyValueInfo::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, +                               BasicBlock *NewSucc) { +  if (PImpl) { +    const DataLayout &DL = PredBB->getModule()->getDataLayout(); +    getImpl(PImpl, AC, &DL, DT).threadEdge(PredBB, OldSucc, NewSucc); +  } +} + +void LazyValueInfo::eraseBlock(BasicBlock *BB) { +  if (PImpl) { +    const DataLayout &DL = BB->getModule()->getDataLayout(); +    getImpl(PImpl, AC, &DL, DT).eraseBlock(BB); +  } +} + + +void LazyValueInfo::printLVI(Function &F, DominatorTree &DTree, raw_ostream &OS) { +  if (PImpl) { +    getImpl(PImpl, AC, DL, DT).printLVI(F, DTree, OS); +  } +} + +void LazyValueInfo::disableDT() { +  if (PImpl) +    getImpl(PImpl, AC, DL, DT).disableDT(); +} + +void LazyValueInfo::enableDT() { +  if (PImpl) +    getImpl(PImpl, AC, DL, DT).enableDT(); +} + +// Print the LVI for the function arguments at the start of each basic block. +void LazyValueInfoAnnotatedWriter::emitBasicBlockStartAnnot( +    const BasicBlock *BB, formatted_raw_ostream &OS) { +  // Find if there are latticevalues defined for arguments of the function. +  auto *F = BB->getParent(); +  for (auto &Arg : F->args()) { +    ValueLatticeElement Result = LVIImpl->getValueInBlock( +        const_cast<Argument *>(&Arg), const_cast<BasicBlock *>(BB)); +    if (Result.isUndefined()) +      continue; +    OS << "; LatticeVal for: '" << Arg << "' is: " << Result << "\n"; +  } +} + +// This function prints the LVI analysis for the instruction I at the beginning +// of various basic blocks. It relies on calculated values that are stored in +// the LazyValueInfoCache, and in the absence of cached values, recalculate the +// LazyValueInfo for `I`, and print that info. +void LazyValueInfoAnnotatedWriter::emitInstructionAnnot( +    const Instruction *I, formatted_raw_ostream &OS) { + +  auto *ParentBB = I->getParent(); +  SmallPtrSet<const BasicBlock*, 16> BlocksContainingLVI; +  // We can generate (solve) LVI values only for blocks that are dominated by +  // the I's parent. However, to avoid generating LVI for all dominating blocks, +  // that contain redundant/uninteresting information, we print LVI for +  // blocks that may use this LVI information (such as immediate successor +  // blocks, and blocks that contain uses of `I`). +  auto printResult = [&](const BasicBlock *BB) { +    if (!BlocksContainingLVI.insert(BB).second) +      return; +    ValueLatticeElement Result = LVIImpl->getValueInBlock( +        const_cast<Instruction *>(I), const_cast<BasicBlock *>(BB)); +      OS << "; LatticeVal for: '" << *I << "' in BB: '"; +      BB->printAsOperand(OS, false); +      OS << "' is: " << Result << "\n"; +  }; + +  printResult(ParentBB); +  // Print the LVI analysis results for the immediate successor blocks, that +  // are dominated by `ParentBB`. +  for (auto *BBSucc : successors(ParentBB)) +    if (DT.dominates(ParentBB, BBSucc)) +      printResult(BBSucc); + +  // Print LVI in blocks where `I` is used. +  for (auto *U : I->users()) +    if (auto *UseI = dyn_cast<Instruction>(U)) +      if (!isa<PHINode>(UseI) || DT.dominates(ParentBB, UseI->getParent())) +        printResult(UseI->getParent()); + +} + +namespace { +// Printer class for LazyValueInfo results. +class LazyValueInfoPrinter : public FunctionPass { +public: +  static char ID; // Pass identification, replacement for typeid +  LazyValueInfoPrinter() : FunctionPass(ID) { +    initializeLazyValueInfoPrinterPass(*PassRegistry::getPassRegistry()); +  } + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.setPreservesAll(); +    AU.addRequired<LazyValueInfoWrapperPass>(); +    AU.addRequired<DominatorTreeWrapperPass>(); +  } + +  // Get the mandatory dominator tree analysis and pass this in to the +  // LVIPrinter. We cannot rely on the LVI's DT, since it's optional. +  bool runOnFunction(Function &F) override { +    dbgs() << "LVI for function '" << F.getName() << "':\n"; +    auto &LVI = getAnalysis<LazyValueInfoWrapperPass>().getLVI(); +    auto &DTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +    LVI.printLVI(F, DTree, dbgs()); +    return false; +  } +}; +} + +char LazyValueInfoPrinter::ID = 0; +INITIALIZE_PASS_BEGIN(LazyValueInfoPrinter, "print-lazy-value-info", +                "Lazy Value Info Printer Pass", false, false) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(LazyValueInfoPrinter, "print-lazy-value-info", +                "Lazy Value Info Printer Pass", false, false) diff --git a/contrib/llvm/lib/Analysis/Lint.cpp b/contrib/llvm/lib/Analysis/Lint.cpp new file mode 100644 index 000000000000..db919bd233bf --- /dev/null +++ b/contrib/llvm/lib/Analysis/Lint.cpp @@ -0,0 +1,753 @@ +//===-- Lint.cpp - Check for common errors in LLVM IR ---------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass statically checks for common and easily-identified constructs +// which produce undefined or likely unintended behavior in LLVM IR. +// +// It is not a guarantee of correctness, in two ways. First, it isn't +// comprehensive. There are checks which could be done statically which are +// not yet implemented. Some of these are indicated by TODO comments, but +// those aren't comprehensive either. Second, many conditions cannot be +// checked statically. This pass does no dynamic instrumentation, so it +// can't check for all possible problems. +// +// Another limitation is that it assumes all code will be executed. A store +// through a null pointer in a basic block which is never reached is harmless, +// but this pass will warn about it anyway. This is the main reason why most +// of these checks live here instead of in the Verifier pass. +// +// Optimization passes may make conditions that this pass checks for more or +// less obvious. If an optimization pass appears to be introducing a warning, +// it may be that the optimization pass is merely exposing an existing +// condition in the code. +// +// This code may be run before instcombine. In many cases, instcombine checks +// for the same kinds of things and turns instructions with undefined behavior +// into unreachable (or equivalent). Because of this, this pass makes some +// effort to look through bitcasts and so on. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Lint.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <string> + +using namespace llvm; + +namespace { +  namespace MemRef { +    static const unsigned Read     = 1; +    static const unsigned Write    = 2; +    static const unsigned Callee   = 4; +    static const unsigned Branchee = 8; +  } // end namespace MemRef + +  class Lint : public FunctionPass, public InstVisitor<Lint> { +    friend class InstVisitor<Lint>; + +    void visitFunction(Function &F); + +    void visitCallSite(CallSite CS); +    void visitMemoryReference(Instruction &I, Value *Ptr, +                              uint64_t Size, unsigned Align, +                              Type *Ty, unsigned Flags); +    void visitEHBeginCatch(IntrinsicInst *II); +    void visitEHEndCatch(IntrinsicInst *II); + +    void visitCallInst(CallInst &I); +    void visitInvokeInst(InvokeInst &I); +    void visitReturnInst(ReturnInst &I); +    void visitLoadInst(LoadInst &I); +    void visitStoreInst(StoreInst &I); +    void visitXor(BinaryOperator &I); +    void visitSub(BinaryOperator &I); +    void visitLShr(BinaryOperator &I); +    void visitAShr(BinaryOperator &I); +    void visitShl(BinaryOperator &I); +    void visitSDiv(BinaryOperator &I); +    void visitUDiv(BinaryOperator &I); +    void visitSRem(BinaryOperator &I); +    void visitURem(BinaryOperator &I); +    void visitAllocaInst(AllocaInst &I); +    void visitVAArgInst(VAArgInst &I); +    void visitIndirectBrInst(IndirectBrInst &I); +    void visitExtractElementInst(ExtractElementInst &I); +    void visitInsertElementInst(InsertElementInst &I); +    void visitUnreachableInst(UnreachableInst &I); + +    Value *findValue(Value *V, bool OffsetOk) const; +    Value *findValueImpl(Value *V, bool OffsetOk, +                         SmallPtrSetImpl<Value *> &Visited) const; + +  public: +    Module *Mod; +    const DataLayout *DL; +    AliasAnalysis *AA; +    AssumptionCache *AC; +    DominatorTree *DT; +    TargetLibraryInfo *TLI; + +    std::string Messages; +    raw_string_ostream MessagesStr; + +    static char ID; // Pass identification, replacement for typeid +    Lint() : FunctionPass(ID), MessagesStr(Messages) { +      initializeLintPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnFunction(Function &F) override; + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +      AU.addRequired<AAResultsWrapperPass>(); +      AU.addRequired<AssumptionCacheTracker>(); +      AU.addRequired<TargetLibraryInfoWrapperPass>(); +      AU.addRequired<DominatorTreeWrapperPass>(); +    } +    void print(raw_ostream &O, const Module *M) const override {} + +    void WriteValues(ArrayRef<const Value *> Vs) { +      for (const Value *V : Vs) { +        if (!V) +          continue; +        if (isa<Instruction>(V)) { +          MessagesStr << *V << '\n'; +        } else { +          V->printAsOperand(MessagesStr, true, Mod); +          MessagesStr << '\n'; +        } +      } +    } + +    /// A check failed, so printout out the condition and the message. +    /// +    /// This provides a nice place to put a breakpoint if you want to see why +    /// something is not correct. +    void CheckFailed(const Twine &Message) { MessagesStr << Message << '\n'; } + +    /// A check failed (with values to print). +    /// +    /// This calls the Message-only version so that the above is easier to set +    /// a breakpoint on. +    template <typename T1, typename... Ts> +    void CheckFailed(const Twine &Message, const T1 &V1, const Ts &...Vs) { +      CheckFailed(Message); +      WriteValues({V1, Vs...}); +    } +  }; +} // end anonymous namespace + +char Lint::ID = 0; +INITIALIZE_PASS_BEGIN(Lint, "lint", "Statically lint-checks LLVM IR", +                      false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(Lint, "lint", "Statically lint-checks LLVM IR", +                    false, true) + +// Assert - We know that cond should be true, if not print an error message. +#define Assert(C, ...) \ +    do { if (!(C)) { CheckFailed(__VA_ARGS__); return; } } while (false) + +// Lint::run - This is the main Analysis entry point for a +// function. +// +bool Lint::runOnFunction(Function &F) { +  Mod = F.getParent(); +  DL = &F.getParent()->getDataLayout(); +  AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); +  AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +  DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); +  visit(F); +  dbgs() << MessagesStr.str(); +  Messages.clear(); +  return false; +} + +void Lint::visitFunction(Function &F) { +  // This isn't undefined behavior, it's just a little unusual, and it's a +  // fairly common mistake to neglect to name a function. +  Assert(F.hasName() || F.hasLocalLinkage(), +         "Unusual: Unnamed function with non-local linkage", &F); + +  // TODO: Check for irreducible control flow. +} + +void Lint::visitCallSite(CallSite CS) { +  Instruction &I = *CS.getInstruction(); +  Value *Callee = CS.getCalledValue(); + +  visitMemoryReference(I, Callee, MemoryLocation::UnknownSize, 0, nullptr, +                       MemRef::Callee); + +  if (Function *F = dyn_cast<Function>(findValue(Callee, +                                                 /*OffsetOk=*/false))) { +    Assert(CS.getCallingConv() == F->getCallingConv(), +           "Undefined behavior: Caller and callee calling convention differ", +           &I); + +    FunctionType *FT = F->getFunctionType(); +    unsigned NumActualArgs = CS.arg_size(); + +    Assert(FT->isVarArg() ? FT->getNumParams() <= NumActualArgs +                          : FT->getNumParams() == NumActualArgs, +           "Undefined behavior: Call argument count mismatches callee " +           "argument count", +           &I); + +    Assert(FT->getReturnType() == I.getType(), +           "Undefined behavior: Call return type mismatches " +           "callee return type", +           &I); + +    // Check argument types (in case the callee was casted) and attributes. +    // TODO: Verify that caller and callee attributes are compatible. +    Function::arg_iterator PI = F->arg_begin(), PE = F->arg_end(); +    CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end(); +    for (; AI != AE; ++AI) { +      Value *Actual = *AI; +      if (PI != PE) { +        Argument *Formal = &*PI++; +        Assert(Formal->getType() == Actual->getType(), +               "Undefined behavior: Call argument type mismatches " +               "callee parameter type", +               &I); + +        // Check that noalias arguments don't alias other arguments. This is +        // not fully precise because we don't know the sizes of the dereferenced +        // memory regions. +        if (Formal->hasNoAliasAttr() && Actual->getType()->isPointerTy()) { +          AttributeList PAL = CS.getAttributes(); +          unsigned ArgNo = 0; +          for (CallSite::arg_iterator BI = CS.arg_begin(); BI != AE; ++BI) { +            // Skip ByVal arguments since they will be memcpy'd to the callee's +            // stack so we're not really passing the pointer anyway. +            if (PAL.hasParamAttribute(ArgNo++, Attribute::ByVal)) +              continue; +            if (AI != BI && (*BI)->getType()->isPointerTy()) { +              AliasResult Result = AA->alias(*AI, *BI); +              Assert(Result != MustAlias && Result != PartialAlias, +                     "Unusual: noalias argument aliases another argument", &I); +            } +          } +        } + +        // Check that an sret argument points to valid memory. +        if (Formal->hasStructRetAttr() && Actual->getType()->isPointerTy()) { +          Type *Ty = +            cast<PointerType>(Formal->getType())->getElementType(); +          visitMemoryReference(I, Actual, DL->getTypeStoreSize(Ty), +                               DL->getABITypeAlignment(Ty), Ty, +                               MemRef::Read | MemRef::Write); +        } +      } +    } +  } + +  if (CS.isCall()) { +    const CallInst *CI = cast<CallInst>(CS.getInstruction()); +    if (CI->isTailCall()) { +      const AttributeList &PAL = CI->getAttributes(); +      unsigned ArgNo = 0; +      for (Value *Arg : CS.args()) { +        // Skip ByVal arguments since they will be memcpy'd to the callee's +        // stack anyway. +        if (PAL.hasParamAttribute(ArgNo++, Attribute::ByVal)) +          continue; +        Value *Obj = findValue(Arg, /*OffsetOk=*/true); +        Assert(!isa<AllocaInst>(Obj), +               "Undefined behavior: Call with \"tail\" keyword references " +               "alloca", +               &I); +      } +    } +  } + + +  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I)) +    switch (II->getIntrinsicID()) { +    default: break; + +    // TODO: Check more intrinsics + +    case Intrinsic::memcpy: { +      MemCpyInst *MCI = cast<MemCpyInst>(&I); +      // TODO: If the size is known, use it. +      visitMemoryReference(I, MCI->getDest(), MemoryLocation::UnknownSize, +                           MCI->getDestAlignment(), nullptr, MemRef::Write); +      visitMemoryReference(I, MCI->getSource(), MemoryLocation::UnknownSize, +                           MCI->getSourceAlignment(), nullptr, MemRef::Read); + +      // Check that the memcpy arguments don't overlap. The AliasAnalysis API +      // isn't expressive enough for what we really want to do. Known partial +      // overlap is not distinguished from the case where nothing is known. +      uint64_t Size = 0; +      if (const ConstantInt *Len = +              dyn_cast<ConstantInt>(findValue(MCI->getLength(), +                                              /*OffsetOk=*/false))) +        if (Len->getValue().isIntN(32)) +          Size = Len->getValue().getZExtValue(); +      Assert(AA->alias(MCI->getSource(), Size, MCI->getDest(), Size) != +                 MustAlias, +             "Undefined behavior: memcpy source and destination overlap", &I); +      break; +    } +    case Intrinsic::memmove: { +      MemMoveInst *MMI = cast<MemMoveInst>(&I); +      // TODO: If the size is known, use it. +      visitMemoryReference(I, MMI->getDest(), MemoryLocation::UnknownSize, +                           MMI->getDestAlignment(), nullptr, MemRef::Write); +      visitMemoryReference(I, MMI->getSource(), MemoryLocation::UnknownSize, +                           MMI->getSourceAlignment(), nullptr, MemRef::Read); +      break; +    } +    case Intrinsic::memset: { +      MemSetInst *MSI = cast<MemSetInst>(&I); +      // TODO: If the size is known, use it. +      visitMemoryReference(I, MSI->getDest(), MemoryLocation::UnknownSize, +                           MSI->getDestAlignment(), nullptr, MemRef::Write); +      break; +    } + +    case Intrinsic::vastart: +      Assert(I.getParent()->getParent()->isVarArg(), +             "Undefined behavior: va_start called in a non-varargs function", +             &I); + +      visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, +                           nullptr, MemRef::Read | MemRef::Write); +      break; +    case Intrinsic::vacopy: +      visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, +                           nullptr, MemRef::Write); +      visitMemoryReference(I, CS.getArgument(1), MemoryLocation::UnknownSize, 0, +                           nullptr, MemRef::Read); +      break; +    case Intrinsic::vaend: +      visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, +                           nullptr, MemRef::Read | MemRef::Write); +      break; + +    case Intrinsic::stackrestore: +      // Stackrestore doesn't read or write memory, but it sets the +      // stack pointer, which the compiler may read from or write to +      // at any time, so check it for both readability and writeability. +      visitMemoryReference(I, CS.getArgument(0), MemoryLocation::UnknownSize, 0, +                           nullptr, MemRef::Read | MemRef::Write); +      break; +    } +} + +void Lint::visitCallInst(CallInst &I) { +  return visitCallSite(&I); +} + +void Lint::visitInvokeInst(InvokeInst &I) { +  return visitCallSite(&I); +} + +void Lint::visitReturnInst(ReturnInst &I) { +  Function *F = I.getParent()->getParent(); +  Assert(!F->doesNotReturn(), +         "Unusual: Return statement in function with noreturn attribute", &I); + +  if (Value *V = I.getReturnValue()) { +    Value *Obj = findValue(V, /*OffsetOk=*/true); +    Assert(!isa<AllocaInst>(Obj), "Unusual: Returning alloca value", &I); +  } +} + +// TODO: Check that the reference is in bounds. +// TODO: Check readnone/readonly function attributes. +void Lint::visitMemoryReference(Instruction &I, +                                Value *Ptr, uint64_t Size, unsigned Align, +                                Type *Ty, unsigned Flags) { +  // If no memory is being referenced, it doesn't matter if the pointer +  // is valid. +  if (Size == 0) +    return; + +  Value *UnderlyingObject = findValue(Ptr, /*OffsetOk=*/true); +  Assert(!isa<ConstantPointerNull>(UnderlyingObject), +         "Undefined behavior: Null pointer dereference", &I); +  Assert(!isa<UndefValue>(UnderlyingObject), +         "Undefined behavior: Undef pointer dereference", &I); +  Assert(!isa<ConstantInt>(UnderlyingObject) || +             !cast<ConstantInt>(UnderlyingObject)->isMinusOne(), +         "Unusual: All-ones pointer dereference", &I); +  Assert(!isa<ConstantInt>(UnderlyingObject) || +             !cast<ConstantInt>(UnderlyingObject)->isOne(), +         "Unusual: Address one pointer dereference", &I); + +  if (Flags & MemRef::Write) { +    if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(UnderlyingObject)) +      Assert(!GV->isConstant(), "Undefined behavior: Write to read-only memory", +             &I); +    Assert(!isa<Function>(UnderlyingObject) && +               !isa<BlockAddress>(UnderlyingObject), +           "Undefined behavior: Write to text section", &I); +  } +  if (Flags & MemRef::Read) { +    Assert(!isa<Function>(UnderlyingObject), "Unusual: Load from function body", +           &I); +    Assert(!isa<BlockAddress>(UnderlyingObject), +           "Undefined behavior: Load from block address", &I); +  } +  if (Flags & MemRef::Callee) { +    Assert(!isa<BlockAddress>(UnderlyingObject), +           "Undefined behavior: Call to block address", &I); +  } +  if (Flags & MemRef::Branchee) { +    Assert(!isa<Constant>(UnderlyingObject) || +               isa<BlockAddress>(UnderlyingObject), +           "Undefined behavior: Branch to non-blockaddress", &I); +  } + +  // Check for buffer overflows and misalignment. +  // Only handles memory references that read/write something simple like an +  // alloca instruction or a global variable. +  int64_t Offset = 0; +  if (Value *Base = GetPointerBaseWithConstantOffset(Ptr, Offset, *DL)) { +    // OK, so the access is to a constant offset from Ptr.  Check that Ptr is +    // something we can handle and if so extract the size of this base object +    // along with its alignment. +    uint64_t BaseSize = MemoryLocation::UnknownSize; +    unsigned BaseAlign = 0; + +    if (AllocaInst *AI = dyn_cast<AllocaInst>(Base)) { +      Type *ATy = AI->getAllocatedType(); +      if (!AI->isArrayAllocation() && ATy->isSized()) +        BaseSize = DL->getTypeAllocSize(ATy); +      BaseAlign = AI->getAlignment(); +      if (BaseAlign == 0 && ATy->isSized()) +        BaseAlign = DL->getABITypeAlignment(ATy); +    } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Base)) { +      // If the global may be defined differently in another compilation unit +      // then don't warn about funky memory accesses. +      if (GV->hasDefinitiveInitializer()) { +        Type *GTy = GV->getValueType(); +        if (GTy->isSized()) +          BaseSize = DL->getTypeAllocSize(GTy); +        BaseAlign = GV->getAlignment(); +        if (BaseAlign == 0 && GTy->isSized()) +          BaseAlign = DL->getABITypeAlignment(GTy); +      } +    } + +    // Accesses from before the start or after the end of the object are not +    // defined. +    Assert(Size == MemoryLocation::UnknownSize || +               BaseSize == MemoryLocation::UnknownSize || +               (Offset >= 0 && Offset + Size <= BaseSize), +           "Undefined behavior: Buffer overflow", &I); + +    // Accesses that say that the memory is more aligned than it is are not +    // defined. +    if (Align == 0 && Ty && Ty->isSized()) +      Align = DL->getABITypeAlignment(Ty); +    Assert(!BaseAlign || Align <= MinAlign(BaseAlign, Offset), +           "Undefined behavior: Memory reference address is misaligned", &I); +  } +} + +void Lint::visitLoadInst(LoadInst &I) { +  visitMemoryReference(I, I.getPointerOperand(), +                       DL->getTypeStoreSize(I.getType()), I.getAlignment(), +                       I.getType(), MemRef::Read); +} + +void Lint::visitStoreInst(StoreInst &I) { +  visitMemoryReference(I, I.getPointerOperand(), +                       DL->getTypeStoreSize(I.getOperand(0)->getType()), +                       I.getAlignment(), +                       I.getOperand(0)->getType(), MemRef::Write); +} + +void Lint::visitXor(BinaryOperator &I) { +  Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), +         "Undefined result: xor(undef, undef)", &I); +} + +void Lint::visitSub(BinaryOperator &I) { +  Assert(!isa<UndefValue>(I.getOperand(0)) || !isa<UndefValue>(I.getOperand(1)), +         "Undefined result: sub(undef, undef)", &I); +} + +void Lint::visitLShr(BinaryOperator &I) { +  if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(1), +                                                        /*OffsetOk=*/false))) +    Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), +           "Undefined result: Shift count out of range", &I); +} + +void Lint::visitAShr(BinaryOperator &I) { +  if (ConstantInt *CI = +          dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false))) +    Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), +           "Undefined result: Shift count out of range", &I); +} + +void Lint::visitShl(BinaryOperator &I) { +  if (ConstantInt *CI = +          dyn_cast<ConstantInt>(findValue(I.getOperand(1), /*OffsetOk=*/false))) +    Assert(CI->getValue().ult(cast<IntegerType>(I.getType())->getBitWidth()), +           "Undefined result: Shift count out of range", &I); +} + +static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, +                   AssumptionCache *AC) { +  // Assume undef could be zero. +  if (isa<UndefValue>(V)) +    return true; + +  VectorType *VecTy = dyn_cast<VectorType>(V->getType()); +  if (!VecTy) { +    KnownBits Known = computeKnownBits(V, DL, 0, AC, dyn_cast<Instruction>(V), DT); +    return Known.isZero(); +  } + +  // Per-component check doesn't work with zeroinitializer +  Constant *C = dyn_cast<Constant>(V); +  if (!C) +    return false; + +  if (C->isZeroValue()) +    return true; + +  // For a vector, KnownZero will only be true if all values are zero, so check +  // this per component +  for (unsigned I = 0, N = VecTy->getNumElements(); I != N; ++I) { +    Constant *Elem = C->getAggregateElement(I); +    if (isa<UndefValue>(Elem)) +      return true; + +    KnownBits Known = computeKnownBits(Elem, DL); +    if (Known.isZero()) +      return true; +  } + +  return false; +} + +void Lint::visitSDiv(BinaryOperator &I) { +  Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), +         "Undefined behavior: Division by zero", &I); +} + +void Lint::visitUDiv(BinaryOperator &I) { +  Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), +         "Undefined behavior: Division by zero", &I); +} + +void Lint::visitSRem(BinaryOperator &I) { +  Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), +         "Undefined behavior: Division by zero", &I); +} + +void Lint::visitURem(BinaryOperator &I) { +  Assert(!isZero(I.getOperand(1), I.getModule()->getDataLayout(), DT, AC), +         "Undefined behavior: Division by zero", &I); +} + +void Lint::visitAllocaInst(AllocaInst &I) { +  if (isa<ConstantInt>(I.getArraySize())) +    // This isn't undefined behavior, it's just an obvious pessimization. +    Assert(&I.getParent()->getParent()->getEntryBlock() == I.getParent(), +           "Pessimization: Static alloca outside of entry block", &I); + +  // TODO: Check for an unusual size (MSB set?) +} + +void Lint::visitVAArgInst(VAArgInst &I) { +  visitMemoryReference(I, I.getOperand(0), MemoryLocation::UnknownSize, 0, +                       nullptr, MemRef::Read | MemRef::Write); +} + +void Lint::visitIndirectBrInst(IndirectBrInst &I) { +  visitMemoryReference(I, I.getAddress(), MemoryLocation::UnknownSize, 0, +                       nullptr, MemRef::Branchee); + +  Assert(I.getNumDestinations() != 0, +         "Undefined behavior: indirectbr with no destinations", &I); +} + +void Lint::visitExtractElementInst(ExtractElementInst &I) { +  if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getIndexOperand(), +                                                        /*OffsetOk=*/false))) +    Assert(CI->getValue().ult(I.getVectorOperandType()->getNumElements()), +           "Undefined result: extractelement index out of range", &I); +} + +void Lint::visitInsertElementInst(InsertElementInst &I) { +  if (ConstantInt *CI = dyn_cast<ConstantInt>(findValue(I.getOperand(2), +                                                        /*OffsetOk=*/false))) +    Assert(CI->getValue().ult(I.getType()->getNumElements()), +           "Undefined result: insertelement index out of range", &I); +} + +void Lint::visitUnreachableInst(UnreachableInst &I) { +  // This isn't undefined behavior, it's merely suspicious. +  Assert(&I == &I.getParent()->front() || +             std::prev(I.getIterator())->mayHaveSideEffects(), +         "Unusual: unreachable immediately preceded by instruction without " +         "side effects", +         &I); +} + +/// findValue - Look through bitcasts and simple memory reference patterns +/// to identify an equivalent, but more informative, value.  If OffsetOk +/// is true, look through getelementptrs with non-zero offsets too. +/// +/// Most analysis passes don't require this logic, because instcombine +/// will simplify most of these kinds of things away. But it's a goal of +/// this Lint pass to be useful even on non-optimized IR. +Value *Lint::findValue(Value *V, bool OffsetOk) const { +  SmallPtrSet<Value *, 4> Visited; +  return findValueImpl(V, OffsetOk, Visited); +} + +/// findValueImpl - Implementation helper for findValue. +Value *Lint::findValueImpl(Value *V, bool OffsetOk, +                           SmallPtrSetImpl<Value *> &Visited) const { +  // Detect self-referential values. +  if (!Visited.insert(V).second) +    return UndefValue::get(V->getType()); + +  // TODO: Look through sext or zext cast, when the result is known to +  // be interpreted as signed or unsigned, respectively. +  // TODO: Look through eliminable cast pairs. +  // TODO: Look through calls with unique return values. +  // TODO: Look through vector insert/extract/shuffle. +  V = OffsetOk ? GetUnderlyingObject(V, *DL) : V->stripPointerCasts(); +  if (LoadInst *L = dyn_cast<LoadInst>(V)) { +    BasicBlock::iterator BBI = L->getIterator(); +    BasicBlock *BB = L->getParent(); +    SmallPtrSet<BasicBlock *, 4> VisitedBlocks; +    for (;;) { +      if (!VisitedBlocks.insert(BB).second) +        break; +      if (Value *U = +          FindAvailableLoadedValue(L, BB, BBI, DefMaxInstsToScan, AA)) +        return findValueImpl(U, OffsetOk, Visited); +      if (BBI != BB->begin()) break; +      BB = BB->getUniquePredecessor(); +      if (!BB) break; +      BBI = BB->end(); +    } +  } else if (PHINode *PN = dyn_cast<PHINode>(V)) { +    if (Value *W = PN->hasConstantValue()) +      if (W != V) +        return findValueImpl(W, OffsetOk, Visited); +  } else if (CastInst *CI = dyn_cast<CastInst>(V)) { +    if (CI->isNoopCast(*DL)) +      return findValueImpl(CI->getOperand(0), OffsetOk, Visited); +  } else if (ExtractValueInst *Ex = dyn_cast<ExtractValueInst>(V)) { +    if (Value *W = FindInsertedValue(Ex->getAggregateOperand(), +                                     Ex->getIndices())) +      if (W != V) +        return findValueImpl(W, OffsetOk, Visited); +  } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { +    // Same as above, but for ConstantExpr instead of Instruction. +    if (Instruction::isCast(CE->getOpcode())) { +      if (CastInst::isNoopCast(Instruction::CastOps(CE->getOpcode()), +                               CE->getOperand(0)->getType(), CE->getType(), +                               *DL)) +        return findValueImpl(CE->getOperand(0), OffsetOk, Visited); +    } else if (CE->getOpcode() == Instruction::ExtractValue) { +      ArrayRef<unsigned> Indices = CE->getIndices(); +      if (Value *W = FindInsertedValue(CE->getOperand(0), Indices)) +        if (W != V) +          return findValueImpl(W, OffsetOk, Visited); +    } +  } + +  // As a last resort, try SimplifyInstruction or constant folding. +  if (Instruction *Inst = dyn_cast<Instruction>(V)) { +    if (Value *W = SimplifyInstruction(Inst, {*DL, TLI, DT, AC})) +      return findValueImpl(W, OffsetOk, Visited); +  } else if (auto *C = dyn_cast<Constant>(V)) { +    if (Value *W = ConstantFoldConstant(C, *DL, TLI)) +      if (W && W != V) +        return findValueImpl(W, OffsetOk, Visited); +  } + +  return V; +} + +//===----------------------------------------------------------------------===// +//  Implement the public interfaces to this file... +//===----------------------------------------------------------------------===// + +FunctionPass *llvm::createLintPass() { +  return new Lint(); +} + +/// lintFunction - Check a function for errors, printing messages on stderr. +/// +void llvm::lintFunction(const Function &f) { +  Function &F = const_cast<Function&>(f); +  assert(!F.isDeclaration() && "Cannot lint external functions"); + +  legacy::FunctionPassManager FPM(F.getParent()); +  Lint *V = new Lint(); +  FPM.add(V); +  FPM.run(F); +} + +/// lintModule - Check a module for errors, printing messages on stderr. +/// +void llvm::lintModule(const Module &M) { +  legacy::PassManager PM; +  Lint *V = new Lint(); +  PM.add(V); +  PM.run(const_cast<Module&>(M)); +} diff --git a/contrib/llvm/lib/Analysis/Loads.cpp b/contrib/llvm/lib/Analysis/Loads.cpp new file mode 100644 index 000000000000..d319d4c249d3 --- /dev/null +++ b/contrib/llvm/lib/Analysis/Loads.cpp @@ -0,0 +1,441 @@ +//===- Loads.cpp - Local load analysis ------------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines simple local analyses for load instructions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Statepoint.h" + +using namespace llvm; + +static bool isAligned(const Value *Base, const APInt &Offset, unsigned Align, +                      const DataLayout &DL) { +  APInt BaseAlign(Offset.getBitWidth(), Base->getPointerAlignment(DL)); + +  if (!BaseAlign) { +    Type *Ty = Base->getType()->getPointerElementType(); +    if (!Ty->isSized()) +      return false; +    BaseAlign = DL.getABITypeAlignment(Ty); +  } + +  APInt Alignment(Offset.getBitWidth(), Align); + +  assert(Alignment.isPowerOf2() && "must be a power of 2!"); +  return BaseAlign.uge(Alignment) && !(Offset & (Alignment-1)); +} + +static bool isAligned(const Value *Base, unsigned Align, const DataLayout &DL) { +  Type *Ty = Base->getType(); +  assert(Ty->isSized() && "must be sized"); +  APInt Offset(DL.getTypeStoreSizeInBits(Ty), 0); +  return isAligned(Base, Offset, Align, DL); +} + +/// Test if V is always a pointer to allocated and suitably aligned memory for +/// a simple load or store. +static bool isDereferenceableAndAlignedPointer( +    const Value *V, unsigned Align, const APInt &Size, const DataLayout &DL, +    const Instruction *CtxI, const DominatorTree *DT, +    SmallPtrSetImpl<const Value *> &Visited) { +  // Already visited?  Bail out, we've likely hit unreachable code. +  if (!Visited.insert(V).second) +    return false; + +  // Note that it is not safe to speculate into a malloc'd region because +  // malloc may return null. + +  // bitcast instructions are no-ops as far as dereferenceability is concerned. +  if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(V)) +    return isDereferenceableAndAlignedPointer(BC->getOperand(0), Align, Size, +                                              DL, CtxI, DT, Visited); + +  bool CheckForNonNull = false; +  APInt KnownDerefBytes(Size.getBitWidth(), +                        V->getPointerDereferenceableBytes(DL, CheckForNonNull)); +  if (KnownDerefBytes.getBoolValue()) { +    if (KnownDerefBytes.uge(Size)) +      if (!CheckForNonNull || isKnownNonZero(V, DL, 0, nullptr, CtxI, DT)) +        return isAligned(V, Align, DL); +  } + +  // For GEPs, determine if the indexing lands within the allocated object. +  if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { +    const Value *Base = GEP->getPointerOperand(); + +    APInt Offset(DL.getIndexTypeSizeInBits(GEP->getType()), 0); +    if (!GEP->accumulateConstantOffset(DL, Offset) || Offset.isNegative() || +        !Offset.urem(APInt(Offset.getBitWidth(), Align)).isMinValue()) +      return false; + +    // If the base pointer is dereferenceable for Offset+Size bytes, then the +    // GEP (== Base + Offset) is dereferenceable for Size bytes.  If the base +    // pointer is aligned to Align bytes, and the Offset is divisible by Align +    // then the GEP (== Base + Offset == k_0 * Align + k_1 * Align) is also +    // aligned to Align bytes. + +    // Offset and Size may have different bit widths if we have visited an +    // addrspacecast, so we can't do arithmetic directly on the APInt values. +    return isDereferenceableAndAlignedPointer( +        Base, Align, Offset + Size.sextOrTrunc(Offset.getBitWidth()), +        DL, CtxI, DT, Visited); +  } + +  // For gc.relocate, look through relocations +  if (const GCRelocateInst *RelocateInst = dyn_cast<GCRelocateInst>(V)) +    return isDereferenceableAndAlignedPointer( +        RelocateInst->getDerivedPtr(), Align, Size, DL, CtxI, DT, Visited); + +  if (const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(V)) +    return isDereferenceableAndAlignedPointer(ASC->getOperand(0), Align, Size, +                                              DL, CtxI, DT, Visited); + +  if (auto CS = ImmutableCallSite(V)) +    if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) +      return isDereferenceableAndAlignedPointer(RP, Align, Size, DL, CtxI, DT, +                                                Visited); + +  // If we don't know, assume the worst. +  return false; +} + +bool llvm::isDereferenceableAndAlignedPointer(const Value *V, unsigned Align, +                                              const APInt &Size, +                                              const DataLayout &DL, +                                              const Instruction *CtxI, +                                              const DominatorTree *DT) { +  SmallPtrSet<const Value *, 32> Visited; +  return ::isDereferenceableAndAlignedPointer(V, Align, Size, DL, CtxI, DT, +                                              Visited); +} + +bool llvm::isDereferenceableAndAlignedPointer(const Value *V, unsigned Align, +                                              const DataLayout &DL, +                                              const Instruction *CtxI, +                                              const DominatorTree *DT) { +  // When dereferenceability information is provided by a dereferenceable +  // attribute, we know exactly how many bytes are dereferenceable. If we can +  // determine the exact offset to the attributed variable, we can use that +  // information here. +  Type *VTy = V->getType(); +  Type *Ty = VTy->getPointerElementType(); + +  // Require ABI alignment for loads without alignment specification +  if (Align == 0) +    Align = DL.getABITypeAlignment(Ty); + +  if (!Ty->isSized()) +    return false; + +  SmallPtrSet<const Value *, 32> Visited; +  return ::isDereferenceableAndAlignedPointer( +      V, Align, APInt(DL.getIndexTypeSizeInBits(VTy), DL.getTypeStoreSize(Ty)), DL, +      CtxI, DT, Visited); +} + +bool llvm::isDereferenceablePointer(const Value *V, const DataLayout &DL, +                                    const Instruction *CtxI, +                                    const DominatorTree *DT) { +  return isDereferenceableAndAlignedPointer(V, 1, DL, CtxI, DT); +} + +/// Test if A and B will obviously have the same value. +/// +/// This includes recognizing that %t0 and %t1 will have the same +/// value in code like this: +/// \code +///   %t0 = getelementptr \@a, 0, 3 +///   store i32 0, i32* %t0 +///   %t1 = getelementptr \@a, 0, 3 +///   %t2 = load i32* %t1 +/// \endcode +/// +static bool AreEquivalentAddressValues(const Value *A, const Value *B) { +  // Test if the values are trivially equivalent. +  if (A == B) +    return true; + +  // Test if the values come from identical arithmetic instructions. +  // Use isIdenticalToWhenDefined instead of isIdenticalTo because +  // this function is only used when one address use dominates the +  // other, which means that they'll always either have the same +  // value or one of them will have an undefined value. +  if (isa<BinaryOperator>(A) || isa<CastInst>(A) || isa<PHINode>(A) || +      isa<GetElementPtrInst>(A)) +    if (const Instruction *BI = dyn_cast<Instruction>(B)) +      if (cast<Instruction>(A)->isIdenticalToWhenDefined(BI)) +        return true; + +  // Otherwise they may not be equivalent. +  return false; +} + +/// Check if executing a load of this pointer value cannot trap. +/// +/// If DT and ScanFrom are specified this method performs context-sensitive +/// analysis and returns true if it is safe to load immediately before ScanFrom. +/// +/// If it is not obviously safe to load from the specified pointer, we do +/// a quick local scan of the basic block containing \c ScanFrom, to determine +/// if the address is already accessed. +/// +/// This uses the pointee type to determine how many bytes need to be safe to +/// load from the pointer. +bool llvm::isSafeToLoadUnconditionally(Value *V, unsigned Align, +                                       const DataLayout &DL, +                                       Instruction *ScanFrom, +                                       const DominatorTree *DT) { +  // Zero alignment means that the load has the ABI alignment for the target +  if (Align == 0) +    Align = DL.getABITypeAlignment(V->getType()->getPointerElementType()); +  assert(isPowerOf2_32(Align)); + +  // If DT is not specified we can't make context-sensitive query +  const Instruction* CtxI = DT ? ScanFrom : nullptr; +  if (isDereferenceableAndAlignedPointer(V, Align, DL, CtxI, DT)) +    return true; + +  int64_t ByteOffset = 0; +  Value *Base = V; +  Base = GetPointerBaseWithConstantOffset(V, ByteOffset, DL); + +  if (ByteOffset < 0) // out of bounds +    return false; + +  Type *BaseType = nullptr; +  unsigned BaseAlign = 0; +  if (const AllocaInst *AI = dyn_cast<AllocaInst>(Base)) { +    // An alloca is safe to load from as load as it is suitably aligned. +    BaseType = AI->getAllocatedType(); +    BaseAlign = AI->getAlignment(); +  } else if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Base)) { +    // Global variables are not necessarily safe to load from if they are +    // interposed arbitrarily. Their size may change or they may be weak and +    // require a test to determine if they were in fact provided. +    if (!GV->isInterposable()) { +      BaseType = GV->getType()->getElementType(); +      BaseAlign = GV->getAlignment(); +    } +  } + +  PointerType *AddrTy = cast<PointerType>(V->getType()); +  uint64_t LoadSize = DL.getTypeStoreSize(AddrTy->getElementType()); + +  // If we found a base allocated type from either an alloca or global variable, +  // try to see if we are definitively within the allocated region. We need to +  // know the size of the base type and the loaded type to do anything in this +  // case. +  if (BaseType && BaseType->isSized()) { +    if (BaseAlign == 0) +      BaseAlign = DL.getPrefTypeAlignment(BaseType); + +    if (Align <= BaseAlign) { +      // Check if the load is within the bounds of the underlying object. +      if (ByteOffset + LoadSize <= DL.getTypeAllocSize(BaseType) && +          ((ByteOffset % Align) == 0)) +        return true; +    } +  } + +  if (!ScanFrom) +    return false; + +  // Otherwise, be a little bit aggressive by scanning the local block where we +  // want to check to see if the pointer is already being loaded or stored +  // from/to.  If so, the previous load or store would have already trapped, +  // so there is no harm doing an extra load (also, CSE will later eliminate +  // the load entirely). +  BasicBlock::iterator BBI = ScanFrom->getIterator(), +                       E = ScanFrom->getParent()->begin(); + +  // We can at least always strip pointer casts even though we can't use the +  // base here. +  V = V->stripPointerCasts(); + +  while (BBI != E) { +    --BBI; + +    // If we see a free or a call which may write to memory (i.e. which might do +    // a free) the pointer could be marked invalid. +    if (isa<CallInst>(BBI) && BBI->mayWriteToMemory() && +        !isa<DbgInfoIntrinsic>(BBI)) +      return false; + +    Value *AccessedPtr; +    unsigned AccessedAlign; +    if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { +      AccessedPtr = LI->getPointerOperand(); +      AccessedAlign = LI->getAlignment(); +    } else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) { +      AccessedPtr = SI->getPointerOperand(); +      AccessedAlign = SI->getAlignment(); +    } else +      continue; + +    Type *AccessedTy = AccessedPtr->getType()->getPointerElementType(); +    if (AccessedAlign == 0) +      AccessedAlign = DL.getABITypeAlignment(AccessedTy); +    if (AccessedAlign < Align) +      continue; + +    // Handle trivial cases. +    if (AccessedPtr == V) +      return true; + +    if (AreEquivalentAddressValues(AccessedPtr->stripPointerCasts(), V) && +        LoadSize <= DL.getTypeStoreSize(AccessedTy)) +      return true; +  } +  return false; +} + +/// DefMaxInstsToScan - the default number of maximum instructions +/// to scan in the block, used by FindAvailableLoadedValue(). +/// FindAvailableLoadedValue() was introduced in r60148, to improve jump +/// threading in part by eliminating partially redundant loads. +/// At that point, the value of MaxInstsToScan was already set to '6' +/// without documented explanation. +cl::opt<unsigned> +llvm::DefMaxInstsToScan("available-load-scan-limit", cl::init(6), cl::Hidden, +  cl::desc("Use this to specify the default maximum number of instructions " +           "to scan backward from a given instruction, when searching for " +           "available loaded value")); + +Value *llvm::FindAvailableLoadedValue(LoadInst *Load, +                                      BasicBlock *ScanBB, +                                      BasicBlock::iterator &ScanFrom, +                                      unsigned MaxInstsToScan, +                                      AliasAnalysis *AA, bool *IsLoad, +                                      unsigned *NumScanedInst) { +  // Don't CSE load that is volatile or anything stronger than unordered. +  if (!Load->isUnordered()) +    return nullptr; + +  return FindAvailablePtrLoadStore( +      Load->getPointerOperand(), Load->getType(), Load->isAtomic(), ScanBB, +      ScanFrom, MaxInstsToScan, AA, IsLoad, NumScanedInst); +} + +Value *llvm::FindAvailablePtrLoadStore(Value *Ptr, Type *AccessTy, +                                       bool AtLeastAtomic, BasicBlock *ScanBB, +                                       BasicBlock::iterator &ScanFrom, +                                       unsigned MaxInstsToScan, +                                       AliasAnalysis *AA, bool *IsLoadCSE, +                                       unsigned *NumScanedInst) { +  if (MaxInstsToScan == 0) +    MaxInstsToScan = ~0U; + +  const DataLayout &DL = ScanBB->getModule()->getDataLayout(); + +  // Try to get the store size for the type. +  uint64_t AccessSize = DL.getTypeStoreSize(AccessTy); + +  Value *StrippedPtr = Ptr->stripPointerCasts(); + +  while (ScanFrom != ScanBB->begin()) { +    // We must ignore debug info directives when counting (otherwise they +    // would affect codegen). +    Instruction *Inst = &*--ScanFrom; +    if (isa<DbgInfoIntrinsic>(Inst)) +      continue; + +    // Restore ScanFrom to expected value in case next test succeeds +    ScanFrom++; + +    if (NumScanedInst) +      ++(*NumScanedInst); + +    // Don't scan huge blocks. +    if (MaxInstsToScan-- == 0) +      return nullptr; + +    --ScanFrom; +    // If this is a load of Ptr, the loaded value is available. +    // (This is true even if the load is volatile or atomic, although +    // those cases are unlikely.) +    if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) +      if (AreEquivalentAddressValues( +              LI->getPointerOperand()->stripPointerCasts(), StrippedPtr) && +          CastInst::isBitOrNoopPointerCastable(LI->getType(), AccessTy, DL)) { + +        // We can value forward from an atomic to a non-atomic, but not the +        // other way around. +        if (LI->isAtomic() < AtLeastAtomic) +          return nullptr; + +        if (IsLoadCSE) +            *IsLoadCSE = true; +        return LI; +      } + +    if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { +      Value *StorePtr = SI->getPointerOperand()->stripPointerCasts(); +      // If this is a store through Ptr, the value is available! +      // (This is true even if the store is volatile or atomic, although +      // those cases are unlikely.) +      if (AreEquivalentAddressValues(StorePtr, StrippedPtr) && +          CastInst::isBitOrNoopPointerCastable(SI->getValueOperand()->getType(), +                                               AccessTy, DL)) { + +        // We can value forward from an atomic to a non-atomic, but not the +        // other way around. +        if (SI->isAtomic() < AtLeastAtomic) +          return nullptr; + +        if (IsLoadCSE) +          *IsLoadCSE = false; +        return SI->getOperand(0); +      } + +      // If both StrippedPtr and StorePtr reach all the way to an alloca or +      // global and they are different, ignore the store. This is a trivial form +      // of alias analysis that is important for reg2mem'd code. +      if ((isa<AllocaInst>(StrippedPtr) || isa<GlobalVariable>(StrippedPtr)) && +          (isa<AllocaInst>(StorePtr) || isa<GlobalVariable>(StorePtr)) && +          StrippedPtr != StorePtr) +        continue; + +      // If we have alias analysis and it says the store won't modify the loaded +      // value, ignore the store. +      if (AA && !isModSet(AA->getModRefInfo(SI, StrippedPtr, AccessSize))) +        continue; + +      // Otherwise the store that may or may not alias the pointer, bail out. +      ++ScanFrom; +      return nullptr; +    } + +    // If this is some other instruction that may clobber Ptr, bail out. +    if (Inst->mayWriteToMemory()) { +      // If alias analysis claims that it really won't modify the load, +      // ignore it. +      if (AA && !isModSet(AA->getModRefInfo(Inst, StrippedPtr, AccessSize))) +        continue; + +      // May modify the pointer, bail out. +      ++ScanFrom; +      return nullptr; +    } +  } + +  // Got to the start of the block, we didn't find it, but are done for this +  // block. +  return nullptr; +} diff --git a/contrib/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/contrib/llvm/lib/Analysis/LoopAccessAnalysis.cpp new file mode 100644 index 000000000000..a24d66011b8d --- /dev/null +++ b/contrib/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -0,0 +1,2377 @@ +//===- LoopAccessAnalysis.cpp - Loop Access Analysis Implementation --------==// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The implementation for the loop memory dependence that was originally +// developed for the loop vectorizer. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstdlib> +#include <iterator> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "loop-accesses" + +static cl::opt<unsigned, true> +VectorizationFactor("force-vector-width", cl::Hidden, +                    cl::desc("Sets the SIMD width. Zero is autoselect."), +                    cl::location(VectorizerParams::VectorizationFactor)); +unsigned VectorizerParams::VectorizationFactor; + +static cl::opt<unsigned, true> +VectorizationInterleave("force-vector-interleave", cl::Hidden, +                        cl::desc("Sets the vectorization interleave count. " +                                 "Zero is autoselect."), +                        cl::location( +                            VectorizerParams::VectorizationInterleave)); +unsigned VectorizerParams::VectorizationInterleave; + +static cl::opt<unsigned, true> RuntimeMemoryCheckThreshold( +    "runtime-memory-check-threshold", cl::Hidden, +    cl::desc("When performing memory disambiguation checks at runtime do not " +             "generate more than this number of comparisons (default = 8)."), +    cl::location(VectorizerParams::RuntimeMemoryCheckThreshold), cl::init(8)); +unsigned VectorizerParams::RuntimeMemoryCheckThreshold; + +/// The maximum iterations used to merge memory checks +static cl::opt<unsigned> MemoryCheckMergeThreshold( +    "memory-check-merge-threshold", cl::Hidden, +    cl::desc("Maximum number of comparisons done when trying to merge " +             "runtime memory checks. (default = 100)"), +    cl::init(100)); + +/// Maximum SIMD width. +const unsigned VectorizerParams::MaxVectorWidth = 64; + +/// We collect dependences up to this threshold. +static cl::opt<unsigned> +    MaxDependences("max-dependences", cl::Hidden, +                   cl::desc("Maximum number of dependences collected by " +                            "loop-access analysis (default = 100)"), +                   cl::init(100)); + +/// This enables versioning on the strides of symbolically striding memory +/// accesses in code like the following. +///   for (i = 0; i < N; ++i) +///     A[i * Stride1] += B[i * Stride2] ... +/// +/// Will be roughly translated to +///    if (Stride1 == 1 && Stride2 == 1) { +///      for (i = 0; i < N; i+=4) +///       A[i:i+3] += ... +///    } else +///      ... +static cl::opt<bool> EnableMemAccessVersioning( +    "enable-mem-access-versioning", cl::init(true), cl::Hidden, +    cl::desc("Enable symbolic stride memory access versioning")); + +/// Enable store-to-load forwarding conflict detection. This option can +/// be disabled for correctness testing. +static cl::opt<bool> EnableForwardingConflictDetection( +    "store-to-load-forwarding-conflict-detection", cl::Hidden, +    cl::desc("Enable conflict detection in loop-access analysis"), +    cl::init(true)); + +bool VectorizerParams::isInterleaveForced() { +  return ::VectorizationInterleave.getNumOccurrences() > 0; +} + +Value *llvm::stripIntegerCast(Value *V) { +  if (auto *CI = dyn_cast<CastInst>(V)) +    if (CI->getOperand(0)->getType()->isIntegerTy()) +      return CI->getOperand(0); +  return V; +} + +const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, +                                            const ValueToValueMap &PtrToStride, +                                            Value *Ptr, Value *OrigPtr) { +  const SCEV *OrigSCEV = PSE.getSCEV(Ptr); + +  // If there is an entry in the map return the SCEV of the pointer with the +  // symbolic stride replaced by one. +  ValueToValueMap::const_iterator SI = +      PtrToStride.find(OrigPtr ? OrigPtr : Ptr); +  if (SI != PtrToStride.end()) { +    Value *StrideVal = SI->second; + +    // Strip casts. +    StrideVal = stripIntegerCast(StrideVal); + +    ScalarEvolution *SE = PSE.getSE(); +    const auto *U = cast<SCEVUnknown>(SE->getSCEV(StrideVal)); +    const auto *CT = +        static_cast<const SCEVConstant *>(SE->getOne(StrideVal->getType())); + +    PSE.addPredicate(*SE->getEqualPredicate(U, CT)); +    auto *Expr = PSE.getSCEV(Ptr); + +    LLVM_DEBUG(dbgs() << "LAA: Replacing SCEV: " << *OrigSCEV +                      << " by: " << *Expr << "\n"); +    return Expr; +  } + +  // Otherwise, just return the SCEV of the original pointer. +  return OrigSCEV; +} + +/// Calculate Start and End points of memory access. +/// Let's assume A is the first access and B is a memory access on N-th loop +/// iteration. Then B is calculated as: +///   B = A + Step*N . +/// Step value may be positive or negative. +/// N is a calculated back-edge taken count: +///     N = (TripCount > 0) ? RoundDown(TripCount -1 , VF) : 0 +/// Start and End points are calculated in the following way: +/// Start = UMIN(A, B) ; End = UMAX(A, B) + SizeOfElt, +/// where SizeOfElt is the size of single memory access in bytes. +/// +/// There is no conflict when the intervals are disjoint: +/// NoConflict = (P2.Start >= P1.End) || (P1.Start >= P2.End) +void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, +                                    unsigned DepSetId, unsigned ASId, +                                    const ValueToValueMap &Strides, +                                    PredicatedScalarEvolution &PSE) { +  // Get the stride replaced scev. +  const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); +  ScalarEvolution *SE = PSE.getSE(); + +  const SCEV *ScStart; +  const SCEV *ScEnd; + +  if (SE->isLoopInvariant(Sc, Lp)) +    ScStart = ScEnd = Sc; +  else { +    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc); +    assert(AR && "Invalid addrec expression"); +    const SCEV *Ex = PSE.getBackedgeTakenCount(); + +    ScStart = AR->getStart(); +    ScEnd = AR->evaluateAtIteration(Ex, *SE); +    const SCEV *Step = AR->getStepRecurrence(*SE); + +    // For expressions with negative step, the upper bound is ScStart and the +    // lower bound is ScEnd. +    if (const auto *CStep = dyn_cast<SCEVConstant>(Step)) { +      if (CStep->getValue()->isNegative()) +        std::swap(ScStart, ScEnd); +    } else { +      // Fallback case: the step is not constant, but we can still +      // get the upper and lower bounds of the interval by using min/max +      // expressions. +      ScStart = SE->getUMinExpr(ScStart, ScEnd); +      ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd); +    } +    // Add the size of the pointed element to ScEnd. +    unsigned EltSize = +      Ptr->getType()->getPointerElementType()->getScalarSizeInBits() / 8; +    const SCEV *EltSizeSCEV = SE->getConstant(ScEnd->getType(), EltSize); +    ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV); +  } + +  Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); +} + +SmallVector<RuntimePointerChecking::PointerCheck, 4> +RuntimePointerChecking::generateChecks() const { +  SmallVector<PointerCheck, 4> Checks; + +  for (unsigned I = 0; I < CheckingGroups.size(); ++I) { +    for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) { +      const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I]; +      const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J]; + +      if (needsChecking(CGI, CGJ)) +        Checks.push_back(std::make_pair(&CGI, &CGJ)); +    } +  } +  return Checks; +} + +void RuntimePointerChecking::generateChecks( +    MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) { +  assert(Checks.empty() && "Checks is not empty"); +  groupChecks(DepCands, UseDependencies); +  Checks = generateChecks(); +} + +bool RuntimePointerChecking::needsChecking(const CheckingPtrGroup &M, +                                           const CheckingPtrGroup &N) const { +  for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I) +    for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J) +      if (needsChecking(M.Members[I], N.Members[J])) +        return true; +  return false; +} + +/// Compare \p I and \p J and return the minimum. +/// Return nullptr in case we couldn't find an answer. +static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J, +                                   ScalarEvolution *SE) { +  const SCEV *Diff = SE->getMinusSCEV(J, I); +  const SCEVConstant *C = dyn_cast<const SCEVConstant>(Diff); + +  if (!C) +    return nullptr; +  if (C->getValue()->isNegative()) +    return J; +  return I; +} + +bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) { +  const SCEV *Start = RtCheck.Pointers[Index].Start; +  const SCEV *End = RtCheck.Pointers[Index].End; + +  // Compare the starts and ends with the known minimum and maximum +  // of this set. We need to know how we compare against the min/max +  // of the set in order to be able to emit memchecks. +  const SCEV *Min0 = getMinFromExprs(Start, Low, RtCheck.SE); +  if (!Min0) +    return false; + +  const SCEV *Min1 = getMinFromExprs(End, High, RtCheck.SE); +  if (!Min1) +    return false; + +  // Update the low bound  expression if we've found a new min value. +  if (Min0 == Start) +    Low = Start; + +  // Update the high bound expression if we've found a new max value. +  if (Min1 != End) +    High = End; + +  Members.push_back(Index); +  return true; +} + +void RuntimePointerChecking::groupChecks( +    MemoryDepChecker::DepCandidates &DepCands, bool UseDependencies) { +  // We build the groups from dependency candidates equivalence classes +  // because: +  //    - We know that pointers in the same equivalence class share +  //      the same underlying object and therefore there is a chance +  //      that we can compare pointers +  //    - We wouldn't be able to merge two pointers for which we need +  //      to emit a memcheck. The classes in DepCands are already +  //      conveniently built such that no two pointers in the same +  //      class need checking against each other. + +  // We use the following (greedy) algorithm to construct the groups +  // For every pointer in the equivalence class: +  //   For each existing group: +  //   - if the difference between this pointer and the min/max bounds +  //     of the group is a constant, then make the pointer part of the +  //     group and update the min/max bounds of that group as required. + +  CheckingGroups.clear(); + +  // If we need to check two pointers to the same underlying object +  // with a non-constant difference, we shouldn't perform any pointer +  // grouping with those pointers. This is because we can easily get +  // into cases where the resulting check would return false, even when +  // the accesses are safe. +  // +  // The following example shows this: +  // for (i = 0; i < 1000; ++i) +  //   a[5000 + i * m] = a[i] + a[i + 9000] +  // +  // Here grouping gives a check of (5000, 5000 + 1000 * m) against +  // (0, 10000) which is always false. However, if m is 1, there is no +  // dependence. Not grouping the checks for a[i] and a[i + 9000] allows +  // us to perform an accurate check in this case. +  // +  // The above case requires that we have an UnknownDependence between +  // accesses to the same underlying object. This cannot happen unless +  // ShouldRetryWithRuntimeCheck is set, and therefore UseDependencies +  // is also false. In this case we will use the fallback path and create +  // separate checking groups for all pointers. + +  // If we don't have the dependency partitions, construct a new +  // checking pointer group for each pointer. This is also required +  // for correctness, because in this case we can have checking between +  // pointers to the same underlying object. +  if (!UseDependencies) { +    for (unsigned I = 0; I < Pointers.size(); ++I) +      CheckingGroups.push_back(CheckingPtrGroup(I, *this)); +    return; +  } + +  unsigned TotalComparisons = 0; + +  DenseMap<Value *, unsigned> PositionMap; +  for (unsigned Index = 0; Index < Pointers.size(); ++Index) +    PositionMap[Pointers[Index].PointerValue] = Index; + +  // We need to keep track of what pointers we've already seen so we +  // don't process them twice. +  SmallSet<unsigned, 2> Seen; + +  // Go through all equivalence classes, get the "pointer check groups" +  // and add them to the overall solution. We use the order in which accesses +  // appear in 'Pointers' to enforce determinism. +  for (unsigned I = 0; I < Pointers.size(); ++I) { +    // We've seen this pointer before, and therefore already processed +    // its equivalence class. +    if (Seen.count(I)) +      continue; + +    MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue, +                                           Pointers[I].IsWritePtr); + +    SmallVector<CheckingPtrGroup, 2> Groups; +    auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access)); + +    // Because DepCands is constructed by visiting accesses in the order in +    // which they appear in alias sets (which is deterministic) and the +    // iteration order within an equivalence class member is only dependent on +    // the order in which unions and insertions are performed on the +    // equivalence class, the iteration order is deterministic. +    for (auto MI = DepCands.member_begin(LeaderI), ME = DepCands.member_end(); +         MI != ME; ++MI) { +      unsigned Pointer = PositionMap[MI->getPointer()]; +      bool Merged = false; +      // Mark this pointer as seen. +      Seen.insert(Pointer); + +      // Go through all the existing sets and see if we can find one +      // which can include this pointer. +      for (CheckingPtrGroup &Group : Groups) { +        // Don't perform more than a certain amount of comparisons. +        // This should limit the cost of grouping the pointers to something +        // reasonable.  If we do end up hitting this threshold, the algorithm +        // will create separate groups for all remaining pointers. +        if (TotalComparisons > MemoryCheckMergeThreshold) +          break; + +        TotalComparisons++; + +        if (Group.addPointer(Pointer)) { +          Merged = true; +          break; +        } +      } + +      if (!Merged) +        // We couldn't add this pointer to any existing set or the threshold +        // for the number of comparisons has been reached. Create a new group +        // to hold the current pointer. +        Groups.push_back(CheckingPtrGroup(Pointer, *this)); +    } + +    // We've computed the grouped checks for this partition. +    // Save the results and continue with the next one. +    std::copy(Groups.begin(), Groups.end(), std::back_inserter(CheckingGroups)); +  } +} + +bool RuntimePointerChecking::arePointersInSamePartition( +    const SmallVectorImpl<int> &PtrToPartition, unsigned PtrIdx1, +    unsigned PtrIdx2) { +  return (PtrToPartition[PtrIdx1] != -1 && +          PtrToPartition[PtrIdx1] == PtrToPartition[PtrIdx2]); +} + +bool RuntimePointerChecking::needsChecking(unsigned I, unsigned J) const { +  const PointerInfo &PointerI = Pointers[I]; +  const PointerInfo &PointerJ = Pointers[J]; + +  // No need to check if two readonly pointers intersect. +  if (!PointerI.IsWritePtr && !PointerJ.IsWritePtr) +    return false; + +  // Only need to check pointers between two different dependency sets. +  if (PointerI.DependencySetId == PointerJ.DependencySetId) +    return false; + +  // Only need to check pointers in the same alias set. +  if (PointerI.AliasSetId != PointerJ.AliasSetId) +    return false; + +  return true; +} + +void RuntimePointerChecking::printChecks( +    raw_ostream &OS, const SmallVectorImpl<PointerCheck> &Checks, +    unsigned Depth) const { +  unsigned N = 0; +  for (const auto &Check : Checks) { +    const auto &First = Check.first->Members, &Second = Check.second->Members; + +    OS.indent(Depth) << "Check " << N++ << ":\n"; + +    OS.indent(Depth + 2) << "Comparing group (" << Check.first << "):\n"; +    for (unsigned K = 0; K < First.size(); ++K) +      OS.indent(Depth + 2) << *Pointers[First[K]].PointerValue << "\n"; + +    OS.indent(Depth + 2) << "Against group (" << Check.second << "):\n"; +    for (unsigned K = 0; K < Second.size(); ++K) +      OS.indent(Depth + 2) << *Pointers[Second[K]].PointerValue << "\n"; +  } +} + +void RuntimePointerChecking::print(raw_ostream &OS, unsigned Depth) const { + +  OS.indent(Depth) << "Run-time memory checks:\n"; +  printChecks(OS, Checks, Depth); + +  OS.indent(Depth) << "Grouped accesses:\n"; +  for (unsigned I = 0; I < CheckingGroups.size(); ++I) { +    const auto &CG = CheckingGroups[I]; + +    OS.indent(Depth + 2) << "Group " << &CG << ":\n"; +    OS.indent(Depth + 4) << "(Low: " << *CG.Low << " High: " << *CG.High +                         << ")\n"; +    for (unsigned J = 0; J < CG.Members.size(); ++J) { +      OS.indent(Depth + 6) << "Member: " << *Pointers[CG.Members[J]].Expr +                           << "\n"; +    } +  } +} + +namespace { + +/// Analyses memory accesses in a loop. +/// +/// Checks whether run time pointer checks are needed and builds sets for data +/// dependence checking. +class AccessAnalysis { +public: +  /// Read or write access location. +  typedef PointerIntPair<Value *, 1, bool> MemAccessInfo; +  typedef SmallVector<MemAccessInfo, 8> MemAccessInfoList; + +  AccessAnalysis(const DataLayout &Dl, Loop *TheLoop, AliasAnalysis *AA, +                 LoopInfo *LI, MemoryDepChecker::DepCandidates &DA, +                 PredicatedScalarEvolution &PSE) +      : DL(Dl), TheLoop(TheLoop), AST(*AA), LI(LI), DepCands(DA), +        IsRTCheckAnalysisNeeded(false), PSE(PSE) {} + +  /// Register a load  and whether it is only read from. +  void addLoad(MemoryLocation &Loc, bool IsReadOnly) { +    Value *Ptr = const_cast<Value*>(Loc.Ptr); +    AST.add(Ptr, MemoryLocation::UnknownSize, Loc.AATags); +    Accesses.insert(MemAccessInfo(Ptr, false)); +    if (IsReadOnly) +      ReadOnlyPtr.insert(Ptr); +  } + +  /// Register a store. +  void addStore(MemoryLocation &Loc) { +    Value *Ptr = const_cast<Value*>(Loc.Ptr); +    AST.add(Ptr, MemoryLocation::UnknownSize, Loc.AATags); +    Accesses.insert(MemAccessInfo(Ptr, true)); +  } + +  /// Check if we can emit a run-time no-alias check for \p Access. +  /// +  /// Returns true if we can emit a run-time no alias check for \p Access. +  /// If we can check this access, this also adds it to a dependence set and +  /// adds a run-time to check for it to \p RtCheck. If \p Assume is true, +  /// we will attempt to use additional run-time checks in order to get +  /// the bounds of the pointer. +  bool createCheckForAccess(RuntimePointerChecking &RtCheck, +                            MemAccessInfo Access, +                            const ValueToValueMap &Strides, +                            DenseMap<Value *, unsigned> &DepSetId, +                            Loop *TheLoop, unsigned &RunningDepId, +                            unsigned ASId, bool ShouldCheckStride, +                            bool Assume); + +  /// Check whether we can check the pointers at runtime for +  /// non-intersection. +  /// +  /// Returns true if we need no check or if we do and we can generate them +  /// (i.e. the pointers have computable bounds). +  bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, +                       Loop *TheLoop, const ValueToValueMap &Strides, +                       bool ShouldCheckWrap = false); + +  /// Goes over all memory accesses, checks whether a RT check is needed +  /// and builds sets of dependent accesses. +  void buildDependenceSets() { +    processMemAccesses(); +  } + +  /// Initial processing of memory accesses determined that we need to +  /// perform dependency checking. +  /// +  /// Note that this can later be cleared if we retry memcheck analysis without +  /// dependency checking (i.e. ShouldRetryWithRuntimeCheck). +  bool isDependencyCheckNeeded() { return !CheckDeps.empty(); } + +  /// We decided that no dependence analysis would be used.  Reset the state. +  void resetDepChecks(MemoryDepChecker &DepChecker) { +    CheckDeps.clear(); +    DepChecker.clearDependences(); +  } + +  MemAccessInfoList &getDependenciesToCheck() { return CheckDeps; } + +private: +  typedef SetVector<MemAccessInfo> PtrAccessSet; + +  /// Go over all memory access and check whether runtime pointer checks +  /// are needed and build sets of dependency check candidates. +  void processMemAccesses(); + +  /// Set of all accesses. +  PtrAccessSet Accesses; + +  const DataLayout &DL; + +  /// The loop being checked. +  const Loop *TheLoop; + +  /// List of accesses that need a further dependence check. +  MemAccessInfoList CheckDeps; + +  /// Set of pointers that are read only. +  SmallPtrSet<Value*, 16> ReadOnlyPtr; + +  /// An alias set tracker to partition the access set by underlying object and +  //intrinsic property (such as TBAA metadata). +  AliasSetTracker AST; + +  LoopInfo *LI; + +  /// Sets of potentially dependent accesses - members of one set share an +  /// underlying pointer. The set "CheckDeps" identfies which sets really need a +  /// dependence check. +  MemoryDepChecker::DepCandidates &DepCands; + +  /// Initial processing of memory accesses determined that we may need +  /// to add memchecks.  Perform the analysis to determine the necessary checks. +  /// +  /// Note that, this is different from isDependencyCheckNeeded.  When we retry +  /// memcheck analysis without dependency checking +  /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared +  /// while this remains set if we have potentially dependent accesses. +  bool IsRTCheckAnalysisNeeded; + +  /// The SCEV predicate containing all the SCEV-related assumptions. +  PredicatedScalarEvolution &PSE; +}; + +} // end anonymous namespace + +/// Check whether a pointer can participate in a runtime bounds check. +/// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr +/// by adding run-time checks (overflow checks) if necessary. +static bool hasComputableBounds(PredicatedScalarEvolution &PSE, +                                const ValueToValueMap &Strides, Value *Ptr, +                                Loop *L, bool Assume) { +  const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + +  // The bounds for loop-invariant pointer is trivial. +  if (PSE.getSE()->isLoopInvariant(PtrScev, L)) +    return true; + +  const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); + +  if (!AR && Assume) +    AR = PSE.getAsAddRec(Ptr); + +  if (!AR) +    return false; + +  return AR->isAffine(); +} + +/// Check whether a pointer address cannot wrap. +static bool isNoWrap(PredicatedScalarEvolution &PSE, +                     const ValueToValueMap &Strides, Value *Ptr, Loop *L) { +  const SCEV *PtrScev = PSE.getSCEV(Ptr); +  if (PSE.getSE()->isLoopInvariant(PtrScev, L)) +    return true; + +  int64_t Stride = getPtrStride(PSE, Ptr, L, Strides); +  if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) +    return true; + +  return false; +} + +bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, +                                          MemAccessInfo Access, +                                          const ValueToValueMap &StridesMap, +                                          DenseMap<Value *, unsigned> &DepSetId, +                                          Loop *TheLoop, unsigned &RunningDepId, +                                          unsigned ASId, bool ShouldCheckWrap, +                                          bool Assume) { +  Value *Ptr = Access.getPointer(); + +  if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume)) +    return false; + +  // When we run after a failing dependency check we have to make sure +  // we don't have wrapping pointers. +  if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { +    auto *Expr = PSE.getSCEV(Ptr); +    if (!Assume || !isa<SCEVAddRecExpr>(Expr)) +      return false; +    PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); +  } + +  // The id of the dependence set. +  unsigned DepId; + +  if (isDependencyCheckNeeded()) { +    Value *Leader = DepCands.getLeaderValue(Access).getPointer(); +    unsigned &LeaderId = DepSetId[Leader]; +    if (!LeaderId) +      LeaderId = RunningDepId++; +    DepId = LeaderId; +  } else +    // Each access has its own dependence set. +    DepId = RunningDepId++; + +  bool IsWrite = Access.getInt(); +  RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); +  LLVM_DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + +  return true; + } + +bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, +                                     ScalarEvolution *SE, Loop *TheLoop, +                                     const ValueToValueMap &StridesMap, +                                     bool ShouldCheckWrap) { +  // Find pointers with computable bounds. We are going to use this information +  // to place a runtime bound check. +  bool CanDoRT = true; + +  bool NeedRTCheck = false; +  if (!IsRTCheckAnalysisNeeded) return true; + +  bool IsDepCheckNeeded = isDependencyCheckNeeded(); + +  // We assign a consecutive id to access from different alias sets. +  // Accesses between different groups doesn't need to be checked. +  unsigned ASId = 1; +  for (auto &AS : AST) { +    int NumReadPtrChecks = 0; +    int NumWritePtrChecks = 0; +    bool CanDoAliasSetRT = true; + +    // We assign consecutive id to access from different dependence sets. +    // Accesses within the same set don't need a runtime check. +    unsigned RunningDepId = 1; +    DenseMap<Value *, unsigned> DepSetId; + +    SmallVector<MemAccessInfo, 4> Retries; + +    for (auto A : AS) { +      Value *Ptr = A.getValue(); +      bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true)); +      MemAccessInfo Access(Ptr, IsWrite); + +      if (IsWrite) +        ++NumWritePtrChecks; +      else +        ++NumReadPtrChecks; + +      if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, TheLoop, +                                RunningDepId, ASId, ShouldCheckWrap, false)) { +        LLVM_DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" << *Ptr << '\n'); +        Retries.push_back(Access); +        CanDoAliasSetRT = false; +      } +    } + +    // If we have at least two writes or one write and a read then we need to +    // check them.  But there is no need to checks if there is only one +    // dependence set for this alias set. +    // +    // Note that this function computes CanDoRT and NeedRTCheck independently. +    // For example CanDoRT=false, NeedRTCheck=false means that we have a pointer +    // for which we couldn't find the bounds but we don't actually need to emit +    // any checks so it does not matter. +    bool NeedsAliasSetRTCheck = false; +    if (!(IsDepCheckNeeded && CanDoAliasSetRT && RunningDepId == 2)) +      NeedsAliasSetRTCheck = (NumWritePtrChecks >= 2 || +                             (NumReadPtrChecks >= 1 && NumWritePtrChecks >= 1)); + +    // We need to perform run-time alias checks, but some pointers had bounds +    // that couldn't be checked. +    if (NeedsAliasSetRTCheck && !CanDoAliasSetRT) { +      // Reset the CanDoSetRt flag and retry all accesses that have failed. +      // We know that we need these checks, so we can now be more aggressive +      // and add further checks if required (overflow checks). +      CanDoAliasSetRT = true; +      for (auto Access : Retries) +        if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, +                                  TheLoop, RunningDepId, ASId, +                                  ShouldCheckWrap, /*Assume=*/true)) { +          CanDoAliasSetRT = false; +          break; +        } +    } + +    CanDoRT &= CanDoAliasSetRT; +    NeedRTCheck |= NeedsAliasSetRTCheck; +    ++ASId; +  } + +  // If the pointers that we would use for the bounds comparison have different +  // address spaces, assume the values aren't directly comparable, so we can't +  // use them for the runtime check. We also have to assume they could +  // overlap. In the future there should be metadata for whether address spaces +  // are disjoint. +  unsigned NumPointers = RtCheck.Pointers.size(); +  for (unsigned i = 0; i < NumPointers; ++i) { +    for (unsigned j = i + 1; j < NumPointers; ++j) { +      // Only need to check pointers between two different dependency sets. +      if (RtCheck.Pointers[i].DependencySetId == +          RtCheck.Pointers[j].DependencySetId) +       continue; +      // Only need to check pointers in the same alias set. +      if (RtCheck.Pointers[i].AliasSetId != RtCheck.Pointers[j].AliasSetId) +        continue; + +      Value *PtrI = RtCheck.Pointers[i].PointerValue; +      Value *PtrJ = RtCheck.Pointers[j].PointerValue; + +      unsigned ASi = PtrI->getType()->getPointerAddressSpace(); +      unsigned ASj = PtrJ->getType()->getPointerAddressSpace(); +      if (ASi != ASj) { +        LLVM_DEBUG( +            dbgs() << "LAA: Runtime check would require comparison between" +                      " different address spaces\n"); +        return false; +      } +    } +  } + +  if (NeedRTCheck && CanDoRT) +    RtCheck.generateChecks(DepCands, IsDepCheckNeeded); + +  LLVM_DEBUG(dbgs() << "LAA: We need to do " << RtCheck.getNumberOfChecks() +                    << " pointer comparisons.\n"); + +  RtCheck.Need = NeedRTCheck; + +  bool CanDoRTIfNeeded = !NeedRTCheck || CanDoRT; +  if (!CanDoRTIfNeeded) +    RtCheck.reset(); +  return CanDoRTIfNeeded; +} + +void AccessAnalysis::processMemAccesses() { +  // We process the set twice: first we process read-write pointers, last we +  // process read-only pointers. This allows us to skip dependence tests for +  // read-only pointers. + +  LLVM_DEBUG(dbgs() << "LAA: Processing memory accesses...\n"); +  LLVM_DEBUG(dbgs() << "  AST: "; AST.dump()); +  LLVM_DEBUG(dbgs() << "LAA:   Accesses(" << Accesses.size() << "):\n"); +  LLVM_DEBUG({ +    for (auto A : Accesses) +      dbgs() << "\t" << *A.getPointer() << " (" << +                (A.getInt() ? "write" : (ReadOnlyPtr.count(A.getPointer()) ? +                                         "read-only" : "read")) << ")\n"; +  }); + +  // The AliasSetTracker has nicely partitioned our pointers by metadata +  // compatibility and potential for underlying-object overlap. As a result, we +  // only need to check for potential pointer dependencies within each alias +  // set. +  for (auto &AS : AST) { +    // Note that both the alias-set tracker and the alias sets themselves used +    // linked lists internally and so the iteration order here is deterministic +    // (matching the original instruction order within each set). + +    bool SetHasWrite = false; + +    // Map of pointers to last access encountered. +    typedef DenseMap<Value*, MemAccessInfo> UnderlyingObjToAccessMap; +    UnderlyingObjToAccessMap ObjToLastAccess; + +    // Set of access to check after all writes have been processed. +    PtrAccessSet DeferredAccesses; + +    // Iterate over each alias set twice, once to process read/write pointers, +    // and then to process read-only pointers. +    for (int SetIteration = 0; SetIteration < 2; ++SetIteration) { +      bool UseDeferred = SetIteration > 0; +      PtrAccessSet &S = UseDeferred ? DeferredAccesses : Accesses; + +      for (auto AV : AS) { +        Value *Ptr = AV.getValue(); + +        // For a single memory access in AliasSetTracker, Accesses may contain +        // both read and write, and they both need to be handled for CheckDeps. +        for (auto AC : S) { +          if (AC.getPointer() != Ptr) +            continue; + +          bool IsWrite = AC.getInt(); + +          // If we're using the deferred access set, then it contains only +          // reads. +          bool IsReadOnlyPtr = ReadOnlyPtr.count(Ptr) && !IsWrite; +          if (UseDeferred && !IsReadOnlyPtr) +            continue; +          // Otherwise, the pointer must be in the PtrAccessSet, either as a +          // read or a write. +          assert(((IsReadOnlyPtr && UseDeferred) || IsWrite || +                  S.count(MemAccessInfo(Ptr, false))) && +                 "Alias-set pointer not in the access set?"); + +          MemAccessInfo Access(Ptr, IsWrite); +          DepCands.insert(Access); + +          // Memorize read-only pointers for later processing and skip them in +          // the first round (they need to be checked after we have seen all +          // write pointers). Note: we also mark pointer that are not +          // consecutive as "read-only" pointers (so that we check +          // "a[b[i]] +="). Hence, we need the second check for "!IsWrite". +          if (!UseDeferred && IsReadOnlyPtr) { +            DeferredAccesses.insert(Access); +            continue; +          } + +          // If this is a write - check other reads and writes for conflicts. If +          // this is a read only check other writes for conflicts (but only if +          // there is no other write to the ptr - this is an optimization to +          // catch "a[i] = a[i] + " without having to do a dependence check). +          if ((IsWrite || IsReadOnlyPtr) && SetHasWrite) { +            CheckDeps.push_back(Access); +            IsRTCheckAnalysisNeeded = true; +          } + +          if (IsWrite) +            SetHasWrite = true; + +          // Create sets of pointers connected by a shared alias set and +          // underlying object. +          typedef SmallVector<Value *, 16> ValueVector; +          ValueVector TempObjects; + +          GetUnderlyingObjects(Ptr, TempObjects, DL, LI); +          LLVM_DEBUG(dbgs() +                     << "Underlying objects for pointer " << *Ptr << "\n"); +          for (Value *UnderlyingObj : TempObjects) { +            // nullptr never alias, don't join sets for pointer that have "null" +            // in their UnderlyingObjects list. +            if (isa<ConstantPointerNull>(UnderlyingObj) && +                !NullPointerIsDefined( +                    TheLoop->getHeader()->getParent(), +                    UnderlyingObj->getType()->getPointerAddressSpace())) +              continue; + +            UnderlyingObjToAccessMap::iterator Prev = +                ObjToLastAccess.find(UnderlyingObj); +            if (Prev != ObjToLastAccess.end()) +              DepCands.unionSets(Access, Prev->second); + +            ObjToLastAccess[UnderlyingObj] = Access; +            LLVM_DEBUG(dbgs() << "  " << *UnderlyingObj << "\n"); +          } +        } +      } +    } +  } +} + +static bool isInBoundsGep(Value *Ptr) { +  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr)) +    return GEP->isInBounds(); +  return false; +} + +/// Return true if an AddRec pointer \p Ptr is unsigned non-wrapping, +/// i.e. monotonically increasing/decreasing. +static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, +                           PredicatedScalarEvolution &PSE, const Loop *L) { +  // FIXME: This should probably only return true for NUW. +  if (AR->getNoWrapFlags(SCEV::NoWrapMask)) +    return true; + +  // Scalar evolution does not propagate the non-wrapping flags to values that +  // are derived from a non-wrapping induction variable because non-wrapping +  // could be flow-sensitive. +  // +  // Look through the potentially overflowing instruction to try to prove +  // non-wrapping for the *specific* value of Ptr. + +  // The arithmetic implied by an inbounds GEP can't overflow. +  auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); +  if (!GEP || !GEP->isInBounds()) +    return false; + +  // Make sure there is only one non-const index and analyze that. +  Value *NonConstIndex = nullptr; +  for (Value *Index : make_range(GEP->idx_begin(), GEP->idx_end())) +    if (!isa<ConstantInt>(Index)) { +      if (NonConstIndex) +        return false; +      NonConstIndex = Index; +    } +  if (!NonConstIndex) +    // The recurrence is on the pointer, ignore for now. +    return false; + +  // The index in GEP is signed.  It is non-wrapping if it's derived from a NSW +  // AddRec using a NSW operation. +  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(NonConstIndex)) +    if (OBO->hasNoSignedWrap() && +        // Assume constant for other the operand so that the AddRec can be +        // easily found. +        isa<ConstantInt>(OBO->getOperand(1))) { +      auto *OpScev = PSE.getSCEV(OBO->getOperand(0)); + +      if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpScev)) +        return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW); +    } + +  return false; +} + +/// Check whether the access through \p Ptr has a constant stride. +int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Value *Ptr, +                           const Loop *Lp, const ValueToValueMap &StridesMap, +                           bool Assume, bool ShouldCheckWrap) { +  Type *Ty = Ptr->getType(); +  assert(Ty->isPointerTy() && "Unexpected non-ptr"); + +  // Make sure that the pointer does not point to aggregate types. +  auto *PtrTy = cast<PointerType>(Ty); +  if (PtrTy->getElementType()->isAggregateType()) { +    LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" +                      << *Ptr << "\n"); +    return 0; +  } + +  const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); + +  const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); +  if (Assume && !AR) +    AR = PSE.getAsAddRec(Ptr); + +  if (!AR) { +    LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr +                      << " SCEV: " << *PtrScev << "\n"); +    return 0; +  } + +  // The accesss function must stride over the innermost loop. +  if (Lp != AR->getLoop()) { +    LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " +                      << *Ptr << " SCEV: " << *AR << "\n"); +    return 0; +  } + +  // The address calculation must not wrap. Otherwise, a dependence could be +  // inverted. +  // An inbounds getelementptr that is a AddRec with a unit stride +  // cannot wrap per definition. The unit stride requirement is checked later. +  // An getelementptr without an inbounds attribute and unit stride would have +  // to access the pointer value "0" which is undefined behavior in address +  // space 0, therefore we can also vectorize this case. +  bool IsInBoundsGEP = isInBoundsGep(Ptr); +  bool IsNoWrapAddRec = !ShouldCheckWrap || +    PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || +    isNoWrapAddRec(Ptr, AR, PSE, Lp); +  if (!IsNoWrapAddRec && !IsInBoundsGEP && +      NullPointerIsDefined(Lp->getHeader()->getParent(), +                           PtrTy->getAddressSpace())) { +    if (Assume) { +      PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); +      IsNoWrapAddRec = true; +      LLVM_DEBUG(dbgs() << "LAA: Pointer may wrap in the address space:\n" +                        << "LAA:   Pointer: " << *Ptr << "\n" +                        << "LAA:   SCEV: " << *AR << "\n" +                        << "LAA:   Added an overflow assumption\n"); +    } else { +      LLVM_DEBUG( +          dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " +                 << *Ptr << " SCEV: " << *AR << "\n"); +      return 0; +    } +  } + +  // Check the step is constant. +  const SCEV *Step = AR->getStepRecurrence(*PSE.getSE()); + +  // Calculate the pointer stride and check if it is constant. +  const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); +  if (!C) { +    LLVM_DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr +                      << " SCEV: " << *AR << "\n"); +    return 0; +  } + +  auto &DL = Lp->getHeader()->getModule()->getDataLayout(); +  int64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); +  const APInt &APStepVal = C->getAPInt(); + +  // Huge step value - give up. +  if (APStepVal.getBitWidth() > 64) +    return 0; + +  int64_t StepVal = APStepVal.getSExtValue(); + +  // Strided access. +  int64_t Stride = StepVal / Size; +  int64_t Rem = StepVal % Size; +  if (Rem) +    return 0; + +  // If the SCEV could wrap but we have an inbounds gep with a unit stride we +  // know we can't "wrap around the address space". In case of address space +  // zero we know that this won't happen without triggering undefined behavior. +  if (!IsNoWrapAddRec && Stride != 1 && Stride != -1 && +      (IsInBoundsGEP || !NullPointerIsDefined(Lp->getHeader()->getParent(), +                                              PtrTy->getAddressSpace()))) { +    if (Assume) { +      // We can avoid this case by adding a run-time check. +      LLVM_DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either " +                        << "inbouds or in address space 0 may wrap:\n" +                        << "LAA:   Pointer: " << *Ptr << "\n" +                        << "LAA:   SCEV: " << *AR << "\n" +                        << "LAA:   Added an overflow assumption\n"); +      PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); +    } else +      return 0; +  } + +  return Stride; +} + +bool llvm::sortPtrAccesses(ArrayRef<Value *> VL, const DataLayout &DL, +                           ScalarEvolution &SE, +                           SmallVectorImpl<unsigned> &SortedIndices) { +  assert(llvm::all_of( +             VL, [](const Value *V) { return V->getType()->isPointerTy(); }) && +         "Expected list of pointer operands."); +  SmallVector<std::pair<int64_t, Value *>, 4> OffValPairs; +  OffValPairs.reserve(VL.size()); + +  // Walk over the pointers, and map each of them to an offset relative to +  // first pointer in the array. +  Value *Ptr0 = VL[0]; +  const SCEV *Scev0 = SE.getSCEV(Ptr0); +  Value *Obj0 = GetUnderlyingObject(Ptr0, DL); + +  llvm::SmallSet<int64_t, 4> Offsets; +  for (auto *Ptr : VL) { +    // TODO: Outline this code as a special, more time consuming, version of +    // computeConstantDifference() function. +    if (Ptr->getType()->getPointerAddressSpace() != +        Ptr0->getType()->getPointerAddressSpace()) +      return false; +    // If a pointer refers to a different underlying object, bail - the +    // pointers are by definition incomparable. +    Value *CurrObj = GetUnderlyingObject(Ptr, DL); +    if (CurrObj != Obj0) +      return false; + +    const SCEV *Scev = SE.getSCEV(Ptr); +    const auto *Diff = dyn_cast<SCEVConstant>(SE.getMinusSCEV(Scev, Scev0)); +    // The pointers may not have a constant offset from each other, or SCEV +    // may just not be smart enough to figure out they do. Regardless, +    // there's nothing we can do. +    if (!Diff) +      return false; + +    // Check if the pointer with the same offset is found. +    int64_t Offset = Diff->getAPInt().getSExtValue(); +    if (!Offsets.insert(Offset).second) +      return false; +    OffValPairs.emplace_back(Offset, Ptr); +  } +  SortedIndices.clear(); +  SortedIndices.resize(VL.size()); +  std::iota(SortedIndices.begin(), SortedIndices.end(), 0); + +  // Sort the memory accesses and keep the order of their uses in UseOrder. +  std::stable_sort(SortedIndices.begin(), SortedIndices.end(), +                   [&OffValPairs](unsigned Left, unsigned Right) { +                     return OffValPairs[Left].first < OffValPairs[Right].first; +                   }); + +  // Check if the order is consecutive already. +  if (llvm::all_of(SortedIndices, [&SortedIndices](const unsigned I) { +        return I == SortedIndices[I]; +      })) +    SortedIndices.clear(); + +  return true; +} + +/// Take the address space operand from the Load/Store instruction. +/// Returns -1 if this is not a valid Load/Store instruction. +static unsigned getAddressSpaceOperand(Value *I) { +  if (LoadInst *L = dyn_cast<LoadInst>(I)) +    return L->getPointerAddressSpace(); +  if (StoreInst *S = dyn_cast<StoreInst>(I)) +    return S->getPointerAddressSpace(); +  return -1; +} + +/// Returns true if the memory operations \p A and \p B are consecutive. +bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, +                               ScalarEvolution &SE, bool CheckType) { +  Value *PtrA = getLoadStorePointerOperand(A); +  Value *PtrB = getLoadStorePointerOperand(B); +  unsigned ASA = getAddressSpaceOperand(A); +  unsigned ASB = getAddressSpaceOperand(B); + +  // Check that the address spaces match and that the pointers are valid. +  if (!PtrA || !PtrB || (ASA != ASB)) +    return false; + +  // Make sure that A and B are different pointers. +  if (PtrA == PtrB) +    return false; + +  // Make sure that A and B have the same type if required. +  if (CheckType && PtrA->getType() != PtrB->getType()) +    return false; + +  unsigned IdxWidth = DL.getIndexSizeInBits(ASA); +  Type *Ty = cast<PointerType>(PtrA->getType())->getElementType(); +  APInt Size(IdxWidth, DL.getTypeStoreSize(Ty)); + +  APInt OffsetA(IdxWidth, 0), OffsetB(IdxWidth, 0); +  PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); +  PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); + +  //  OffsetDelta = OffsetB - OffsetA; +  const SCEV *OffsetSCEVA = SE.getConstant(OffsetA); +  const SCEV *OffsetSCEVB = SE.getConstant(OffsetB); +  const SCEV *OffsetDeltaSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA); +  const SCEVConstant *OffsetDeltaC = dyn_cast<SCEVConstant>(OffsetDeltaSCEV); +  const APInt &OffsetDelta = OffsetDeltaC->getAPInt(); +  // Check if they are based on the same pointer. That makes the offsets +  // sufficient. +  if (PtrA == PtrB) +    return OffsetDelta == Size; + +  // Compute the necessary base pointer delta to have the necessary final delta +  // equal to the size. +  // BaseDelta = Size - OffsetDelta; +  const SCEV *SizeSCEV = SE.getConstant(Size); +  const SCEV *BaseDelta = SE.getMinusSCEV(SizeSCEV, OffsetDeltaSCEV); + +  // Otherwise compute the distance with SCEV between the base pointers. +  const SCEV *PtrSCEVA = SE.getSCEV(PtrA); +  const SCEV *PtrSCEVB = SE.getSCEV(PtrB); +  const SCEV *X = SE.getAddExpr(PtrSCEVA, BaseDelta); +  return X == PtrSCEVB; +} + +bool MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) { +  switch (Type) { +  case NoDep: +  case Forward: +  case BackwardVectorizable: +    return true; + +  case Unknown: +  case ForwardButPreventsForwarding: +  case Backward: +  case BackwardVectorizableButPreventsForwarding: +    return false; +  } +  llvm_unreachable("unexpected DepType!"); +} + +bool MemoryDepChecker::Dependence::isBackward() const { +  switch (Type) { +  case NoDep: +  case Forward: +  case ForwardButPreventsForwarding: +  case Unknown: +    return false; + +  case BackwardVectorizable: +  case Backward: +  case BackwardVectorizableButPreventsForwarding: +    return true; +  } +  llvm_unreachable("unexpected DepType!"); +} + +bool MemoryDepChecker::Dependence::isPossiblyBackward() const { +  return isBackward() || Type == Unknown; +} + +bool MemoryDepChecker::Dependence::isForward() const { +  switch (Type) { +  case Forward: +  case ForwardButPreventsForwarding: +    return true; + +  case NoDep: +  case Unknown: +  case BackwardVectorizable: +  case Backward: +  case BackwardVectorizableButPreventsForwarding: +    return false; +  } +  llvm_unreachable("unexpected DepType!"); +} + +bool MemoryDepChecker::couldPreventStoreLoadForward(uint64_t Distance, +                                                    uint64_t TypeByteSize) { +  // If loads occur at a distance that is not a multiple of a feasible vector +  // factor store-load forwarding does not take place. +  // Positive dependences might cause troubles because vectorizing them might +  // prevent store-load forwarding making vectorized code run a lot slower. +  //   a[i] = a[i-3] ^ a[i-8]; +  //   The stores to a[i:i+1] don't align with the stores to a[i-3:i-2] and +  //   hence on your typical architecture store-load forwarding does not take +  //   place. Vectorizing in such cases does not make sense. +  // Store-load forwarding distance. + +  // After this many iterations store-to-load forwarding conflicts should not +  // cause any slowdowns. +  const uint64_t NumItersForStoreLoadThroughMemory = 8 * TypeByteSize; +  // Maximum vector factor. +  uint64_t MaxVFWithoutSLForwardIssues = std::min( +      VectorizerParams::MaxVectorWidth * TypeByteSize, MaxSafeDepDistBytes); + +  // Compute the smallest VF at which the store and load would be misaligned. +  for (uint64_t VF = 2 * TypeByteSize; VF <= MaxVFWithoutSLForwardIssues; +       VF *= 2) { +    // If the number of vector iteration between the store and the load are +    // small we could incur conflicts. +    if (Distance % VF && Distance / VF < NumItersForStoreLoadThroughMemory) { +      MaxVFWithoutSLForwardIssues = (VF >>= 1); +      break; +    } +  } + +  if (MaxVFWithoutSLForwardIssues < 2 * TypeByteSize) { +    LLVM_DEBUG( +        dbgs() << "LAA: Distance " << Distance +               << " that could cause a store-load forwarding conflict\n"); +    return true; +  } + +  if (MaxVFWithoutSLForwardIssues < MaxSafeDepDistBytes && +      MaxVFWithoutSLForwardIssues != +          VectorizerParams::MaxVectorWidth * TypeByteSize) +    MaxSafeDepDistBytes = MaxVFWithoutSLForwardIssues; +  return false; +} + +/// Given a non-constant (unknown) dependence-distance \p Dist between two +/// memory accesses, that have the same stride whose absolute value is given +/// in \p Stride, and that have the same type size \p TypeByteSize, +/// in a loop whose takenCount is \p BackedgeTakenCount, check if it is +/// possible to prove statically that the dependence distance is larger +/// than the range that the accesses will travel through the execution of +/// the loop. If so, return true; false otherwise. This is useful for +/// example in loops such as the following (PR31098): +///     for (i = 0; i < D; ++i) { +///                = out[i]; +///       out[i+D] = +///     } +static bool isSafeDependenceDistance(const DataLayout &DL, ScalarEvolution &SE, +                                     const SCEV &BackedgeTakenCount, +                                     const SCEV &Dist, uint64_t Stride, +                                     uint64_t TypeByteSize) { + +  // If we can prove that +  //      (**) |Dist| > BackedgeTakenCount * Step +  // where Step is the absolute stride of the memory accesses in bytes, +  // then there is no dependence. +  // +  // Ratioanle: +  // We basically want to check if the absolute distance (|Dist/Step|) +  // is >= the loop iteration count (or > BackedgeTakenCount). +  // This is equivalent to the Strong SIV Test (Practical Dependence Testing, +  // Section 4.2.1); Note, that for vectorization it is sufficient to prove +  // that the dependence distance is >= VF; This is checked elsewhere. +  // But in some cases we can prune unknown dependence distances early, and +  // even before selecting the VF, and without a runtime test, by comparing +  // the distance against the loop iteration count. Since the vectorized code +  // will be executed only if LoopCount >= VF, proving distance >= LoopCount +  // also guarantees that distance >= VF. +  // +  const uint64_t ByteStride = Stride * TypeByteSize; +  const SCEV *Step = SE.getConstant(BackedgeTakenCount.getType(), ByteStride); +  const SCEV *Product = SE.getMulExpr(&BackedgeTakenCount, Step); + +  const SCEV *CastedDist = &Dist; +  const SCEV *CastedProduct = Product; +  uint64_t DistTypeSize = DL.getTypeAllocSize(Dist.getType()); +  uint64_t ProductTypeSize = DL.getTypeAllocSize(Product->getType()); + +  // The dependence distance can be positive/negative, so we sign extend Dist; +  // The multiplication of the absolute stride in bytes and the +  // backdgeTakenCount is non-negative, so we zero extend Product. +  if (DistTypeSize > ProductTypeSize) +    CastedProduct = SE.getZeroExtendExpr(Product, Dist.getType()); +  else +    CastedDist = SE.getNoopOrSignExtend(&Dist, Product->getType()); + +  // Is  Dist - (BackedgeTakenCount * Step) > 0 ? +  // (If so, then we have proven (**) because |Dist| >= Dist) +  const SCEV *Minus = SE.getMinusSCEV(CastedDist, CastedProduct); +  if (SE.isKnownPositive(Minus)) +    return true; + +  // Second try: Is  -Dist - (BackedgeTakenCount * Step) > 0 ? +  // (If so, then we have proven (**) because |Dist| >= -1*Dist) +  const SCEV *NegDist = SE.getNegativeSCEV(CastedDist); +  Minus = SE.getMinusSCEV(NegDist, CastedProduct); +  if (SE.isKnownPositive(Minus)) +    return true; + +  return false; +} + +/// Check the dependence for two accesses with the same stride \p Stride. +/// \p Distance is the positive distance and \p TypeByteSize is type size in +/// bytes. +/// +/// \returns true if they are independent. +static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride, +                                          uint64_t TypeByteSize) { +  assert(Stride > 1 && "The stride must be greater than 1"); +  assert(TypeByteSize > 0 && "The type size in byte must be non-zero"); +  assert(Distance > 0 && "The distance must be non-zero"); + +  // Skip if the distance is not multiple of type byte size. +  if (Distance % TypeByteSize) +    return false; + +  uint64_t ScaledDist = Distance / TypeByteSize; + +  // No dependence if the scaled distance is not multiple of the stride. +  // E.g. +  //      for (i = 0; i < 1024 ; i += 4) +  //        A[i+2] = A[i] + 1; +  // +  // Two accesses in memory (scaled distance is 2, stride is 4): +  //     | A[0] |      |      |      | A[4] |      |      |      | +  //     |      |      | A[2] |      |      |      | A[6] |      | +  // +  // E.g. +  //      for (i = 0; i < 1024 ; i += 3) +  //        A[i+4] = A[i] + 1; +  // +  // Two accesses in memory (scaled distance is 4, stride is 3): +  //     | A[0] |      |      | A[3] |      |      | A[6] |      |      | +  //     |      |      |      |      | A[4] |      |      | A[7] |      | +  return ScaledDist % Stride; +} + +MemoryDepChecker::Dependence::DepType +MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, +                              const MemAccessInfo &B, unsigned BIdx, +                              const ValueToValueMap &Strides) { +  assert (AIdx < BIdx && "Must pass arguments in program order"); + +  Value *APtr = A.getPointer(); +  Value *BPtr = B.getPointer(); +  bool AIsWrite = A.getInt(); +  bool BIsWrite = B.getInt(); + +  // Two reads are independent. +  if (!AIsWrite && !BIsWrite) +    return Dependence::NoDep; + +  // We cannot check pointers in different address spaces. +  if (APtr->getType()->getPointerAddressSpace() != +      BPtr->getType()->getPointerAddressSpace()) +    return Dependence::Unknown; + +  int64_t StrideAPtr = getPtrStride(PSE, APtr, InnermostLoop, Strides, true); +  int64_t StrideBPtr = getPtrStride(PSE, BPtr, InnermostLoop, Strides, true); + +  const SCEV *Src = PSE.getSCEV(APtr); +  const SCEV *Sink = PSE.getSCEV(BPtr); + +  // If the induction step is negative we have to invert source and sink of the +  // dependence. +  if (StrideAPtr < 0) { +    std::swap(APtr, BPtr); +    std::swap(Src, Sink); +    std::swap(AIsWrite, BIsWrite); +    std::swap(AIdx, BIdx); +    std::swap(StrideAPtr, StrideBPtr); +  } + +  const SCEV *Dist = PSE.getSE()->getMinusSCEV(Sink, Src); + +  LLVM_DEBUG(dbgs() << "LAA: Src Scev: " << *Src << "Sink Scev: " << *Sink +                    << "(Induction step: " << StrideAPtr << ")\n"); +  LLVM_DEBUG(dbgs() << "LAA: Distance for " << *InstMap[AIdx] << " to " +                    << *InstMap[BIdx] << ": " << *Dist << "\n"); + +  // Need accesses with constant stride. We don't want to vectorize +  // "A[B[i]] += ..." and similar code or pointer arithmetic that could wrap in +  // the address space. +  if (!StrideAPtr || !StrideBPtr || StrideAPtr != StrideBPtr){ +    LLVM_DEBUG(dbgs() << "Pointer access with non-constant stride\n"); +    return Dependence::Unknown; +  } + +  Type *ATy = APtr->getType()->getPointerElementType(); +  Type *BTy = BPtr->getType()->getPointerElementType(); +  auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout(); +  uint64_t TypeByteSize = DL.getTypeAllocSize(ATy); +  uint64_t Stride = std::abs(StrideAPtr); +  const SCEVConstant *C = dyn_cast<SCEVConstant>(Dist); +  if (!C) { +    if (TypeByteSize == DL.getTypeAllocSize(BTy) && +        isSafeDependenceDistance(DL, *(PSE.getSE()), +                                 *(PSE.getBackedgeTakenCount()), *Dist, Stride, +                                 TypeByteSize)) +      return Dependence::NoDep; + +    LLVM_DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n"); +    ShouldRetryWithRuntimeCheck = true; +    return Dependence::Unknown; +  } + +  const APInt &Val = C->getAPInt(); +  int64_t Distance = Val.getSExtValue(); + +  // Attempt to prove strided accesses independent. +  if (std::abs(Distance) > 0 && Stride > 1 && ATy == BTy && +      areStridedAccessesIndependent(std::abs(Distance), Stride, TypeByteSize)) { +    LLVM_DEBUG(dbgs() << "LAA: Strided accesses are independent\n"); +    return Dependence::NoDep; +  } + +  // Negative distances are not plausible dependencies. +  if (Val.isNegative()) { +    bool IsTrueDataDependence = (AIsWrite && !BIsWrite); +    if (IsTrueDataDependence && EnableForwardingConflictDetection && +        (couldPreventStoreLoadForward(Val.abs().getZExtValue(), TypeByteSize) || +         ATy != BTy)) { +      LLVM_DEBUG(dbgs() << "LAA: Forward but may prevent st->ld forwarding\n"); +      return Dependence::ForwardButPreventsForwarding; +    } + +    LLVM_DEBUG(dbgs() << "LAA: Dependence is negative\n"); +    return Dependence::Forward; +  } + +  // Write to the same location with the same size. +  // Could be improved to assert type sizes are the same (i32 == float, etc). +  if (Val == 0) { +    if (ATy == BTy) +      return Dependence::Forward; +    LLVM_DEBUG( +        dbgs() << "LAA: Zero dependence difference but different types\n"); +    return Dependence::Unknown; +  } + +  assert(Val.isStrictlyPositive() && "Expect a positive value"); + +  if (ATy != BTy) { +    LLVM_DEBUG( +        dbgs() +        << "LAA: ReadWrite-Write positive dependency with different types\n"); +    return Dependence::Unknown; +  } + +  // Bail out early if passed-in parameters make vectorization not feasible. +  unsigned ForcedFactor = (VectorizerParams::VectorizationFactor ? +                           VectorizerParams::VectorizationFactor : 1); +  unsigned ForcedUnroll = (VectorizerParams::VectorizationInterleave ? +                           VectorizerParams::VectorizationInterleave : 1); +  // The minimum number of iterations for a vectorized/unrolled version. +  unsigned MinNumIter = std::max(ForcedFactor * ForcedUnroll, 2U); + +  // It's not vectorizable if the distance is smaller than the minimum distance +  // needed for a vectroized/unrolled version. Vectorizing one iteration in +  // front needs TypeByteSize * Stride. Vectorizing the last iteration needs +  // TypeByteSize (No need to plus the last gap distance). +  // +  // E.g. Assume one char is 1 byte in memory and one int is 4 bytes. +  //      foo(int *A) { +  //        int *B = (int *)((char *)A + 14); +  //        for (i = 0 ; i < 1024 ; i += 2) +  //          B[i] = A[i] + 1; +  //      } +  // +  // Two accesses in memory (stride is 2): +  //     | A[0] |      | A[2] |      | A[4] |      | A[6] |      | +  //                              | B[0] |      | B[2] |      | B[4] | +  // +  // Distance needs for vectorizing iterations except the last iteration: +  // 4 * 2 * (MinNumIter - 1). Distance needs for the last iteration: 4. +  // So the minimum distance needed is: 4 * 2 * (MinNumIter - 1) + 4. +  // +  // If MinNumIter is 2, it is vectorizable as the minimum distance needed is +  // 12, which is less than distance. +  // +  // If MinNumIter is 4 (Say if a user forces the vectorization factor to be 4), +  // the minimum distance needed is 28, which is greater than distance. It is +  // not safe to do vectorization. +  uint64_t MinDistanceNeeded = +      TypeByteSize * Stride * (MinNumIter - 1) + TypeByteSize; +  if (MinDistanceNeeded > static_cast<uint64_t>(Distance)) { +    LLVM_DEBUG(dbgs() << "LAA: Failure because of positive distance " +                      << Distance << '\n'); +    return Dependence::Backward; +  } + +  // Unsafe if the minimum distance needed is greater than max safe distance. +  if (MinDistanceNeeded > MaxSafeDepDistBytes) { +    LLVM_DEBUG(dbgs() << "LAA: Failure because it needs at least " +                      << MinDistanceNeeded << " size in bytes"); +    return Dependence::Backward; +  } + +  // Positive distance bigger than max vectorization factor. +  // FIXME: Should use max factor instead of max distance in bytes, which could +  // not handle different types. +  // E.g. Assume one char is 1 byte in memory and one int is 4 bytes. +  //      void foo (int *A, char *B) { +  //        for (unsigned i = 0; i < 1024; i++) { +  //          A[i+2] = A[i] + 1; +  //          B[i+2] = B[i] + 1; +  //        } +  //      } +  // +  // This case is currently unsafe according to the max safe distance. If we +  // analyze the two accesses on array B, the max safe dependence distance +  // is 2. Then we analyze the accesses on array A, the minimum distance needed +  // is 8, which is less than 2 and forbidden vectorization, But actually +  // both A and B could be vectorized by 2 iterations. +  MaxSafeDepDistBytes = +      std::min(static_cast<uint64_t>(Distance), MaxSafeDepDistBytes); + +  bool IsTrueDataDependence = (!AIsWrite && BIsWrite); +  if (IsTrueDataDependence && EnableForwardingConflictDetection && +      couldPreventStoreLoadForward(Distance, TypeByteSize)) +    return Dependence::BackwardVectorizableButPreventsForwarding; + +  uint64_t MaxVF = MaxSafeDepDistBytes / (TypeByteSize * Stride); +  LLVM_DEBUG(dbgs() << "LAA: Positive distance " << Val.getSExtValue() +                    << " with max VF = " << MaxVF << '\n'); +  uint64_t MaxVFInBits = MaxVF * TypeByteSize * 8; +  MaxSafeRegisterWidth = std::min(MaxSafeRegisterWidth, MaxVFInBits); +  return Dependence::BackwardVectorizable; +} + +bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, +                                   MemAccessInfoList &CheckDeps, +                                   const ValueToValueMap &Strides) { + +  MaxSafeDepDistBytes = -1; +  SmallPtrSet<MemAccessInfo, 8> Visited; +  for (MemAccessInfo CurAccess : CheckDeps) { +    if (Visited.count(CurAccess)) +      continue; + +    // Get the relevant memory access set. +    EquivalenceClasses<MemAccessInfo>::iterator I = +      AccessSets.findValue(AccessSets.getLeaderValue(CurAccess)); + +    // Check accesses within this set. +    EquivalenceClasses<MemAccessInfo>::member_iterator AI = +        AccessSets.member_begin(I); +    EquivalenceClasses<MemAccessInfo>::member_iterator AE = +        AccessSets.member_end(); + +    // Check every access pair. +    while (AI != AE) { +      Visited.insert(*AI); +      EquivalenceClasses<MemAccessInfo>::member_iterator OI = std::next(AI); +      while (OI != AE) { +        // Check every accessing instruction pair in program order. +        for (std::vector<unsigned>::iterator I1 = Accesses[*AI].begin(), +             I1E = Accesses[*AI].end(); I1 != I1E; ++I1) +          for (std::vector<unsigned>::iterator I2 = Accesses[*OI].begin(), +               I2E = Accesses[*OI].end(); I2 != I2E; ++I2) { +            auto A = std::make_pair(&*AI, *I1); +            auto B = std::make_pair(&*OI, *I2); + +            assert(*I1 != *I2); +            if (*I1 > *I2) +              std::swap(A, B); + +            Dependence::DepType Type = +                isDependent(*A.first, A.second, *B.first, B.second, Strides); +            SafeForVectorization &= Dependence::isSafeForVectorization(Type); + +            // Gather dependences unless we accumulated MaxDependences +            // dependences.  In that case return as soon as we find the first +            // unsafe dependence.  This puts a limit on this quadratic +            // algorithm. +            if (RecordDependences) { +              if (Type != Dependence::NoDep) +                Dependences.push_back(Dependence(A.second, B.second, Type)); + +              if (Dependences.size() >= MaxDependences) { +                RecordDependences = false; +                Dependences.clear(); +                LLVM_DEBUG(dbgs() +                           << "Too many dependences, stopped recording\n"); +              } +            } +            if (!RecordDependences && !SafeForVectorization) +              return false; +          } +        ++OI; +      } +      AI++; +    } +  } + +  LLVM_DEBUG(dbgs() << "Total Dependences: " << Dependences.size() << "\n"); +  return SafeForVectorization; +} + +SmallVector<Instruction *, 4> +MemoryDepChecker::getInstructionsForAccess(Value *Ptr, bool isWrite) const { +  MemAccessInfo Access(Ptr, isWrite); +  auto &IndexVector = Accesses.find(Access)->second; + +  SmallVector<Instruction *, 4> Insts; +  transform(IndexVector, +                 std::back_inserter(Insts), +                 [&](unsigned Idx) { return this->InstMap[Idx]; }); +  return Insts; +} + +const char *MemoryDepChecker::Dependence::DepName[] = { +    "NoDep", "Unknown", "Forward", "ForwardButPreventsForwarding", "Backward", +    "BackwardVectorizable", "BackwardVectorizableButPreventsForwarding"}; + +void MemoryDepChecker::Dependence::print( +    raw_ostream &OS, unsigned Depth, +    const SmallVectorImpl<Instruction *> &Instrs) const { +  OS.indent(Depth) << DepName[Type] << ":\n"; +  OS.indent(Depth + 2) << *Instrs[Source] << " -> \n"; +  OS.indent(Depth + 2) << *Instrs[Destination] << "\n"; +} + +bool LoopAccessInfo::canAnalyzeLoop() { +  // We need to have a loop header. +  LLVM_DEBUG(dbgs() << "LAA: Found a loop in " +                    << TheLoop->getHeader()->getParent()->getName() << ": " +                    << TheLoop->getHeader()->getName() << '\n'); + +  // We can only analyze innermost loops. +  if (!TheLoop->empty()) { +    LLVM_DEBUG(dbgs() << "LAA: loop is not the innermost loop\n"); +    recordAnalysis("NotInnerMostLoop") << "loop is not the innermost loop"; +    return false; +  } + +  // We must have a single backedge. +  if (TheLoop->getNumBackEdges() != 1) { +    LLVM_DEBUG( +        dbgs() << "LAA: loop control flow is not understood by analyzer\n"); +    recordAnalysis("CFGNotUnderstood") +        << "loop control flow is not understood by analyzer"; +    return false; +  } + +  // We must have a single exiting block. +  if (!TheLoop->getExitingBlock()) { +    LLVM_DEBUG( +        dbgs() << "LAA: loop control flow is not understood by analyzer\n"); +    recordAnalysis("CFGNotUnderstood") +        << "loop control flow is not understood by analyzer"; +    return false; +  } + +  // We only handle bottom-tested loops, i.e. loop in which the condition is +  // checked at the end of each iteration. With that we can assume that all +  // instructions in the loop are executed the same number of times. +  if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { +    LLVM_DEBUG( +        dbgs() << "LAA: loop control flow is not understood by analyzer\n"); +    recordAnalysis("CFGNotUnderstood") +        << "loop control flow is not understood by analyzer"; +    return false; +  } + +  // ScalarEvolution needs to be able to find the exit count. +  const SCEV *ExitCount = PSE->getBackedgeTakenCount(); +  if (ExitCount == PSE->getSE()->getCouldNotCompute()) { +    recordAnalysis("CantComputeNumberOfIterations") +        << "could not determine number of loop iterations"; +    LLVM_DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n"); +    return false; +  } + +  return true; +} + +void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, +                                 const TargetLibraryInfo *TLI, +                                 DominatorTree *DT) { +  typedef SmallPtrSet<Value*, 16> ValueSet; + +  // Holds the Load and Store instructions. +  SmallVector<LoadInst *, 16> Loads; +  SmallVector<StoreInst *, 16> Stores; + +  // Holds all the different accesses in the loop. +  unsigned NumReads = 0; +  unsigned NumReadWrites = 0; + +  PtrRtChecking->Pointers.clear(); +  PtrRtChecking->Need = false; + +  const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); + +  // For each block. +  for (BasicBlock *BB : TheLoop->blocks()) { +    // Scan the BB and collect legal loads and stores. +    for (Instruction &I : *BB) { +      // If this is a load, save it. If this instruction can read from memory +      // but is not a load, then we quit. Notice that we don't handle function +      // calls that read or write. +      if (I.mayReadFromMemory()) { +        // Many math library functions read the rounding mode. We will only +        // vectorize a loop if it contains known function calls that don't set +        // the flag. Therefore, it is safe to ignore this read from memory. +        auto *Call = dyn_cast<CallInst>(&I); +        if (Call && getVectorIntrinsicIDForCall(Call, TLI)) +          continue; + +        // If the function has an explicit vectorized counterpart, we can safely +        // assume that it can be vectorized. +        if (Call && !Call->isNoBuiltin() && Call->getCalledFunction() && +            TLI->isFunctionVectorizable(Call->getCalledFunction()->getName())) +          continue; + +        auto *Ld = dyn_cast<LoadInst>(&I); +        if (!Ld || (!Ld->isSimple() && !IsAnnotatedParallel)) { +          recordAnalysis("NonSimpleLoad", Ld) +              << "read with atomic ordering or volatile read"; +          LLVM_DEBUG(dbgs() << "LAA: Found a non-simple load.\n"); +          CanVecMem = false; +          return; +        } +        NumLoads++; +        Loads.push_back(Ld); +        DepChecker->addAccess(Ld); +        if (EnableMemAccessVersioning) +          collectStridedAccess(Ld); +        continue; +      } + +      // Save 'store' instructions. Abort if other instructions write to memory. +      if (I.mayWriteToMemory()) { +        auto *St = dyn_cast<StoreInst>(&I); +        if (!St) { +          recordAnalysis("CantVectorizeInstruction", St) +              << "instruction cannot be vectorized"; +          CanVecMem = false; +          return; +        } +        if (!St->isSimple() && !IsAnnotatedParallel) { +          recordAnalysis("NonSimpleStore", St) +              << "write with atomic ordering or volatile write"; +          LLVM_DEBUG(dbgs() << "LAA: Found a non-simple store.\n"); +          CanVecMem = false; +          return; +        } +        NumStores++; +        Stores.push_back(St); +        DepChecker->addAccess(St); +        if (EnableMemAccessVersioning) +          collectStridedAccess(St); +      } +    } // Next instr. +  } // Next block. + +  // Now we have two lists that hold the loads and the stores. +  // Next, we find the pointers that they use. + +  // Check if we see any stores. If there are no stores, then we don't +  // care if the pointers are *restrict*. +  if (!Stores.size()) { +    LLVM_DEBUG(dbgs() << "LAA: Found a read-only loop!\n"); +    CanVecMem = true; +    return; +  } + +  MemoryDepChecker::DepCandidates DependentAccesses; +  AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), +                          TheLoop, AA, LI, DependentAccesses, *PSE); + +  // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects +  // multiple times on the same object. If the ptr is accessed twice, once +  // for read and once for write, it will only appear once (on the write +  // list). This is okay, since we are going to check for conflicts between +  // writes and between reads and writes, but not between reads and reads. +  ValueSet Seen; + +  for (StoreInst *ST : Stores) { +    Value *Ptr = ST->getPointerOperand(); +    // Check for store to loop invariant address. +    StoreToLoopInvariantAddress |= isUniform(Ptr); +    // If we did *not* see this pointer before, insert it to  the read-write +    // list. At this phase it is only a 'write' list. +    if (Seen.insert(Ptr).second) { +      ++NumReadWrites; + +      MemoryLocation Loc = MemoryLocation::get(ST); +      // The TBAA metadata could have a control dependency on the predication +      // condition, so we cannot rely on it when determining whether or not we +      // need runtime pointer checks. +      if (blockNeedsPredication(ST->getParent(), TheLoop, DT)) +        Loc.AATags.TBAA = nullptr; + +      Accesses.addStore(Loc); +    } +  } + +  if (IsAnnotatedParallel) { +    LLVM_DEBUG( +        dbgs() << "LAA: A loop annotated parallel, ignore memory dependency " +               << "checks.\n"); +    CanVecMem = true; +    return; +  } + +  for (LoadInst *LD : Loads) { +    Value *Ptr = LD->getPointerOperand(); +    // If we did *not* see this pointer before, insert it to the +    // read list. If we *did* see it before, then it is already in +    // the read-write list. This allows us to vectorize expressions +    // such as A[i] += x;  Because the address of A[i] is a read-write +    // pointer. This only works if the index of A[i] is consecutive. +    // If the address of i is unknown (for example A[B[i]]) then we may +    // read a few words, modify, and write a few words, and some of the +    // words may be written to the same address. +    bool IsReadOnlyPtr = false; +    if (Seen.insert(Ptr).second || +        !getPtrStride(*PSE, Ptr, TheLoop, SymbolicStrides)) { +      ++NumReads; +      IsReadOnlyPtr = true; +    } + +    MemoryLocation Loc = MemoryLocation::get(LD); +    // The TBAA metadata could have a control dependency on the predication +    // condition, so we cannot rely on it when determining whether or not we +    // need runtime pointer checks. +    if (blockNeedsPredication(LD->getParent(), TheLoop, DT)) +      Loc.AATags.TBAA = nullptr; + +    Accesses.addLoad(Loc, IsReadOnlyPtr); +  } + +  // If we write (or read-write) to a single destination and there are no +  // other reads in this loop then is it safe to vectorize. +  if (NumReadWrites == 1 && NumReads == 0) { +    LLVM_DEBUG(dbgs() << "LAA: Found a write-only loop!\n"); +    CanVecMem = true; +    return; +  } + +  // Build dependence sets and check whether we need a runtime pointer bounds +  // check. +  Accesses.buildDependenceSets(); + +  // Find pointers with computable bounds. We are going to use this information +  // to place a runtime bound check. +  bool CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, PSE->getSE(), +                                                  TheLoop, SymbolicStrides); +  if (!CanDoRTIfNeeded) { +    recordAnalysis("CantIdentifyArrayBounds") << "cannot identify array bounds"; +    LLVM_DEBUG(dbgs() << "LAA: We can't vectorize because we can't find " +                      << "the array bounds.\n"); +    CanVecMem = false; +    return; +  } + +  LLVM_DEBUG( +      dbgs() << "LAA: We can perform a memory runtime check if needed.\n"); + +  CanVecMem = true; +  if (Accesses.isDependencyCheckNeeded()) { +    LLVM_DEBUG(dbgs() << "LAA: Checking memory dependencies\n"); +    CanVecMem = DepChecker->areDepsSafe( +        DependentAccesses, Accesses.getDependenciesToCheck(), SymbolicStrides); +    MaxSafeDepDistBytes = DepChecker->getMaxSafeDepDistBytes(); + +    if (!CanVecMem && DepChecker->shouldRetryWithRuntimeCheck()) { +      LLVM_DEBUG(dbgs() << "LAA: Retrying with memory checks\n"); + +      // Clear the dependency checks. We assume they are not needed. +      Accesses.resetDepChecks(*DepChecker); + +      PtrRtChecking->reset(); +      PtrRtChecking->Need = true; + +      auto *SE = PSE->getSE(); +      CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, SE, TheLoop, +                                                 SymbolicStrides, true); + +      // Check that we found the bounds for the pointer. +      if (!CanDoRTIfNeeded) { +        recordAnalysis("CantCheckMemDepsAtRunTime") +            << "cannot check memory dependencies at runtime"; +        LLVM_DEBUG(dbgs() << "LAA: Can't vectorize with memory checks\n"); +        CanVecMem = false; +        return; +      } + +      CanVecMem = true; +    } +  } + +  if (CanVecMem) +    LLVM_DEBUG( +        dbgs() << "LAA: No unsafe dependent memory operations in loop.  We" +               << (PtrRtChecking->Need ? "" : " don't") +               << " need runtime memory checks.\n"); +  else { +    recordAnalysis("UnsafeMemDep") +        << "unsafe dependent memory operations in loop. Use " +           "#pragma loop distribute(enable) to allow loop distribution " +           "to attempt to isolate the offending operations into a separate " +           "loop"; +    LLVM_DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n"); +  } +} + +bool LoopAccessInfo::blockNeedsPredication(BasicBlock *BB, Loop *TheLoop, +                                           DominatorTree *DT)  { +  assert(TheLoop->contains(BB) && "Unknown block used"); + +  // Blocks that do not dominate the latch need predication. +  BasicBlock* Latch = TheLoop->getLoopLatch(); +  return !DT->dominates(BB, Latch); +} + +OptimizationRemarkAnalysis &LoopAccessInfo::recordAnalysis(StringRef RemarkName, +                                                           Instruction *I) { +  assert(!Report && "Multiple reports generated"); + +  Value *CodeRegion = TheLoop->getHeader(); +  DebugLoc DL = TheLoop->getStartLoc(); + +  if (I) { +    CodeRegion = I->getParent(); +    // If there is no debug location attached to the instruction, revert back to +    // using the loop's. +    if (I->getDebugLoc()) +      DL = I->getDebugLoc(); +  } + +  Report = make_unique<OptimizationRemarkAnalysis>(DEBUG_TYPE, RemarkName, DL, +                                                   CodeRegion); +  return *Report; +} + +bool LoopAccessInfo::isUniform(Value *V) const { +  auto *SE = PSE->getSE(); +  // Since we rely on SCEV for uniformity, if the type is not SCEVable, it is +  // never considered uniform. +  // TODO: Is this really what we want? Even without FP SCEV, we may want some +  // trivially loop-invariant FP values to be considered uniform. +  if (!SE->isSCEVable(V->getType())) +    return false; +  return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); +} + +// FIXME: this function is currently a duplicate of the one in +// LoopVectorize.cpp. +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, +                                 Instruction *Loc) { +  if (FirstInst) +    return FirstInst; +  if (Instruction *I = dyn_cast<Instruction>(V)) +    return I->getParent() == Loc->getParent() ? I : nullptr; +  return nullptr; +} + +namespace { + +/// IR Values for the lower and upper bounds of a pointer evolution.  We +/// need to use value-handles because SCEV expansion can invalidate previously +/// expanded values.  Thus expansion of a pointer can invalidate the bounds for +/// a previous one. +struct PointerBounds { +  TrackingVH<Value> Start; +  TrackingVH<Value> End; +}; + +} // end anonymous namespace + +/// Expand code for the lower and upper bound of the pointer group \p CG +/// in \p TheLoop.  \return the values for the bounds. +static PointerBounds +expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop, +             Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE, +             const RuntimePointerChecking &PtrRtChecking) { +  Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue; +  const SCEV *Sc = SE->getSCEV(Ptr); + +  unsigned AS = Ptr->getType()->getPointerAddressSpace(); +  LLVMContext &Ctx = Loc->getContext(); + +  // Use this type for pointer arithmetic. +  Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + +  if (SE->isLoopInvariant(Sc, TheLoop)) { +    LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" +                      << *Ptr << "\n"); +    // Ptr could be in the loop body. If so, expand a new one at the correct +    // location. +    Instruction *Inst = dyn_cast<Instruction>(Ptr); +    Value *NewPtr = (Inst && TheLoop->contains(Inst)) +                        ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) +                        : Ptr; +    // We must return a half-open range, which means incrementing Sc. +    const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); +    Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); +    return {NewPtr, NewPtrPlusOne}; +  } else { +    Value *Start = nullptr, *End = nullptr; +    LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); +    Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); +    End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); +    LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High +                      << "\n"); +    return {Start, End}; +  } +} + +/// Turns a collection of checks into a collection of expanded upper and +/// lower bounds for both pointers in the check. +static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds( +    const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks, +    Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp, +    const RuntimePointerChecking &PtrRtChecking) { +  SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; + +  // Here we're relying on the SCEV Expander's cache to only emit code for the +  // same bounds once. +  transform( +      PointerChecks, std::back_inserter(ChecksWithBounds), +      [&](const RuntimePointerChecking::PointerCheck &Check) { +        PointerBounds +          First = expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking), +          Second = expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking); +        return std::make_pair(First, Second); +      }); + +  return ChecksWithBounds; +} + +std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks( +    Instruction *Loc, +    const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks) +    const { +  const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); +  auto *SE = PSE->getSE(); +  SCEVExpander Exp(*SE, DL, "induction"); +  auto ExpandedChecks = +      expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, *PtrRtChecking); + +  LLVMContext &Ctx = Loc->getContext(); +  Instruction *FirstInst = nullptr; +  IRBuilder<> ChkBuilder(Loc); +  // Our instructions might fold to a constant. +  Value *MemoryRuntimeCheck = nullptr; + +  for (const auto &Check : ExpandedChecks) { +    const PointerBounds &A = Check.first, &B = Check.second; +    // Check if two pointers (A and B) conflict where conflict is computed as: +    // start(A) <= end(B) && start(B) <= end(A) +    unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); +    unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); + +    assert((AS0 == B.End->getType()->getPointerAddressSpace()) && +           (AS1 == A.End->getType()->getPointerAddressSpace()) && +           "Trying to bounds check pointers with different address spaces"); + +    Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); +    Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); + +    Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); +    Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); +    Value *End0 =   ChkBuilder.CreateBitCast(A.End,   PtrArithTy1, "bc"); +    Value *End1 =   ChkBuilder.CreateBitCast(B.End,   PtrArithTy0, "bc"); + +    // [A|B].Start points to the first accessed byte under base [A|B]. +    // [A|B].End points to the last accessed byte, plus one. +    // There is no conflict when the intervals are disjoint: +    // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) +    // +    // bound0 = (B.Start < A.End) +    // bound1 = (A.Start < B.End) +    //  IsConflict = bound0 & bound1 +    Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); +    FirstInst = getFirstInst(FirstInst, Cmp0, Loc); +    Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); +    FirstInst = getFirstInst(FirstInst, Cmp1, Loc); +    Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); +    FirstInst = getFirstInst(FirstInst, IsConflict, Loc); +    if (MemoryRuntimeCheck) { +      IsConflict = +          ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); +      FirstInst = getFirstInst(FirstInst, IsConflict, Loc); +    } +    MemoryRuntimeCheck = IsConflict; +  } + +  if (!MemoryRuntimeCheck) +    return std::make_pair(nullptr, nullptr); + +  // We have to do this trickery because the IRBuilder might fold the check to a +  // constant expression in which case there is no Instruction anchored in a +  // the block. +  Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck, +                                                 ConstantInt::getTrue(Ctx)); +  ChkBuilder.Insert(Check, "memcheck.conflict"); +  FirstInst = getFirstInst(FirstInst, Check, Loc); +  return std::make_pair(FirstInst, Check); +} + +std::pair<Instruction *, Instruction *> +LoopAccessInfo::addRuntimeChecks(Instruction *Loc) const { +  if (!PtrRtChecking->Need) +    return std::make_pair(nullptr, nullptr); + +  return addRuntimeChecks(Loc, PtrRtChecking->getChecks()); +} + +void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { +  Value *Ptr = nullptr; +  if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess)) +    Ptr = LI->getPointerOperand(); +  else if (StoreInst *SI = dyn_cast<StoreInst>(MemAccess)) +    Ptr = SI->getPointerOperand(); +  else +    return; + +  Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop); +  if (!Stride) +    return; + +  LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for " +                       "versioning:"); +  LLVM_DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + +  // Avoid adding the "Stride == 1" predicate when we know that +  // Stride >= Trip-Count. Such a predicate will effectively optimize a single +  // or zero iteration loop, as Trip-Count <= Stride == 1. +  // +  // TODO: We are currently not making a very informed decision on when it is +  // beneficial to apply stride versioning. It might make more sense that the +  // users of this analysis (such as the vectorizer) will trigger it, based on +  // their specific cost considerations; For example, in cases where stride +  // versioning does  not help resolving memory accesses/dependences, the +  // vectorizer should evaluate the cost of the runtime test, and the benefit +  // of various possible stride specializations, considering the alternatives +  // of using gather/scatters (if available). + +  const SCEV *StrideExpr = PSE->getSCEV(Stride); +  const SCEV *BETakenCount = PSE->getBackedgeTakenCount(); + +  // Match the types so we can compare the stride and the BETakenCount. +  // The Stride can be positive/negative, so we sign extend Stride; +  // The backdgeTakenCount is non-negative, so we zero extend BETakenCount. +  const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); +  uint64_t StrideTypeSize = DL.getTypeAllocSize(StrideExpr->getType()); +  uint64_t BETypeSize = DL.getTypeAllocSize(BETakenCount->getType()); +  const SCEV *CastedStride = StrideExpr; +  const SCEV *CastedBECount = BETakenCount; +  ScalarEvolution *SE = PSE->getSE(); +  if (BETypeSize >= StrideTypeSize) +    CastedStride = SE->getNoopOrSignExtend(StrideExpr, BETakenCount->getType()); +  else +    CastedBECount = SE->getZeroExtendExpr(BETakenCount, StrideExpr->getType()); +  const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount); +  // Since TripCount == BackEdgeTakenCount + 1, checking: +  // "Stride >= TripCount" is equivalent to checking: +  // Stride - BETakenCount > 0 +  if (SE->isKnownPositive(StrideMinusBETaken)) { +    LLVM_DEBUG( +        dbgs() << "LAA: Stride>=TripCount; No point in versioning as the " +                  "Stride==1 predicate will imply that the loop executes " +                  "at most once.\n"); +    return; +  } +  LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version."); + +  SymbolicStrides[Ptr] = Stride; +  StrideSet.insert(Stride); +} + +LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, +                               const TargetLibraryInfo *TLI, AliasAnalysis *AA, +                               DominatorTree *DT, LoopInfo *LI) +    : PSE(llvm::make_unique<PredicatedScalarEvolution>(*SE, *L)), +      PtrRtChecking(llvm::make_unique<RuntimePointerChecking>(SE)), +      DepChecker(llvm::make_unique<MemoryDepChecker>(*PSE, L)), TheLoop(L), +      NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false), +      StoreToLoopInvariantAddress(false) { +  if (canAnalyzeLoop()) +    analyzeLoop(AA, LI, TLI, DT); +} + +void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { +  if (CanVecMem) { +    OS.indent(Depth) << "Memory dependences are safe"; +    if (MaxSafeDepDistBytes != -1ULL) +      OS << " with a maximum dependence distance of " << MaxSafeDepDistBytes +         << " bytes"; +    if (PtrRtChecking->Need) +      OS << " with run-time checks"; +    OS << "\n"; +  } + +  if (Report) +    OS.indent(Depth) << "Report: " << Report->getMsg() << "\n"; + +  if (auto *Dependences = DepChecker->getDependences()) { +    OS.indent(Depth) << "Dependences:\n"; +    for (auto &Dep : *Dependences) { +      Dep.print(OS, Depth + 2, DepChecker->getMemoryInstructions()); +      OS << "\n"; +    } +  } else +    OS.indent(Depth) << "Too many dependences, not recorded\n"; + +  // List the pair of accesses need run-time checks to prove independence. +  PtrRtChecking->print(OS, Depth); +  OS << "\n"; + +  OS.indent(Depth) << "Store to invariant address was " +                   << (StoreToLoopInvariantAddress ? "" : "not ") +                   << "found in loop.\n"; + +  OS.indent(Depth) << "SCEV assumptions:\n"; +  PSE->getUnionPredicate().print(OS, Depth); + +  OS << "\n"; + +  OS.indent(Depth) << "Expressions re-written:\n"; +  PSE->print(OS, Depth); +} + +const LoopAccessInfo &LoopAccessLegacyAnalysis::getInfo(Loop *L) { +  auto &LAI = LoopAccessInfoMap[L]; + +  if (!LAI) +    LAI = llvm::make_unique<LoopAccessInfo>(L, SE, TLI, AA, DT, LI); + +  return *LAI.get(); +} + +void LoopAccessLegacyAnalysis::print(raw_ostream &OS, const Module *M) const { +  LoopAccessLegacyAnalysis &LAA = *const_cast<LoopAccessLegacyAnalysis *>(this); + +  for (Loop *TopLevelLoop : *LI) +    for (Loop *L : depth_first(TopLevelLoop)) { +      OS.indent(2) << L->getHeader()->getName() << ":\n"; +      auto &LAI = LAA.getInfo(L); +      LAI.print(OS, 4); +    } +} + +bool LoopAccessLegacyAnalysis::runOnFunction(Function &F) { +  SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); +  auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); +  TLI = TLIP ? &TLIP->getTLI() : nullptr; +  AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); +  DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + +  return false; +} + +void LoopAccessLegacyAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { +    AU.addRequired<ScalarEvolutionWrapperPass>(); +    AU.addRequired<AAResultsWrapperPass>(); +    AU.addRequired<DominatorTreeWrapperPass>(); +    AU.addRequired<LoopInfoWrapperPass>(); + +    AU.setPreservesAll(); +} + +char LoopAccessLegacyAnalysis::ID = 0; +static const char laa_name[] = "Loop Access Analysis"; +#define LAA_NAME "loop-accesses" + +INITIALIZE_PASS_BEGIN(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) + +AnalysisKey LoopAccessAnalysis::Key; + +LoopAccessInfo LoopAccessAnalysis::run(Loop &L, LoopAnalysisManager &AM, +                                       LoopStandardAnalysisResults &AR) { +  return LoopAccessInfo(&L, &AR.SE, &AR.TLI, &AR.AA, &AR.DT, &AR.LI); +} + +namespace llvm { + +  Pass *createLAAPass() { +    return new LoopAccessLegacyAnalysis(); +  } + +} // end namespace llvm diff --git a/contrib/llvm/lib/Analysis/LoopAnalysisManager.cpp b/contrib/llvm/lib/Analysis/LoopAnalysisManager.cpp new file mode 100644 index 000000000000..074023a7e1e2 --- /dev/null +++ b/contrib/llvm/lib/Analysis/LoopAnalysisManager.cpp @@ -0,0 +1,159 @@ +//===- LoopAnalysisManager.cpp - Loop analysis management -----------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +namespace llvm { +/// Enables memory ssa as a dependency for loop passes in legacy pass manager. +cl::opt<bool> EnableMSSALoopDependency( +    "enable-mssa-loop-dependency", cl::Hidden, cl::init(false), +    cl::desc("Enable MemorySSA dependency for loop pass manager")); + +// Explicit template instantiations and specialization definitions for core +// template typedefs. +template class AllAnalysesOn<Loop>; +template class AnalysisManager<Loop, LoopStandardAnalysisResults &>; +template class InnerAnalysisManagerProxy<LoopAnalysisManager, Function>; +template class OuterAnalysisManagerProxy<FunctionAnalysisManager, Loop, +                                         LoopStandardAnalysisResults &>; + +bool LoopAnalysisManagerFunctionProxy::Result::invalidate( +    Function &F, const PreservedAnalyses &PA, +    FunctionAnalysisManager::Invalidator &Inv) { +  // First compute the sequence of IR units covered by this proxy. We will want +  // to visit this in postorder, but because this is a tree structure we can do +  // this by building a preorder sequence and walking it backwards. We also +  // want siblings in forward program order to match the LoopPassManager so we +  // get the preorder with siblings reversed. +  SmallVector<Loop *, 4> PreOrderLoops = LI->getLoopsInReverseSiblingPreorder(); + +  // If this proxy or the loop info is going to be invalidated, we also need +  // to clear all the keys coming from that analysis. We also completely blow +  // away the loop analyses if any of the standard analyses provided by the +  // loop pass manager go away so that loop analyses can freely use these +  // without worrying about declaring dependencies on them etc. +  // FIXME: It isn't clear if this is the right tradeoff. We could instead make +  // loop analyses declare any dependencies on these and use the more general +  // invalidation logic below to act on that. +  auto PAC = PA.getChecker<LoopAnalysisManagerFunctionProxy>(); +  bool invalidateMemorySSAAnalysis = false; +  if (EnableMSSALoopDependency) +    invalidateMemorySSAAnalysis = Inv.invalidate<MemorySSAAnalysis>(F, PA); +  if (!(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || +      Inv.invalidate<AAManager>(F, PA) || +      Inv.invalidate<AssumptionAnalysis>(F, PA) || +      Inv.invalidate<DominatorTreeAnalysis>(F, PA) || +      Inv.invalidate<LoopAnalysis>(F, PA) || +      Inv.invalidate<ScalarEvolutionAnalysis>(F, PA) || +      invalidateMemorySSAAnalysis) { +    // Note that the LoopInfo may be stale at this point, however the loop +    // objects themselves remain the only viable keys that could be in the +    // analysis manager's cache. So we just walk the keys and forcibly clear +    // those results. Note that the order doesn't matter here as this will just +    // directly destroy the results without calling methods on them. +    for (Loop *L : PreOrderLoops) { +      // NB! `L` may not be in a good enough state to run Loop::getName. +      InnerAM->clear(*L, "<possibly invalidated loop>"); +    } + +    // We also need to null out the inner AM so that when the object gets +    // destroyed as invalid we don't try to clear the inner AM again. At that +    // point we won't be able to reliably walk the loops for this function and +    // only clear results associated with those loops the way we do here. +    // FIXME: Making InnerAM null at this point isn't very nice. Most analyses +    // try to remain valid during invalidation. Maybe we should add an +    // `IsClean` flag? +    InnerAM = nullptr; + +    // Now return true to indicate this *is* invalid and a fresh proxy result +    // needs to be built. This is especially important given the null InnerAM. +    return true; +  } + +  // Directly check if the relevant set is preserved so we can short circuit +  // invalidating loops. +  bool AreLoopAnalysesPreserved = +      PA.allAnalysesInSetPreserved<AllAnalysesOn<Loop>>(); + +  // Since we have a valid LoopInfo we can actually leave the cached results in +  // the analysis manager associated with the Loop keys, but we need to +  // propagate any necessary invalidation logic into them. We'd like to +  // invalidate things in roughly the same order as they were put into the +  // cache and so we walk the preorder list in reverse to form a valid +  // postorder. +  for (Loop *L : reverse(PreOrderLoops)) { +    Optional<PreservedAnalyses> InnerPA; + +    // Check to see whether the preserved set needs to be adjusted based on +    // function-level analysis invalidation triggering deferred invalidation +    // for this loop. +    if (auto *OuterProxy = +            InnerAM->getCachedResult<FunctionAnalysisManagerLoopProxy>(*L)) +      for (const auto &OuterInvalidationPair : +           OuterProxy->getOuterInvalidations()) { +        AnalysisKey *OuterAnalysisID = OuterInvalidationPair.first; +        const auto &InnerAnalysisIDs = OuterInvalidationPair.second; +        if (Inv.invalidate(OuterAnalysisID, F, PA)) { +          if (!InnerPA) +            InnerPA = PA; +          for (AnalysisKey *InnerAnalysisID : InnerAnalysisIDs) +            InnerPA->abandon(InnerAnalysisID); +        } +      } + +    // Check if we needed a custom PA set. If so we'll need to run the inner +    // invalidation. +    if (InnerPA) { +      InnerAM->invalidate(*L, *InnerPA); +      continue; +    } + +    // Otherwise we only need to do invalidation if the original PA set didn't +    // preserve all Loop analyses. +    if (!AreLoopAnalysesPreserved) +      InnerAM->invalidate(*L, PA); +  } + +  // Return false to indicate that this result is still a valid proxy. +  return false; +} + +template <> +LoopAnalysisManagerFunctionProxy::Result +LoopAnalysisManagerFunctionProxy::run(Function &F, +                                      FunctionAnalysisManager &AM) { +  return Result(*InnerAM, AM.getResult<LoopAnalysis>(F)); +} +} + +PreservedAnalyses llvm::getLoopPassPreservedAnalyses() { +  PreservedAnalyses PA; +  PA.preserve<DominatorTreeAnalysis>(); +  PA.preserve<LoopAnalysis>(); +  PA.preserve<LoopAnalysisManagerFunctionProxy>(); +  PA.preserve<ScalarEvolutionAnalysis>(); +  // FIXME: Uncomment this when all loop passes preserve MemorySSA +  // PA.preserve<MemorySSAAnalysis>(); +  // FIXME: What we really want to do here is preserve an AA category, but that +  // concept doesn't exist yet. +  PA.preserve<AAManager>(); +  PA.preserve<BasicAA>(); +  PA.preserve<GlobalsAA>(); +  PA.preserve<SCEVAA>(); +  return PA; +} diff --git a/contrib/llvm/lib/Analysis/LoopInfo.cpp b/contrib/llvm/lib/Analysis/LoopInfo.cpp new file mode 100644 index 000000000000..3f78456b3586 --- /dev/null +++ b/contrib/llvm/lib/Analysis/LoopInfo.cpp @@ -0,0 +1,770 @@ +//===- LoopInfo.cpp - Natural Loop Calculator -----------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the LoopInfo class that is used to identify natural loops +// and determine the loop depth of various nodes of the CFG.  Note that the +// loops identified may actually be several natural loops that share the same +// header node... not just a single natural loop. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/LoopInfoImpl.h" +#include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +using namespace llvm; + +// Explicitly instantiate methods in LoopInfoImpl.h for IR-level Loops. +template class llvm::LoopBase<BasicBlock, Loop>; +template class llvm::LoopInfoBase<BasicBlock, Loop>; + +// Always verify loopinfo if expensive checking is enabled. +#ifdef EXPENSIVE_CHECKS +bool llvm::VerifyLoopInfo = true; +#else +bool llvm::VerifyLoopInfo = false; +#endif +static cl::opt<bool, true> +    VerifyLoopInfoX("verify-loop-info", cl::location(VerifyLoopInfo), +                    cl::Hidden, cl::desc("Verify loop info (time consuming)")); + +//===----------------------------------------------------------------------===// +// Loop implementation +// + +bool Loop::isLoopInvariant(const Value *V) const { +  if (const Instruction *I = dyn_cast<Instruction>(V)) +    return !contains(I); +  return true; // All non-instructions are loop invariant +} + +bool Loop::hasLoopInvariantOperands(const Instruction *I) const { +  return all_of(I->operands(), [this](Value *V) { return isLoopInvariant(V); }); +} + +bool Loop::makeLoopInvariant(Value *V, bool &Changed, +                             Instruction *InsertPt) const { +  if (Instruction *I = dyn_cast<Instruction>(V)) +    return makeLoopInvariant(I, Changed, InsertPt); +  return true; // All non-instructions are loop-invariant. +} + +bool Loop::makeLoopInvariant(Instruction *I, bool &Changed, +                             Instruction *InsertPt) const { +  // Test if the value is already loop-invariant. +  if (isLoopInvariant(I)) +    return true; +  if (!isSafeToSpeculativelyExecute(I)) +    return false; +  if (I->mayReadFromMemory()) +    return false; +  // EH block instructions are immobile. +  if (I->isEHPad()) +    return false; +  // Determine the insertion point, unless one was given. +  if (!InsertPt) { +    BasicBlock *Preheader = getLoopPreheader(); +    // Without a preheader, hoisting is not feasible. +    if (!Preheader) +      return false; +    InsertPt = Preheader->getTerminator(); +  } +  // Don't hoist instructions with loop-variant operands. +  for (Value *Operand : I->operands()) +    if (!makeLoopInvariant(Operand, Changed, InsertPt)) +      return false; + +  // Hoist. +  I->moveBefore(InsertPt); + +  // There is possibility of hoisting this instruction above some arbitrary +  // condition. Any metadata defined on it can be control dependent on this +  // condition. Conservatively strip it here so that we don't give any wrong +  // information to the optimizer. +  I->dropUnknownNonDebugMetadata(); + +  Changed = true; +  return true; +} + +PHINode *Loop::getCanonicalInductionVariable() const { +  BasicBlock *H = getHeader(); + +  BasicBlock *Incoming = nullptr, *Backedge = nullptr; +  pred_iterator PI = pred_begin(H); +  assert(PI != pred_end(H) && "Loop must have at least one backedge!"); +  Backedge = *PI++; +  if (PI == pred_end(H)) +    return nullptr; // dead loop +  Incoming = *PI++; +  if (PI != pred_end(H)) +    return nullptr; // multiple backedges? + +  if (contains(Incoming)) { +    if (contains(Backedge)) +      return nullptr; +    std::swap(Incoming, Backedge); +  } else if (!contains(Backedge)) +    return nullptr; + +  // Loop over all of the PHI nodes, looking for a canonical indvar. +  for (BasicBlock::iterator I = H->begin(); isa<PHINode>(I); ++I) { +    PHINode *PN = cast<PHINode>(I); +    if (ConstantInt *CI = +            dyn_cast<ConstantInt>(PN->getIncomingValueForBlock(Incoming))) +      if (CI->isZero()) +        if (Instruction *Inc = +                dyn_cast<Instruction>(PN->getIncomingValueForBlock(Backedge))) +          if (Inc->getOpcode() == Instruction::Add && Inc->getOperand(0) == PN) +            if (ConstantInt *CI = dyn_cast<ConstantInt>(Inc->getOperand(1))) +              if (CI->isOne()) +                return PN; +  } +  return nullptr; +} + +// Check that 'BB' doesn't have any uses outside of the 'L' +static bool isBlockInLCSSAForm(const Loop &L, const BasicBlock &BB, +                               DominatorTree &DT) { +  for (const Instruction &I : BB) { +    // Tokens can't be used in PHI nodes and live-out tokens prevent loop +    // optimizations, so for the purposes of considered LCSSA form, we +    // can ignore them. +    if (I.getType()->isTokenTy()) +      continue; + +    for (const Use &U : I.uses()) { +      const Instruction *UI = cast<Instruction>(U.getUser()); +      const BasicBlock *UserBB = UI->getParent(); +      if (const PHINode *P = dyn_cast<PHINode>(UI)) +        UserBB = P->getIncomingBlock(U); + +      // Check the current block, as a fast-path, before checking whether +      // the use is anywhere in the loop.  Most values are used in the same +      // block they are defined in.  Also, blocks not reachable from the +      // entry are special; uses in them don't need to go through PHIs. +      if (UserBB != &BB && !L.contains(UserBB) && +          DT.isReachableFromEntry(UserBB)) +        return false; +    } +  } +  return true; +} + +bool Loop::isLCSSAForm(DominatorTree &DT) const { +  // For each block we check that it doesn't have any uses outside of this loop. +  return all_of(this->blocks(), [&](const BasicBlock *BB) { +    return isBlockInLCSSAForm(*this, *BB, DT); +  }); +} + +bool Loop::isRecursivelyLCSSAForm(DominatorTree &DT, const LoopInfo &LI) const { +  // For each block we check that it doesn't have any uses outside of its +  // innermost loop. This process will transitively guarantee that the current +  // loop and all of the nested loops are in LCSSA form. +  return all_of(this->blocks(), [&](const BasicBlock *BB) { +    return isBlockInLCSSAForm(*LI.getLoopFor(BB), *BB, DT); +  }); +} + +bool Loop::isLoopSimplifyForm() const { +  // Normal-form loops have a preheader, a single backedge, and all of their +  // exits have all their predecessors inside the loop. +  return getLoopPreheader() && getLoopLatch() && hasDedicatedExits(); +} + +// Routines that reform the loop CFG and split edges often fail on indirectbr. +bool Loop::isSafeToClone() const { +  // Return false if any loop blocks contain indirectbrs, or there are any calls +  // to noduplicate functions. +  for (BasicBlock *BB : this->blocks()) { +    if (isa<IndirectBrInst>(BB->getTerminator())) +      return false; + +    for (Instruction &I : *BB) +      if (auto CS = CallSite(&I)) +        if (CS.cannotDuplicate()) +          return false; +  } +  return true; +} + +MDNode *Loop::getLoopID() const { +  MDNode *LoopID = nullptr; +  if (BasicBlock *Latch = getLoopLatch()) { +    LoopID = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop); +  } else { +    assert(!getLoopLatch() && +           "The loop should have no single latch at this point"); +    // Go through each predecessor of the loop header and check the +    // terminator for the metadata. +    BasicBlock *H = getHeader(); +    for (BasicBlock *BB : this->blocks()) { +      TerminatorInst *TI = BB->getTerminator(); +      MDNode *MD = nullptr; + +      // Check if this terminator branches to the loop header. +      for (BasicBlock *Successor : TI->successors()) { +        if (Successor == H) { +          MD = TI->getMetadata(LLVMContext::MD_loop); +          break; +        } +      } +      if (!MD) +        return nullptr; + +      if (!LoopID) +        LoopID = MD; +      else if (MD != LoopID) +        return nullptr; +    } +  } +  if (!LoopID || LoopID->getNumOperands() == 0 || +      LoopID->getOperand(0) != LoopID) +    return nullptr; +  return LoopID; +} + +void Loop::setLoopID(MDNode *LoopID) const { +  assert(LoopID && "Loop ID should not be null"); +  assert(LoopID->getNumOperands() > 0 && "Loop ID needs at least one operand"); +  assert(LoopID->getOperand(0) == LoopID && "Loop ID should refer to itself"); + +  if (BasicBlock *Latch = getLoopLatch()) { +    Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); +    return; +  } + +  assert(!getLoopLatch() && +         "The loop should have no single latch at this point"); +  BasicBlock *H = getHeader(); +  for (BasicBlock *BB : this->blocks()) { +    TerminatorInst *TI = BB->getTerminator(); +    for (BasicBlock *Successor : TI->successors()) { +      if (Successor == H) +        TI->setMetadata(LLVMContext::MD_loop, LoopID); +    } +  } +} + +void Loop::setLoopAlreadyUnrolled() { +  MDNode *LoopID = getLoopID(); +  // First remove any existing loop unrolling metadata. +  SmallVector<Metadata *, 4> MDs; +  // Reserve first location for self reference to the LoopID metadata node. +  MDs.push_back(nullptr); + +  if (LoopID) { +    for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { +      bool IsUnrollMetadata = false; +      MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); +      if (MD) { +        const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); +        IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); +      } +      if (!IsUnrollMetadata) +        MDs.push_back(LoopID->getOperand(i)); +    } +  } + +  // Add unroll(disable) metadata to disable future unrolling. +  LLVMContext &Context = getHeader()->getContext(); +  SmallVector<Metadata *, 1> DisableOperands; +  DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable")); +  MDNode *DisableNode = MDNode::get(Context, DisableOperands); +  MDs.push_back(DisableNode); + +  MDNode *NewLoopID = MDNode::get(Context, MDs); +  // Set operand 0 to refer to the loop id itself. +  NewLoopID->replaceOperandWith(0, NewLoopID); +  setLoopID(NewLoopID); +} + +bool Loop::isAnnotatedParallel() const { +  MDNode *DesiredLoopIdMetadata = getLoopID(); + +  if (!DesiredLoopIdMetadata) +    return false; + +  // The loop branch contains the parallel loop metadata. In order to ensure +  // that any parallel-loop-unaware optimization pass hasn't added loop-carried +  // dependencies (thus converted the loop back to a sequential loop), check +  // that all the memory instructions in the loop contain parallelism metadata +  // that point to the same unique "loop id metadata" the loop branch does. +  for (BasicBlock *BB : this->blocks()) { +    for (Instruction &I : *BB) { +      if (!I.mayReadOrWriteMemory()) +        continue; + +      // The memory instruction can refer to the loop identifier metadata +      // directly or indirectly through another list metadata (in case of +      // nested parallel loops). The loop identifier metadata refers to +      // itself so we can check both cases with the same routine. +      MDNode *LoopIdMD = +          I.getMetadata(LLVMContext::MD_mem_parallel_loop_access); + +      if (!LoopIdMD) +        return false; + +      bool LoopIdMDFound = false; +      for (const MDOperand &MDOp : LoopIdMD->operands()) { +        if (MDOp == DesiredLoopIdMetadata) { +          LoopIdMDFound = true; +          break; +        } +      } + +      if (!LoopIdMDFound) +        return false; +    } +  } +  return true; +} + +DebugLoc Loop::getStartLoc() const { return getLocRange().getStart(); } + +Loop::LocRange Loop::getLocRange() const { +  // If we have a debug location in the loop ID, then use it. +  if (MDNode *LoopID = getLoopID()) { +    DebugLoc Start; +    // We use the first DebugLoc in the header as the start location of the loop +    // and if there is a second DebugLoc in the header we use it as end location +    // of the loop. +    for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { +      if (DILocation *L = dyn_cast<DILocation>(LoopID->getOperand(i))) { +        if (!Start) +          Start = DebugLoc(L); +        else +          return LocRange(Start, DebugLoc(L)); +      } +    } + +    if (Start) +      return LocRange(Start); +  } + +  // Try the pre-header first. +  if (BasicBlock *PHeadBB = getLoopPreheader()) +    if (DebugLoc DL = PHeadBB->getTerminator()->getDebugLoc()) +      return LocRange(DL); + +  // If we have no pre-header or there are no instructions with debug +  // info in it, try the header. +  if (BasicBlock *HeadBB = getHeader()) +    return LocRange(HeadBB->getTerminator()->getDebugLoc()); + +  return LocRange(); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void Loop::dump() const { print(dbgs()); } + +LLVM_DUMP_METHOD void Loop::dumpVerbose() const { +  print(dbgs(), /*Depth=*/0, /*Verbose=*/true); +} +#endif + +//===----------------------------------------------------------------------===// +// UnloopUpdater implementation +// + +namespace { +/// Find the new parent loop for all blocks within the "unloop" whose last +/// backedges has just been removed. +class UnloopUpdater { +  Loop &Unloop; +  LoopInfo *LI; + +  LoopBlocksDFS DFS; + +  // Map unloop's immediate subloops to their nearest reachable parents. Nested +  // loops within these subloops will not change parents. However, an immediate +  // subloop's new parent will be the nearest loop reachable from either its own +  // exits *or* any of its nested loop's exits. +  DenseMap<Loop *, Loop *> SubloopParents; + +  // Flag the presence of an irreducible backedge whose destination is a block +  // directly contained by the original unloop. +  bool FoundIB; + +public: +  UnloopUpdater(Loop *UL, LoopInfo *LInfo) +      : Unloop(*UL), LI(LInfo), DFS(UL), FoundIB(false) {} + +  void updateBlockParents(); + +  void removeBlocksFromAncestors(); + +  void updateSubloopParents(); + +protected: +  Loop *getNearestLoop(BasicBlock *BB, Loop *BBLoop); +}; +} // end anonymous namespace + +/// Update the parent loop for all blocks that are directly contained within the +/// original "unloop". +void UnloopUpdater::updateBlockParents() { +  if (Unloop.getNumBlocks()) { +    // Perform a post order CFG traversal of all blocks within this loop, +    // propagating the nearest loop from successors to predecessors. +    LoopBlocksTraversal Traversal(DFS, LI); +    for (BasicBlock *POI : Traversal) { + +      Loop *L = LI->getLoopFor(POI); +      Loop *NL = getNearestLoop(POI, L); + +      if (NL != L) { +        // For reducible loops, NL is now an ancestor of Unloop. +        assert((NL != &Unloop && (!NL || NL->contains(&Unloop))) && +               "uninitialized successor"); +        LI->changeLoopFor(POI, NL); +      } else { +        // Or the current block is part of a subloop, in which case its parent +        // is unchanged. +        assert((FoundIB || Unloop.contains(L)) && "uninitialized successor"); +      } +    } +  } +  // Each irreducible loop within the unloop induces a round of iteration using +  // the DFS result cached by Traversal. +  bool Changed = FoundIB; +  for (unsigned NIters = 0; Changed; ++NIters) { +    assert(NIters < Unloop.getNumBlocks() && "runaway iterative algorithm"); + +    // Iterate over the postorder list of blocks, propagating the nearest loop +    // from successors to predecessors as before. +    Changed = false; +    for (LoopBlocksDFS::POIterator POI = DFS.beginPostorder(), +                                   POE = DFS.endPostorder(); +         POI != POE; ++POI) { + +      Loop *L = LI->getLoopFor(*POI); +      Loop *NL = getNearestLoop(*POI, L); +      if (NL != L) { +        assert(NL != &Unloop && (!NL || NL->contains(&Unloop)) && +               "uninitialized successor"); +        LI->changeLoopFor(*POI, NL); +        Changed = true; +      } +    } +  } +} + +/// Remove unloop's blocks from all ancestors below their new parents. +void UnloopUpdater::removeBlocksFromAncestors() { +  // Remove all unloop's blocks (including those in nested subloops) from +  // ancestors below the new parent loop. +  for (Loop::block_iterator BI = Unloop.block_begin(), BE = Unloop.block_end(); +       BI != BE; ++BI) { +    Loop *OuterParent = LI->getLoopFor(*BI); +    if (Unloop.contains(OuterParent)) { +      while (OuterParent->getParentLoop() != &Unloop) +        OuterParent = OuterParent->getParentLoop(); +      OuterParent = SubloopParents[OuterParent]; +    } +    // Remove blocks from former Ancestors except Unloop itself which will be +    // deleted. +    for (Loop *OldParent = Unloop.getParentLoop(); OldParent != OuterParent; +         OldParent = OldParent->getParentLoop()) { +      assert(OldParent && "new loop is not an ancestor of the original"); +      OldParent->removeBlockFromLoop(*BI); +    } +  } +} + +/// Update the parent loop for all subloops directly nested within unloop. +void UnloopUpdater::updateSubloopParents() { +  while (!Unloop.empty()) { +    Loop *Subloop = *std::prev(Unloop.end()); +    Unloop.removeChildLoop(std::prev(Unloop.end())); + +    assert(SubloopParents.count(Subloop) && "DFS failed to visit subloop"); +    if (Loop *Parent = SubloopParents[Subloop]) +      Parent->addChildLoop(Subloop); +    else +      LI->addTopLevelLoop(Subloop); +  } +} + +/// Return the nearest parent loop among this block's successors. If a successor +/// is a subloop header, consider its parent to be the nearest parent of the +/// subloop's exits. +/// +/// For subloop blocks, simply update SubloopParents and return NULL. +Loop *UnloopUpdater::getNearestLoop(BasicBlock *BB, Loop *BBLoop) { + +  // Initially for blocks directly contained by Unloop, NearLoop == Unloop and +  // is considered uninitialized. +  Loop *NearLoop = BBLoop; + +  Loop *Subloop = nullptr; +  if (NearLoop != &Unloop && Unloop.contains(NearLoop)) { +    Subloop = NearLoop; +    // Find the subloop ancestor that is directly contained within Unloop. +    while (Subloop->getParentLoop() != &Unloop) { +      Subloop = Subloop->getParentLoop(); +      assert(Subloop && "subloop is not an ancestor of the original loop"); +    } +    // Get the current nearest parent of the Subloop exits, initially Unloop. +    NearLoop = SubloopParents.insert({Subloop, &Unloop}).first->second; +  } + +  succ_iterator I = succ_begin(BB), E = succ_end(BB); +  if (I == E) { +    assert(!Subloop && "subloop blocks must have a successor"); +    NearLoop = nullptr; // unloop blocks may now exit the function. +  } +  for (; I != E; ++I) { +    if (*I == BB) +      continue; // self loops are uninteresting + +    Loop *L = LI->getLoopFor(*I); +    if (L == &Unloop) { +      // This successor has not been processed. This path must lead to an +      // irreducible backedge. +      assert((FoundIB || !DFS.hasPostorder(*I)) && "should have seen IB"); +      FoundIB = true; +    } +    if (L != &Unloop && Unloop.contains(L)) { +      // Successor is in a subloop. +      if (Subloop) +        continue; // Branching within subloops. Ignore it. + +      // BB branches from the original into a subloop header. +      assert(L->getParentLoop() == &Unloop && "cannot skip into nested loops"); + +      // Get the current nearest parent of the Subloop's exits. +      L = SubloopParents[L]; +      // L could be Unloop if the only exit was an irreducible backedge. +    } +    if (L == &Unloop) { +      continue; +    } +    // Handle critical edges from Unloop into a sibling loop. +    if (L && !L->contains(&Unloop)) { +      L = L->getParentLoop(); +    } +    // Remember the nearest parent loop among successors or subloop exits. +    if (NearLoop == &Unloop || !NearLoop || NearLoop->contains(L)) +      NearLoop = L; +  } +  if (Subloop) { +    SubloopParents[Subloop] = NearLoop; +    return BBLoop; +  } +  return NearLoop; +} + +LoopInfo::LoopInfo(const DomTreeBase<BasicBlock> &DomTree) { analyze(DomTree); } + +bool LoopInfo::invalidate(Function &F, const PreservedAnalyses &PA, +                          FunctionAnalysisManager::Invalidator &) { +  // Check whether the analysis, all analyses on functions, or the function's +  // CFG have been preserved. +  auto PAC = PA.getChecker<LoopAnalysis>(); +  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || +           PAC.preservedSet<CFGAnalyses>()); +} + +void LoopInfo::erase(Loop *Unloop) { +  assert(!Unloop->isInvalid() && "Loop has already been erased!"); + +  auto InvalidateOnExit = make_scope_exit([&]() { destroy(Unloop); }); + +  // First handle the special case of no parent loop to simplify the algorithm. +  if (!Unloop->getParentLoop()) { +    // Since BBLoop had no parent, Unloop blocks are no longer in a loop. +    for (Loop::block_iterator I = Unloop->block_begin(), +                              E = Unloop->block_end(); +         I != E; ++I) { + +      // Don't reparent blocks in subloops. +      if (getLoopFor(*I) != Unloop) +        continue; + +      // Blocks no longer have a parent but are still referenced by Unloop until +      // the Unloop object is deleted. +      changeLoopFor(*I, nullptr); +    } + +    // Remove the loop from the top-level LoopInfo object. +    for (iterator I = begin();; ++I) { +      assert(I != end() && "Couldn't find loop"); +      if (*I == Unloop) { +        removeLoop(I); +        break; +      } +    } + +    // Move all of the subloops to the top-level. +    while (!Unloop->empty()) +      addTopLevelLoop(Unloop->removeChildLoop(std::prev(Unloop->end()))); + +    return; +  } + +  // Update the parent loop for all blocks within the loop. Blocks within +  // subloops will not change parents. +  UnloopUpdater Updater(Unloop, this); +  Updater.updateBlockParents(); + +  // Remove blocks from former ancestor loops. +  Updater.removeBlocksFromAncestors(); + +  // Add direct subloops as children in their new parent loop. +  Updater.updateSubloopParents(); + +  // Remove unloop from its parent loop. +  Loop *ParentLoop = Unloop->getParentLoop(); +  for (Loop::iterator I = ParentLoop->begin();; ++I) { +    assert(I != ParentLoop->end() && "Couldn't find loop"); +    if (*I == Unloop) { +      ParentLoop->removeChildLoop(I); +      break; +    } +  } +} + +AnalysisKey LoopAnalysis::Key; + +LoopInfo LoopAnalysis::run(Function &F, FunctionAnalysisManager &AM) { +  // FIXME: Currently we create a LoopInfo from scratch for every function. +  // This may prove to be too wasteful due to deallocating and re-allocating +  // memory each time for the underlying map and vector datastructures. At some +  // point it may prove worthwhile to use a freelist and recycle LoopInfo +  // objects. I don't want to add that kind of complexity until the scope of +  // the problem is better understood. +  LoopInfo LI; +  LI.analyze(AM.getResult<DominatorTreeAnalysis>(F)); +  return LI; +} + +PreservedAnalyses LoopPrinterPass::run(Function &F, +                                       FunctionAnalysisManager &AM) { +  AM.getResult<LoopAnalysis>(F).print(OS); +  return PreservedAnalyses::all(); +} + +void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) { + +  if (forcePrintModuleIR()) { +    // handling -print-module-scope +    OS << Banner << " (loop: "; +    L.getHeader()->printAsOperand(OS, false); +    OS << ")\n"; + +    // printing whole module +    OS << *L.getHeader()->getModule(); +    return; +  } + +  OS << Banner; + +  auto *PreHeader = L.getLoopPreheader(); +  if (PreHeader) { +    OS << "\n; Preheader:"; +    PreHeader->print(OS); +    OS << "\n; Loop:"; +  } + +  for (auto *Block : L.blocks()) +    if (Block) +      Block->print(OS); +    else +      OS << "Printing <null> block"; + +  SmallVector<BasicBlock *, 8> ExitBlocks; +  L.getExitBlocks(ExitBlocks); +  if (!ExitBlocks.empty()) { +    OS << "\n; Exit blocks"; +    for (auto *Block : ExitBlocks) +      if (Block) +        Block->print(OS); +      else +        OS << "Printing <null> block"; +  } +} + +//===----------------------------------------------------------------------===// +// LoopInfo implementation +// + +char LoopInfoWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopInfoWrapperPass, "loops", "Natural Loop Information", +                      true, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(LoopInfoWrapperPass, "loops", "Natural Loop Information", +                    true, true) + +bool LoopInfoWrapperPass::runOnFunction(Function &) { +  releaseMemory(); +  LI.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree()); +  return false; +} + +void LoopInfoWrapperPass::verifyAnalysis() const { +  // LoopInfoWrapperPass is a FunctionPass, but verifying every loop in the +  // function each time verifyAnalysis is called is very expensive. The +  // -verify-loop-info option can enable this. In order to perform some +  // checking by default, LoopPass has been taught to call verifyLoop manually +  // during loop pass sequences. +  if (VerifyLoopInfo) { +    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +    LI.verify(DT); +  } +} + +void LoopInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<DominatorTreeWrapperPass>(); +} + +void LoopInfoWrapperPass::print(raw_ostream &OS, const Module *) const { +  LI.print(OS); +} + +PreservedAnalyses LoopVerifierPass::run(Function &F, +                                        FunctionAnalysisManager &AM) { +  LoopInfo &LI = AM.getResult<LoopAnalysis>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  LI.verify(DT); +  return PreservedAnalyses::all(); +} + +//===----------------------------------------------------------------------===// +// LoopBlocksDFS implementation +// + +/// Traverse the loop blocks and store the DFS result. +/// Useful for clients that just want the final DFS result and don't need to +/// visit blocks during the initial traversal. +void LoopBlocksDFS::perform(LoopInfo *LI) { +  LoopBlocksTraversal Traversal(*this, LI); +  for (LoopBlocksTraversal::POTIterator POI = Traversal.begin(), +                                        POE = Traversal.end(); +       POI != POE; ++POI) +    ; +} diff --git a/contrib/llvm/lib/Analysis/LoopPass.cpp b/contrib/llvm/lib/Analysis/LoopPass.cpp new file mode 100644 index 000000000000..07a151ce0fce --- /dev/null +++ b/contrib/llvm/lib/Analysis/LoopPass.cpp @@ -0,0 +1,390 @@ +//===- LoopPass.cpp - Loop Pass and Loop Pass Manager ---------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements LoopPass and LPPassManager. All loop optimization +// and transformation passes are derived from LoopPass. LPPassManager is +// responsible for managing LoopPasses. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/OptBisect.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-pass-manager" + +namespace { + +/// PrintLoopPass - Print a Function corresponding to a Loop. +/// +class PrintLoopPassWrapper : public LoopPass { +  raw_ostream &OS; +  std::string Banner; + +public: +  static char ID; +  PrintLoopPassWrapper() : LoopPass(ID), OS(dbgs()) {} +  PrintLoopPassWrapper(raw_ostream &OS, const std::string &Banner) +      : LoopPass(ID), OS(OS), Banner(Banner) {} + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.setPreservesAll(); +  } + +  bool runOnLoop(Loop *L, LPPassManager &) override { +    auto BBI = llvm::find_if(L->blocks(), [](BasicBlock *BB) { return BB; }); +    if (BBI != L->blocks().end() && +        isFunctionInPrintList((*BBI)->getParent()->getName())) { +      printLoop(*L, OS, Banner); +    } +    return false; +  } + +  StringRef getPassName() const override { return "Print Loop IR"; } +}; + +char PrintLoopPassWrapper::ID = 0; +} + +//===----------------------------------------------------------------------===// +// LPPassManager +// + +char LPPassManager::ID = 0; + +LPPassManager::LPPassManager() +  : FunctionPass(ID), PMDataManager() { +  LI = nullptr; +  CurrentLoop = nullptr; +} + +// Insert loop into loop nest (LoopInfo) and loop queue (LQ). +void LPPassManager::addLoop(Loop &L) { +  if (!L.getParentLoop()) { +    // This is the top level loop. +    LQ.push_front(&L); +    return; +  } + +  // Insert L into the loop queue after the parent loop. +  for (auto I = LQ.begin(), E = LQ.end(); I != E; ++I) { +    if (*I == L.getParentLoop()) { +      // deque does not support insert after. +      ++I; +      LQ.insert(I, 1, &L); +      return; +    } +  } +} + +/// cloneBasicBlockSimpleAnalysis - Invoke cloneBasicBlockAnalysis hook for +/// all loop passes. +void LPPassManager::cloneBasicBlockSimpleAnalysis(BasicBlock *From, +                                                  BasicBlock *To, Loop *L) { +  for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +    LoopPass *LP = getContainedPass(Index); +    LP->cloneBasicBlockAnalysis(From, To, L); +  } +} + +/// deleteSimpleAnalysisValue - Invoke deleteAnalysisValue hook for all passes. +void LPPassManager::deleteSimpleAnalysisValue(Value *V, Loop *L) { +  if (BasicBlock *BB = dyn_cast<BasicBlock>(V)) { +    for (Instruction &I : *BB) { +      deleteSimpleAnalysisValue(&I, L); +    } +  } +  for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +    LoopPass *LP = getContainedPass(Index); +    LP->deleteAnalysisValue(V, L); +  } +} + +/// Invoke deleteAnalysisLoop hook for all passes. +void LPPassManager::deleteSimpleAnalysisLoop(Loop *L) { +  for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +    LoopPass *LP = getContainedPass(Index); +    LP->deleteAnalysisLoop(L); +  } +} + + +// Recurse through all subloops and all loops  into LQ. +static void addLoopIntoQueue(Loop *L, std::deque<Loop *> &LQ) { +  LQ.push_back(L); +  for (Loop *I : reverse(*L)) +    addLoopIntoQueue(I, LQ); +} + +/// Pass Manager itself does not invalidate any analysis info. +void LPPassManager::getAnalysisUsage(AnalysisUsage &Info) const { +  // LPPassManager needs LoopInfo. In the long term LoopInfo class will +  // become part of LPPassManager. +  Info.addRequired<LoopInfoWrapperPass>(); +  Info.addRequired<DominatorTreeWrapperPass>(); +  Info.setPreservesAll(); +} + +void LPPassManager::markLoopAsDeleted(Loop &L) { +  assert((&L == CurrentLoop || CurrentLoop->contains(&L)) && +         "Must not delete loop outside the current loop tree!"); +  // If this loop appears elsewhere within the queue, we also need to remove it +  // there. However, we have to be careful to not remove the back of the queue +  // as that is assumed to match the current loop. +  assert(LQ.back() == CurrentLoop && "Loop queue back isn't the current loop!"); +  LQ.erase(std::remove(LQ.begin(), LQ.end(), &L), LQ.end()); + +  if (&L == CurrentLoop) { +    CurrentLoopDeleted = true; +    // Add this loop back onto the back of the queue to preserve our invariants. +    LQ.push_back(&L); +  } +} + +/// run - Execute all of the passes scheduled for execution.  Keep track of +/// whether any of the passes modifies the function, and if so, return true. +bool LPPassManager::runOnFunction(Function &F) { +  auto &LIWP = getAnalysis<LoopInfoWrapperPass>(); +  LI = &LIWP.getLoopInfo(); +  Module &M = *F.getParent(); +#if 0 +  DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +#endif +  bool Changed = false; + +  // Collect inherited analysis from Module level pass manager. +  populateInheritedAnalysis(TPM->activeStack); + +  // Populate the loop queue in reverse program order. There is no clear need to +  // process sibling loops in either forward or reverse order. There may be some +  // advantage in deleting uses in a later loop before optimizing the +  // definitions in an earlier loop. If we find a clear reason to process in +  // forward order, then a forward variant of LoopPassManager should be created. +  // +  // Note that LoopInfo::iterator visits loops in reverse program +  // order. Here, reverse_iterator gives us a forward order, and the LoopQueue +  // reverses the order a third time by popping from the back. +  for (Loop *L : reverse(*LI)) +    addLoopIntoQueue(L, LQ); + +  if (LQ.empty()) // No loops, skip calling finalizers +    return false; + +  // Initialization +  for (Loop *L : LQ) { +    for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +      LoopPass *P = getContainedPass(Index); +      Changed |= P->doInitialization(L, *this); +    } +  } + +  // Walk Loops +  unsigned InstrCount = 0; +  bool EmitICRemark = M.shouldEmitInstrCountChangedRemark(); +  while (!LQ.empty()) { +    CurrentLoopDeleted = false; +    CurrentLoop = LQ.back(); + +    // Run all passes on the current Loop. +    for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +      LoopPass *P = getContainedPass(Index); + +      dumpPassInfo(P, EXECUTION_MSG, ON_LOOP_MSG, +                   CurrentLoop->getHeader()->getName()); +      dumpRequiredSet(P); + +      initializeAnalysisImpl(P); + +      { +        PassManagerPrettyStackEntry X(P, *CurrentLoop->getHeader()); +        TimeRegion PassTimer(getPassTimer(P)); +        if (EmitICRemark) +          InstrCount = initSizeRemarkInfo(M); +        Changed |= P->runOnLoop(CurrentLoop, *this); +        if (EmitICRemark) +          emitInstrCountChangedRemark(P, M, InstrCount); +      } + +      if (Changed) +        dumpPassInfo(P, MODIFICATION_MSG, ON_LOOP_MSG, +                     CurrentLoopDeleted ? "<deleted loop>" +                                        : CurrentLoop->getName()); +      dumpPreservedSet(P); + +      if (CurrentLoopDeleted) { +        // Notify passes that the loop is being deleted. +        deleteSimpleAnalysisLoop(CurrentLoop); +      } else { +        // Manually check that this loop is still healthy. This is done +        // instead of relying on LoopInfo::verifyLoop since LoopInfo +        // is a function pass and it's really expensive to verify every +        // loop in the function every time. That level of checking can be +        // enabled with the -verify-loop-info option. +        { +          TimeRegion PassTimer(getPassTimer(&LIWP)); +          CurrentLoop->verifyLoop(); +        } +        // Here we apply same reasoning as in the above case. Only difference +        // is that LPPassManager might run passes which do not require LCSSA +        // form (LoopPassPrinter for example). We should skip verification for +        // such passes. +        // FIXME: Loop-sink currently break LCSSA. Fix it and reenable the +        // verification! +#if 0 +        if (mustPreserveAnalysisID(LCSSAVerificationPass::ID)) +          assert(CurrentLoop->isRecursivelyLCSSAForm(*DT, *LI)); +#endif + +        // Then call the regular verifyAnalysis functions. +        verifyPreservedAnalysis(P); + +        F.getContext().yield(); +      } + +      removeNotPreservedAnalysis(P); +      recordAvailableAnalysis(P); +      removeDeadPasses(P, +                       CurrentLoopDeleted ? "<deleted>" +                                          : CurrentLoop->getHeader()->getName(), +                       ON_LOOP_MSG); + +      if (CurrentLoopDeleted) +        // Do not run other passes on this loop. +        break; +    } + +    // If the loop was deleted, release all the loop passes. This frees up +    // some memory, and avoids trouble with the pass manager trying to call +    // verifyAnalysis on them. +    if (CurrentLoopDeleted) { +      for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +        Pass *P = getContainedPass(Index); +        freePass(P, "<deleted>", ON_LOOP_MSG); +      } +    } + +    // Pop the loop from queue after running all passes. +    LQ.pop_back(); +  } + +  // Finalization +  for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +    LoopPass *P = getContainedPass(Index); +    Changed |= P->doFinalization(); +  } + +  return Changed; +} + +/// Print passes managed by this manager +void LPPassManager::dumpPassStructure(unsigned Offset) { +  errs().indent(Offset*2) << "Loop Pass Manager\n"; +  for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +    Pass *P = getContainedPass(Index); +    P->dumpPassStructure(Offset + 1); +    dumpLastUses(P, Offset+1); +  } +} + + +//===----------------------------------------------------------------------===// +// LoopPass + +Pass *LoopPass::createPrinterPass(raw_ostream &O, +                                  const std::string &Banner) const { +  return new PrintLoopPassWrapper(O, Banner); +} + +// Check if this pass is suitable for the current LPPassManager, if +// available. This pass P is not suitable for a LPPassManager if P +// is not preserving higher level analysis info used by other +// LPPassManager passes. In such case, pop LPPassManager from the +// stack. This will force assignPassManager() to create new +// LPPassManger as expected. +void LoopPass::preparePassManager(PMStack &PMS) { + +  // Find LPPassManager +  while (!PMS.empty() && +         PMS.top()->getPassManagerType() > PMT_LoopPassManager) +    PMS.pop(); + +  // If this pass is destroying high level information that is used +  // by other passes that are managed by LPM then do not insert +  // this pass in current LPM. Use new LPPassManager. +  if (PMS.top()->getPassManagerType() == PMT_LoopPassManager && +      !PMS.top()->preserveHigherLevelAnalysis(this)) +    PMS.pop(); +} + +/// Assign pass manager to manage this pass. +void LoopPass::assignPassManager(PMStack &PMS, +                                 PassManagerType PreferredType) { +  // Find LPPassManager +  while (!PMS.empty() && +         PMS.top()->getPassManagerType() > PMT_LoopPassManager) +    PMS.pop(); + +  LPPassManager *LPPM; +  if (PMS.top()->getPassManagerType() == PMT_LoopPassManager) +    LPPM = (LPPassManager*)PMS.top(); +  else { +    // Create new Loop Pass Manager if it does not exist. +    assert (!PMS.empty() && "Unable to create Loop Pass Manager"); +    PMDataManager *PMD = PMS.top(); + +    // [1] Create new Loop Pass Manager +    LPPM = new LPPassManager(); +    LPPM->populateInheritedAnalysis(PMS); + +    // [2] Set up new manager's top level manager +    PMTopLevelManager *TPM = PMD->getTopLevelManager(); +    TPM->addIndirectPassManager(LPPM); + +    // [3] Assign manager to manage this new manager. This may create +    // and push new managers into PMS +    Pass *P = LPPM->getAsPass(); +    TPM->schedulePass(P); + +    // [4] Push new manager into PMS +    PMS.push(LPPM); +  } + +  LPPM->add(this); +} + +bool LoopPass::skipLoop(const Loop *L) const { +  const Function *F = L->getHeader()->getParent(); +  if (!F) +    return false; +  // Check the opt bisect limit. +  LLVMContext &Context = F->getContext(); +  if (!Context.getOptPassGate().shouldRunPass(this, *L)) +    return true; +  // Check for the OptimizeNone attribute. +  if (F->hasFnAttribute(Attribute::OptimizeNone)) { +    // FIXME: Report this to dbgs() only once per function. +    LLVM_DEBUG(dbgs() << "Skipping pass '" << getPassName() << "' in function " +                      << F->getName() << "\n"); +    // FIXME: Delete loop from pass manager's queue? +    return true; +  } +  return false; +} + +char LCSSAVerificationPass::ID = 0; +INITIALIZE_PASS(LCSSAVerificationPass, "lcssa-verification", "LCSSA Verifier", +                false, false) diff --git a/contrib/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp b/contrib/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp new file mode 100644 index 000000000000..c8b91a7a1a51 --- /dev/null +++ b/contrib/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp @@ -0,0 +1,215 @@ +//===- LoopUnrollAnalyzer.cpp - Unrolling Effect Estimation -----*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements UnrolledInstAnalyzer class. It's used for predicting +// potential effects that loop unrolling might have, such as enabling constant +// propagation and other optimizations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopUnrollAnalyzer.h" + +using namespace llvm; + +/// Try to simplify instruction \param I using its SCEV expression. +/// +/// The idea is that some AddRec expressions become constants, which then +/// could trigger folding of other instructions. However, that only happens +/// for expressions whose start value is also constant, which isn't always the +/// case. In another common and important case the start value is just some +/// address (i.e. SCEVUnknown) - in this case we compute the offset and save +/// it along with the base address instead. +bool UnrolledInstAnalyzer::simplifyInstWithSCEV(Instruction *I) { +  if (!SE.isSCEVable(I->getType())) +    return false; + +  const SCEV *S = SE.getSCEV(I); +  if (auto *SC = dyn_cast<SCEVConstant>(S)) { +    SimplifiedValues[I] = SC->getValue(); +    return true; +  } + +  auto *AR = dyn_cast<SCEVAddRecExpr>(S); +  if (!AR || AR->getLoop() != L) +    return false; + +  const SCEV *ValueAtIteration = AR->evaluateAtIteration(IterationNumber, SE); +  // Check if the AddRec expression becomes a constant. +  if (auto *SC = dyn_cast<SCEVConstant>(ValueAtIteration)) { +    SimplifiedValues[I] = SC->getValue(); +    return true; +  } + +  // Check if the offset from the base address becomes a constant. +  auto *Base = dyn_cast<SCEVUnknown>(SE.getPointerBase(S)); +  if (!Base) +    return false; +  auto *Offset = +      dyn_cast<SCEVConstant>(SE.getMinusSCEV(ValueAtIteration, Base)); +  if (!Offset) +    return false; +  SimplifiedAddress Address; +  Address.Base = Base->getValue(); +  Address.Offset = Offset->getValue(); +  SimplifiedAddresses[I] = Address; +  return false; +} + +/// Try to simplify binary operator I. +/// +/// TODO: Probably it's worth to hoist the code for estimating the +/// simplifications effects to a separate class, since we have a very similar +/// code in InlineCost already. +bool UnrolledInstAnalyzer::visitBinaryOperator(BinaryOperator &I) { +  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); +  if (!isa<Constant>(LHS)) +    if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) +      LHS = SimpleLHS; +  if (!isa<Constant>(RHS)) +    if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) +      RHS = SimpleRHS; + +  Value *SimpleV = nullptr; +  const DataLayout &DL = I.getModule()->getDataLayout(); +  if (auto FI = dyn_cast<FPMathOperator>(&I)) +    SimpleV = +        SimplifyFPBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL); +  else +    SimpleV = SimplifyBinOp(I.getOpcode(), LHS, RHS, DL); + +  if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) +    SimplifiedValues[&I] = C; + +  if (SimpleV) +    return true; +  return Base::visitBinaryOperator(I); +} + +/// Try to fold load I. +bool UnrolledInstAnalyzer::visitLoad(LoadInst &I) { +  Value *AddrOp = I.getPointerOperand(); + +  auto AddressIt = SimplifiedAddresses.find(AddrOp); +  if (AddressIt == SimplifiedAddresses.end()) +    return false; +  ConstantInt *SimplifiedAddrOp = AddressIt->second.Offset; + +  auto *GV = dyn_cast<GlobalVariable>(AddressIt->second.Base); +  // We're only interested in loads that can be completely folded to a +  // constant. +  if (!GV || !GV->hasDefinitiveInitializer() || !GV->isConstant()) +    return false; + +  ConstantDataSequential *CDS = +      dyn_cast<ConstantDataSequential>(GV->getInitializer()); +  if (!CDS) +    return false; + +  // We might have a vector load from an array. FIXME: for now we just bail +  // out in this case, but we should be able to resolve and simplify such +  // loads. +  if (CDS->getElementType() != I.getType()) +    return false; + +  unsigned ElemSize = CDS->getElementType()->getPrimitiveSizeInBits() / 8U; +  if (SimplifiedAddrOp->getValue().getActiveBits() > 64) +    return false; +  int64_t SimplifiedAddrOpV = SimplifiedAddrOp->getSExtValue(); +  if (SimplifiedAddrOpV < 0) { +    // FIXME: For now we conservatively ignore out of bound accesses, but +    // we're allowed to perform the optimization in this case. +    return false; +  } +  uint64_t Index = static_cast<uint64_t>(SimplifiedAddrOpV) / ElemSize; +  if (Index >= CDS->getNumElements()) { +    // FIXME: For now we conservatively ignore out of bound accesses, but +    // we're allowed to perform the optimization in this case. +    return false; +  } + +  Constant *CV = CDS->getElementAsConstant(Index); +  assert(CV && "Constant expected."); +  SimplifiedValues[&I] = CV; + +  return true; +} + +/// Try to simplify cast instruction. +bool UnrolledInstAnalyzer::visitCastInst(CastInst &I) { +  // Propagate constants through casts. +  Constant *COp = dyn_cast<Constant>(I.getOperand(0)); +  if (!COp) +    COp = SimplifiedValues.lookup(I.getOperand(0)); + +  // If we know a simplified value for this operand and cast is valid, save the +  // result to SimplifiedValues. +  // The cast can be invalid, because SimplifiedValues contains results of SCEV +  // analysis, which operates on integers (and, e.g., might convert i8* null to +  // i32 0). +  if (COp && CastInst::castIsValid(I.getOpcode(), COp, I.getType())) { +    if (Constant *C = +            ConstantExpr::getCast(I.getOpcode(), COp, I.getType())) { +      SimplifiedValues[&I] = C; +      return true; +    } +  } + +  return Base::visitCastInst(I); +} + +/// Try to simplify cmp instruction. +bool UnrolledInstAnalyzer::visitCmpInst(CmpInst &I) { +  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + +  // First try to handle simplified comparisons. +  if (!isa<Constant>(LHS)) +    if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) +      LHS = SimpleLHS; +  if (!isa<Constant>(RHS)) +    if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) +      RHS = SimpleRHS; + +  if (!isa<Constant>(LHS) && !isa<Constant>(RHS)) { +    auto SimplifiedLHS = SimplifiedAddresses.find(LHS); +    if (SimplifiedLHS != SimplifiedAddresses.end()) { +      auto SimplifiedRHS = SimplifiedAddresses.find(RHS); +      if (SimplifiedRHS != SimplifiedAddresses.end()) { +        SimplifiedAddress &LHSAddr = SimplifiedLHS->second; +        SimplifiedAddress &RHSAddr = SimplifiedRHS->second; +        if (LHSAddr.Base == RHSAddr.Base) { +          LHS = LHSAddr.Offset; +          RHS = RHSAddr.Offset; +        } +      } +    } +  } + +  if (Constant *CLHS = dyn_cast<Constant>(LHS)) { +    if (Constant *CRHS = dyn_cast<Constant>(RHS)) { +      if (CLHS->getType() == CRHS->getType()) { +        if (Constant *C = ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) { +          SimplifiedValues[&I] = C; +          return true; +        } +      } +    } +  } + +  return Base::visitCmpInst(I); +} + +bool UnrolledInstAnalyzer::visitPHINode(PHINode &PN) { +  // Run base visitor first. This way we can gather some useful for later +  // analysis information. +  if (Base::visitPHINode(PN)) +    return true; + +  // The loop induction PHI nodes are definitionally free. +  return PN.getParent() == L->getHeader(); +} diff --git a/contrib/llvm/lib/Analysis/MemDepPrinter.cpp b/contrib/llvm/lib/Analysis/MemDepPrinter.cpp new file mode 100644 index 000000000000..5a6bbd7b2ac6 --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemDepPrinter.cpp @@ -0,0 +1,166 @@ +//===- MemDepPrinter.cpp - Printer for MemoryDependenceAnalysis -----------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +namespace { +  struct MemDepPrinter : public FunctionPass { +    const Function *F; + +    enum DepType { +      Clobber = 0, +      Def, +      NonFuncLocal, +      Unknown +    }; + +    static const char *const DepTypeStr[]; + +    typedef PointerIntPair<const Instruction *, 2, DepType> InstTypePair; +    typedef std::pair<InstTypePair, const BasicBlock *> Dep; +    typedef SmallSetVector<Dep, 4> DepSet; +    typedef DenseMap<const Instruction *, DepSet> DepSetMap; +    DepSetMap Deps; + +    static char ID; // Pass identifcation, replacement for typeid +    MemDepPrinter() : FunctionPass(ID) { +      initializeMemDepPrinterPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnFunction(Function &F) override; + +    void print(raw_ostream &OS, const Module * = nullptr) const override; + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.addRequiredTransitive<AAResultsWrapperPass>(); +      AU.addRequiredTransitive<MemoryDependenceWrapperPass>(); +      AU.setPreservesAll(); +    } + +    void releaseMemory() override { +      Deps.clear(); +      F = nullptr; +    } + +  private: +    static InstTypePair getInstTypePair(MemDepResult dep) { +      if (dep.isClobber()) +        return InstTypePair(dep.getInst(), Clobber); +      if (dep.isDef()) +        return InstTypePair(dep.getInst(), Def); +      if (dep.isNonFuncLocal()) +        return InstTypePair(dep.getInst(), NonFuncLocal); +      assert(dep.isUnknown() && "unexpected dependence type"); +      return InstTypePair(dep.getInst(), Unknown); +    } +    static InstTypePair getInstTypePair(const Instruction* inst, DepType type) { +      return InstTypePair(inst, type); +    } +  }; +} + +char MemDepPrinter::ID = 0; +INITIALIZE_PASS_BEGIN(MemDepPrinter, "print-memdeps", +                      "Print MemDeps of function", false, true) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_END(MemDepPrinter, "print-memdeps", +                      "Print MemDeps of function", false, true) + +FunctionPass *llvm::createMemDepPrinter() { +  return new MemDepPrinter(); +} + +const char *const MemDepPrinter::DepTypeStr[] +  = {"Clobber", "Def", "NonFuncLocal", "Unknown"}; + +bool MemDepPrinter::runOnFunction(Function &F) { +  this->F = &F; +  MemoryDependenceResults &MDA = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + +  // All this code uses non-const interfaces because MemDep is not +  // const-friendly, though nothing is actually modified. +  for (auto &I : instructions(F)) { +    Instruction *Inst = &I; + +    if (!Inst->mayReadFromMemory() && !Inst->mayWriteToMemory()) +      continue; + +    MemDepResult Res = MDA.getDependency(Inst); +    if (!Res.isNonLocal()) { +      Deps[Inst].insert(std::make_pair(getInstTypePair(Res), +                                       static_cast<BasicBlock *>(nullptr))); +    } else if (auto CS = CallSite(Inst)) { +      const MemoryDependenceResults::NonLocalDepInfo &NLDI = +        MDA.getNonLocalCallDependency(CS); + +      DepSet &InstDeps = Deps[Inst]; +      for (const NonLocalDepEntry &I : NLDI) { +        const MemDepResult &Res = I.getResult(); +        InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); +      } +    } else { +      SmallVector<NonLocalDepResult, 4> NLDI; +      assert( (isa<LoadInst>(Inst) || isa<StoreInst>(Inst) || +               isa<VAArgInst>(Inst)) && "Unknown memory instruction!"); +      MDA.getNonLocalPointerDependency(Inst, NLDI); + +      DepSet &InstDeps = Deps[Inst]; +      for (const NonLocalDepResult &I : NLDI) { +        const MemDepResult &Res = I.getResult(); +        InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); +      } +    } +  } + +  return false; +} + +void MemDepPrinter::print(raw_ostream &OS, const Module *M) const { +  for (const auto &I : instructions(*F)) { +    const Instruction *Inst = &I; + +    DepSetMap::const_iterator DI = Deps.find(Inst); +    if (DI == Deps.end()) +      continue; + +    const DepSet &InstDeps = DI->second; + +    for (const auto &I : InstDeps) { +      const Instruction *DepInst = I.first.getPointer(); +      DepType type = I.first.getInt(); +      const BasicBlock *DepBB = I.second; + +      OS << "    "; +      OS << DepTypeStr[type]; +      if (DepBB) { +        OS << " in block "; +        DepBB->printAsOperand(OS, /*PrintType=*/false, M); +      } +      if (DepInst) { +        OS << " from: "; +        DepInst->print(OS); +      } +      OS << "\n"; +    } + +    Inst->print(OS); +    OS << "\n\n"; +  } +} diff --git a/contrib/llvm/lib/Analysis/MemDerefPrinter.cpp b/contrib/llvm/lib/Analysis/MemDerefPrinter.cpp new file mode 100644 index 000000000000..4a136c5a0c6d --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemDerefPrinter.cpp @@ -0,0 +1,76 @@ +//===- MemDerefPrinter.cpp - Printer for isDereferenceablePointer ---------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +namespace { +  struct MemDerefPrinter : public FunctionPass { +    SmallVector<Value *, 4> Deref; +    SmallPtrSet<Value *, 4> DerefAndAligned; + +    static char ID; // Pass identification, replacement for typeid +    MemDerefPrinter() : FunctionPass(ID) { +      initializeMemDerefPrinterPass(*PassRegistry::getPassRegistry()); +    } +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } +    bool runOnFunction(Function &F) override; +    void print(raw_ostream &OS, const Module * = nullptr) const override; +    void releaseMemory() override { +      Deref.clear(); +      DerefAndAligned.clear(); +    } +  }; +} + +char MemDerefPrinter::ID = 0; +INITIALIZE_PASS_BEGIN(MemDerefPrinter, "print-memderefs", +                      "Memory Dereferenciblity of pointers in function", false, true) +INITIALIZE_PASS_END(MemDerefPrinter, "print-memderefs", +                    "Memory Dereferenciblity of pointers in function", false, true) + +FunctionPass *llvm::createMemDerefPrinter() { +  return new MemDerefPrinter(); +} + +bool MemDerefPrinter::runOnFunction(Function &F) { +  const DataLayout &DL = F.getParent()->getDataLayout(); +  for (auto &I: instructions(F)) { +    if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { +      Value *PO = LI->getPointerOperand(); +      if (isDereferenceablePointer(PO, DL)) +        Deref.push_back(PO); +      if (isDereferenceableAndAlignedPointer(PO, LI->getAlignment(), DL)) +        DerefAndAligned.insert(PO); +    } +  } +  return false; +} + +void MemDerefPrinter::print(raw_ostream &OS, const Module *M) const { +  OS << "The following are dereferenceable:\n"; +  for (Value *V: Deref) { +    V->print(OS); +    if (DerefAndAligned.count(V)) +      OS << "\t(aligned)"; +    else +      OS << "\t(unaligned)"; +    OS << "\n\n"; +  } +} diff --git a/contrib/llvm/lib/Analysis/MemoryBuiltins.cpp b/contrib/llvm/lib/Analysis/MemoryBuiltins.cpp new file mode 100644 index 000000000000..686ad294378c --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemoryBuiltins.cpp @@ -0,0 +1,961 @@ +//===- MemoryBuiltins.cpp - Identify calls to memory builtins -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This family of functions identifies calls to builtin functions that allocate +// or free memory. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/TargetFolder.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/Utils/Local.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "memory-builtins" + +enum AllocType : uint8_t { +  OpNewLike          = 1<<0, // allocates; never returns null +  MallocLike         = 1<<1 | OpNewLike, // allocates; may return null +  CallocLike         = 1<<2, // allocates + bzero +  ReallocLike        = 1<<3, // reallocates +  StrDupLike         = 1<<4, +  MallocOrCallocLike = MallocLike | CallocLike, +  AllocLike          = MallocLike | CallocLike | StrDupLike, +  AnyAlloc           = AllocLike | ReallocLike +}; + +struct AllocFnsTy { +  AllocType AllocTy; +  unsigned NumParams; +  // First and Second size parameters (or -1 if unused) +  int FstParam, SndParam; +}; + +// FIXME: certain users need more information. E.g., SimplifyLibCalls needs to +// know which functions are nounwind, noalias, nocapture parameters, etc. +static const std::pair<LibFunc, AllocFnsTy> AllocationFnData[] = { +  {LibFunc_malloc,              {MallocLike,  1, 0,  -1}}, +  {LibFunc_valloc,              {MallocLike,  1, 0,  -1}}, +  {LibFunc_Znwj,                {OpNewLike,   1, 0,  -1}}, // new(unsigned int) +  {LibFunc_ZnwjRKSt9nothrow_t,  {MallocLike,  2, 0,  -1}}, // new(unsigned int, nothrow) +  {LibFunc_ZnwjSt11align_val_t, {OpNewLike,   2, 0,  -1}}, // new(unsigned int, align_val_t) +  {LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t, // new(unsigned int, align_val_t, nothrow) +                                {MallocLike,  3, 0,  -1}}, +  {LibFunc_Znwm,                {OpNewLike,   1, 0,  -1}}, // new(unsigned long) +  {LibFunc_ZnwmRKSt9nothrow_t,  {MallocLike,  2, 0,  -1}}, // new(unsigned long, nothrow) +  {LibFunc_ZnwmSt11align_val_t, {OpNewLike,   2, 0,  -1}}, // new(unsigned long, align_val_t) +  {LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t, // new(unsigned long, align_val_t, nothrow) +                                {MallocLike,  3, 0,  -1}}, +  {LibFunc_Znaj,                {OpNewLike,   1, 0,  -1}}, // new[](unsigned int) +  {LibFunc_ZnajRKSt9nothrow_t,  {MallocLike,  2, 0,  -1}}, // new[](unsigned int, nothrow) +  {LibFunc_ZnajSt11align_val_t, {OpNewLike,   2, 0,  -1}}, // new[](unsigned int, align_val_t) +  {LibFunc_ZnajSt11align_val_tRKSt9nothrow_t, // new[](unsigned int, align_val_t, nothrow) +                                {MallocLike,  3, 0,  -1}}, +  {LibFunc_Znam,                {OpNewLike,   1, 0,  -1}}, // new[](unsigned long) +  {LibFunc_ZnamRKSt9nothrow_t,  {MallocLike,  2, 0,  -1}}, // new[](unsigned long, nothrow) +  {LibFunc_ZnamSt11align_val_t, {OpNewLike,   2, 0,  -1}}, // new[](unsigned long, align_val_t) +  {LibFunc_ZnamSt11align_val_tRKSt9nothrow_t, // new[](unsigned long, align_val_t, nothrow) +                                 {MallocLike,  3, 0,  -1}}, +  {LibFunc_msvc_new_int,         {OpNewLike,   1, 0,  -1}}, // new(unsigned int) +  {LibFunc_msvc_new_int_nothrow, {MallocLike,  2, 0,  -1}}, // new(unsigned int, nothrow) +  {LibFunc_msvc_new_longlong,         {OpNewLike,   1, 0,  -1}}, // new(unsigned long long) +  {LibFunc_msvc_new_longlong_nothrow, {MallocLike,  2, 0,  -1}}, // new(unsigned long long, nothrow) +  {LibFunc_msvc_new_array_int,         {OpNewLike,   1, 0,  -1}}, // new[](unsigned int) +  {LibFunc_msvc_new_array_int_nothrow, {MallocLike,  2, 0,  -1}}, // new[](unsigned int, nothrow) +  {LibFunc_msvc_new_array_longlong,         {OpNewLike,   1, 0,  -1}}, // new[](unsigned long long) +  {LibFunc_msvc_new_array_longlong_nothrow, {MallocLike,  2, 0,  -1}}, // new[](unsigned long long, nothrow) +  {LibFunc_calloc,              {CallocLike,  2, 0,   1}}, +  {LibFunc_realloc,             {ReallocLike, 2, 1,  -1}}, +  {LibFunc_reallocf,            {ReallocLike, 2, 1,  -1}}, +  {LibFunc_strdup,              {StrDupLike,  1, -1, -1}}, +  {LibFunc_strndup,             {StrDupLike,  2, 1,  -1}} +  // TODO: Handle "int posix_memalign(void **, size_t, size_t)" +}; + +static const Function *getCalledFunction(const Value *V, bool LookThroughBitCast, +                                         bool &IsNoBuiltin) { +  // Don't care about intrinsics in this case. +  if (isa<IntrinsicInst>(V)) +    return nullptr; + +  if (LookThroughBitCast) +    V = V->stripPointerCasts(); + +  ImmutableCallSite CS(V); +  if (!CS.getInstruction()) +    return nullptr; + +  IsNoBuiltin = CS.isNoBuiltin(); + +  if (const Function *Callee = CS.getCalledFunction()) +    return Callee; +  return nullptr; +} + +/// Returns the allocation data for the given value if it's either a call to a +/// known allocation function, or a call to a function with the allocsize +/// attribute. +static Optional<AllocFnsTy> +getAllocationDataForFunction(const Function *Callee, AllocType AllocTy, +                             const TargetLibraryInfo *TLI) { +  // Make sure that the function is available. +  StringRef FnName = Callee->getName(); +  LibFunc TLIFn; +  if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn)) +    return None; + +  const auto *Iter = find_if( +      AllocationFnData, [TLIFn](const std::pair<LibFunc, AllocFnsTy> &P) { +        return P.first == TLIFn; +      }); + +  if (Iter == std::end(AllocationFnData)) +    return None; + +  const AllocFnsTy *FnData = &Iter->second; +  if ((FnData->AllocTy & AllocTy) != FnData->AllocTy) +    return None; + +  // Check function prototype. +  int FstParam = FnData->FstParam; +  int SndParam = FnData->SndParam; +  FunctionType *FTy = Callee->getFunctionType(); + +  if (FTy->getReturnType() == Type::getInt8PtrTy(FTy->getContext()) && +      FTy->getNumParams() == FnData->NumParams && +      (FstParam < 0 || +       (FTy->getParamType(FstParam)->isIntegerTy(32) || +        FTy->getParamType(FstParam)->isIntegerTy(64))) && +      (SndParam < 0 || +       FTy->getParamType(SndParam)->isIntegerTy(32) || +       FTy->getParamType(SndParam)->isIntegerTy(64))) +    return *FnData; +  return None; +} + +static Optional<AllocFnsTy> getAllocationData(const Value *V, AllocType AllocTy, +                                              const TargetLibraryInfo *TLI, +                                              bool LookThroughBitCast = false) { +  bool IsNoBuiltinCall; +  if (const Function *Callee = +          getCalledFunction(V, LookThroughBitCast, IsNoBuiltinCall)) +    if (!IsNoBuiltinCall) +      return getAllocationDataForFunction(Callee, AllocTy, TLI); +  return None; +} + +static Optional<AllocFnsTy> getAllocationSize(const Value *V, +                                              const TargetLibraryInfo *TLI) { +  bool IsNoBuiltinCall; +  const Function *Callee = +      getCalledFunction(V, /*LookThroughBitCast=*/false, IsNoBuiltinCall); +  if (!Callee) +    return None; + +  // Prefer to use existing information over allocsize. This will give us an +  // accurate AllocTy. +  if (!IsNoBuiltinCall) +    if (Optional<AllocFnsTy> Data = +            getAllocationDataForFunction(Callee, AnyAlloc, TLI)) +      return Data; + +  Attribute Attr = Callee->getFnAttribute(Attribute::AllocSize); +  if (Attr == Attribute()) +    return None; + +  std::pair<unsigned, Optional<unsigned>> Args = Attr.getAllocSizeArgs(); + +  AllocFnsTy Result; +  // Because allocsize only tells us how many bytes are allocated, we're not +  // really allowed to assume anything, so we use MallocLike. +  Result.AllocTy = MallocLike; +  Result.NumParams = Callee->getNumOperands(); +  Result.FstParam = Args.first; +  Result.SndParam = Args.second.getValueOr(-1); +  return Result; +} + +static bool hasNoAliasAttr(const Value *V, bool LookThroughBitCast) { +  ImmutableCallSite CS(LookThroughBitCast ? V->stripPointerCasts() : V); +  return CS && CS.hasRetAttr(Attribute::NoAlias); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates or reallocates memory (either malloc, calloc, realloc, or strdup +/// like). +bool llvm::isAllocationFn(const Value *V, const TargetLibraryInfo *TLI, +                          bool LookThroughBitCast) { +  return getAllocationData(V, AnyAlloc, TLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a function that returns a +/// NoAlias pointer (including malloc/calloc/realloc/strdup-like functions). +bool llvm::isNoAliasFn(const Value *V, const TargetLibraryInfo *TLI, +                       bool LookThroughBitCast) { +  // it's safe to consider realloc as noalias since accessing the original +  // pointer is undefined behavior +  return isAllocationFn(V, TLI, LookThroughBitCast) || +         hasNoAliasAttr(V, LookThroughBitCast); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates uninitialized memory (such as malloc). +bool llvm::isMallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, +                          bool LookThroughBitCast) { +  return getAllocationData(V, MallocLike, TLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates zero-filled memory (such as calloc). +bool llvm::isCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, +                          bool LookThroughBitCast) { +  return getAllocationData(V, CallocLike, TLI, LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates memory similar to malloc or calloc. +bool llvm::isMallocOrCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, +                                  bool LookThroughBitCast) { +  return getAllocationData(V, MallocOrCallocLike, TLI, +                           LookThroughBitCast).hasValue(); +} + +/// Tests if a value is a call or invoke to a library function that +/// allocates memory (either malloc, calloc, or strdup like). +bool llvm::isAllocLikeFn(const Value *V, const TargetLibraryInfo *TLI, +                         bool LookThroughBitCast) { +  return getAllocationData(V, AllocLike, TLI, LookThroughBitCast).hasValue(); +} + +/// extractMallocCall - Returns the corresponding CallInst if the instruction +/// is a malloc call.  Since CallInst::CreateMalloc() only creates calls, we +/// ignore InvokeInst here. +const CallInst *llvm::extractMallocCall(const Value *I, +                                        const TargetLibraryInfo *TLI) { +  return isMallocLikeFn(I, TLI) ? dyn_cast<CallInst>(I) : nullptr; +} + +static Value *computeArraySize(const CallInst *CI, const DataLayout &DL, +                               const TargetLibraryInfo *TLI, +                               bool LookThroughSExt = false) { +  if (!CI) +    return nullptr; + +  // The size of the malloc's result type must be known to determine array size. +  Type *T = getMallocAllocatedType(CI, TLI); +  if (!T || !T->isSized()) +    return nullptr; + +  unsigned ElementSize = DL.getTypeAllocSize(T); +  if (StructType *ST = dyn_cast<StructType>(T)) +    ElementSize = DL.getStructLayout(ST)->getSizeInBytes(); + +  // If malloc call's arg can be determined to be a multiple of ElementSize, +  // return the multiple.  Otherwise, return NULL. +  Value *MallocArg = CI->getArgOperand(0); +  Value *Multiple = nullptr; +  if (ComputeMultiple(MallocArg, ElementSize, Multiple, LookThroughSExt)) +    return Multiple; + +  return nullptr; +} + +/// getMallocType - Returns the PointerType resulting from the malloc call. +/// The PointerType depends on the number of bitcast uses of the malloc call: +///   0: PointerType is the calls' return type. +///   1: PointerType is the bitcast's result type. +///  >1: Unique PointerType cannot be determined, return NULL. +PointerType *llvm::getMallocType(const CallInst *CI, +                                 const TargetLibraryInfo *TLI) { +  assert(isMallocLikeFn(CI, TLI) && "getMallocType and not malloc call"); + +  PointerType *MallocType = nullptr; +  unsigned NumOfBitCastUses = 0; + +  // Determine if CallInst has a bitcast use. +  for (Value::const_user_iterator UI = CI->user_begin(), E = CI->user_end(); +       UI != E;) +    if (const BitCastInst *BCI = dyn_cast<BitCastInst>(*UI++)) { +      MallocType = cast<PointerType>(BCI->getDestTy()); +      NumOfBitCastUses++; +    } + +  // Malloc call has 1 bitcast use, so type is the bitcast's destination type. +  if (NumOfBitCastUses == 1) +    return MallocType; + +  // Malloc call was not bitcast, so type is the malloc function's return type. +  if (NumOfBitCastUses == 0) +    return cast<PointerType>(CI->getType()); + +  // Type could not be determined. +  return nullptr; +} + +/// getMallocAllocatedType - Returns the Type allocated by malloc call. +/// The Type depends on the number of bitcast uses of the malloc call: +///   0: PointerType is the malloc calls' return type. +///   1: PointerType is the bitcast's result type. +///  >1: Unique PointerType cannot be determined, return NULL. +Type *llvm::getMallocAllocatedType(const CallInst *CI, +                                   const TargetLibraryInfo *TLI) { +  PointerType *PT = getMallocType(CI, TLI); +  return PT ? PT->getElementType() : nullptr; +} + +/// getMallocArraySize - Returns the array size of a malloc call.  If the +/// argument passed to malloc is a multiple of the size of the malloced type, +/// then return that multiple.  For non-array mallocs, the multiple is +/// constant 1.  Otherwise, return NULL for mallocs whose array size cannot be +/// determined. +Value *llvm::getMallocArraySize(CallInst *CI, const DataLayout &DL, +                                const TargetLibraryInfo *TLI, +                                bool LookThroughSExt) { +  assert(isMallocLikeFn(CI, TLI) && "getMallocArraySize and not malloc call"); +  return computeArraySize(CI, DL, TLI, LookThroughSExt); +} + +/// extractCallocCall - Returns the corresponding CallInst if the instruction +/// is a calloc call. +const CallInst *llvm::extractCallocCall(const Value *I, +                                        const TargetLibraryInfo *TLI) { +  return isCallocLikeFn(I, TLI) ? cast<CallInst>(I) : nullptr; +} + +/// isFreeCall - Returns non-null if the value is a call to the builtin free() +const CallInst *llvm::isFreeCall(const Value *I, const TargetLibraryInfo *TLI) { +  bool IsNoBuiltinCall; +  const Function *Callee = +      getCalledFunction(I, /*LookThroughBitCast=*/false, IsNoBuiltinCall); +  if (Callee == nullptr || IsNoBuiltinCall) +    return nullptr; + +  StringRef FnName = Callee->getName(); +  LibFunc TLIFn; +  if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn)) +    return nullptr; + +  unsigned ExpectedNumParams; +  if (TLIFn == LibFunc_free || +      TLIFn == LibFunc_ZdlPv || // operator delete(void*) +      TLIFn == LibFunc_ZdaPv || // operator delete[](void*) +      TLIFn == LibFunc_msvc_delete_ptr32 || // operator delete(void*) +      TLIFn == LibFunc_msvc_delete_ptr64 || // operator delete(void*) +      TLIFn == LibFunc_msvc_delete_array_ptr32 || // operator delete[](void*) +      TLIFn == LibFunc_msvc_delete_array_ptr64)   // operator delete[](void*) +    ExpectedNumParams = 1; +  else if (TLIFn == LibFunc_ZdlPvj ||              // delete(void*, uint) +           TLIFn == LibFunc_ZdlPvm ||              // delete(void*, ulong) +           TLIFn == LibFunc_ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) +           TLIFn == LibFunc_ZdlPvSt11align_val_t || // delete(void*, align_val_t) +           TLIFn == LibFunc_ZdaPvj ||              // delete[](void*, uint) +           TLIFn == LibFunc_ZdaPvm ||              // delete[](void*, ulong) +           TLIFn == LibFunc_ZdaPvRKSt9nothrow_t || // delete[](void*, nothrow) +           TLIFn == LibFunc_ZdaPvSt11align_val_t || // delete[](void*, align_val_t) +           TLIFn == LibFunc_msvc_delete_ptr32_int ||      // delete(void*, uint) +           TLIFn == LibFunc_msvc_delete_ptr64_longlong || // delete(void*, ulonglong) +           TLIFn == LibFunc_msvc_delete_ptr32_nothrow || // delete(void*, nothrow) +           TLIFn == LibFunc_msvc_delete_ptr64_nothrow || // delete(void*, nothrow) +           TLIFn == LibFunc_msvc_delete_array_ptr32_int ||      // delete[](void*, uint) +           TLIFn == LibFunc_msvc_delete_array_ptr64_longlong || // delete[](void*, ulonglong) +           TLIFn == LibFunc_msvc_delete_array_ptr32_nothrow || // delete[](void*, nothrow) +           TLIFn == LibFunc_msvc_delete_array_ptr64_nothrow)   // delete[](void*, nothrow) +    ExpectedNumParams = 2; +  else if (TLIFn == LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t || // delete(void*, align_val_t, nothrow) +           TLIFn == LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t) // delete[](void*, align_val_t, nothrow) +    ExpectedNumParams = 3; +  else +    return nullptr; + +  // Check free prototype. +  // FIXME: workaround for PR5130, this will be obsolete when a nobuiltin +  // attribute will exist. +  FunctionType *FTy = Callee->getFunctionType(); +  if (!FTy->getReturnType()->isVoidTy()) +    return nullptr; +  if (FTy->getNumParams() != ExpectedNumParams) +    return nullptr; +  if (FTy->getParamType(0) != Type::getInt8PtrTy(Callee->getContext())) +    return nullptr; + +  return dyn_cast<CallInst>(I); +} + +//===----------------------------------------------------------------------===// +//  Utility functions to compute size of objects. +// +static APInt getSizeWithOverflow(const SizeOffsetType &Data) { +  if (Data.second.isNegative() || Data.first.ult(Data.second)) +    return APInt(Data.first.getBitWidth(), 0); +  return Data.first - Data.second; +} + +/// Compute the size of the object pointed by Ptr. Returns true and the +/// object size in Size if successful, and false otherwise. +/// If RoundToAlign is true, then Size is rounded up to the alignment of +/// allocas, byval arguments, and global variables. +bool llvm::getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, +                         const TargetLibraryInfo *TLI, ObjectSizeOpts Opts) { +  ObjectSizeOffsetVisitor Visitor(DL, TLI, Ptr->getContext(), Opts); +  SizeOffsetType Data = Visitor.compute(const_cast<Value*>(Ptr)); +  if (!Visitor.bothKnown(Data)) +    return false; + +  Size = getSizeWithOverflow(Data).getZExtValue(); +  return true; +} + +ConstantInt *llvm::lowerObjectSizeCall(IntrinsicInst *ObjectSize, +                                       const DataLayout &DL, +                                       const TargetLibraryInfo *TLI, +                                       bool MustSucceed) { +  assert(ObjectSize->getIntrinsicID() == Intrinsic::objectsize && +         "ObjectSize must be a call to llvm.objectsize!"); + +  bool MaxVal = cast<ConstantInt>(ObjectSize->getArgOperand(1))->isZero(); +  ObjectSizeOpts EvalOptions; +  // Unless we have to fold this to something, try to be as accurate as +  // possible. +  if (MustSucceed) +    EvalOptions.EvalMode = +        MaxVal ? ObjectSizeOpts::Mode::Max : ObjectSizeOpts::Mode::Min; +  else +    EvalOptions.EvalMode = ObjectSizeOpts::Mode::Exact; + +  EvalOptions.NullIsUnknownSize = +      cast<ConstantInt>(ObjectSize->getArgOperand(2))->isOne(); + +  // FIXME: Does it make sense to just return a failure value if the size won't +  // fit in the output and `!MustSucceed`? +  uint64_t Size; +  auto *ResultType = cast<IntegerType>(ObjectSize->getType()); +  if (getObjectSize(ObjectSize->getArgOperand(0), Size, DL, TLI, EvalOptions) && +      isUIntN(ResultType->getBitWidth(), Size)) +    return ConstantInt::get(ResultType, Size); + +  if (!MustSucceed) +    return nullptr; + +  return ConstantInt::get(ResultType, MaxVal ? -1ULL : 0); +} + +STATISTIC(ObjectVisitorArgument, +          "Number of arguments with unsolved size and offset"); +STATISTIC(ObjectVisitorLoad, +          "Number of load instructions with unsolved size and offset"); + +APInt ObjectSizeOffsetVisitor::align(APInt Size, uint64_t Align) { +  if (Options.RoundToAlign && Align) +    return APInt(IntTyBits, alignTo(Size.getZExtValue(), Align)); +  return Size; +} + +ObjectSizeOffsetVisitor::ObjectSizeOffsetVisitor(const DataLayout &DL, +                                                 const TargetLibraryInfo *TLI, +                                                 LLVMContext &Context, +                                                 ObjectSizeOpts Options) +    : DL(DL), TLI(TLI), Options(Options) { +  // Pointer size must be rechecked for each object visited since it could have +  // a different address space. +} + +SizeOffsetType ObjectSizeOffsetVisitor::compute(Value *V) { +  IntTyBits = DL.getPointerTypeSizeInBits(V->getType()); +  Zero = APInt::getNullValue(IntTyBits); + +  V = V->stripPointerCasts(); +  if (Instruction *I = dyn_cast<Instruction>(V)) { +    // If we have already seen this instruction, bail out. Cycles can happen in +    // unreachable code after constant propagation. +    if (!SeenInsts.insert(I).second) +      return unknown(); + +    if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) +      return visitGEPOperator(*GEP); +    return visit(*I); +  } +  if (Argument *A = dyn_cast<Argument>(V)) +    return visitArgument(*A); +  if (ConstantPointerNull *P = dyn_cast<ConstantPointerNull>(V)) +    return visitConstantPointerNull(*P); +  if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) +    return visitGlobalAlias(*GA); +  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) +    return visitGlobalVariable(*GV); +  if (UndefValue *UV = dyn_cast<UndefValue>(V)) +    return visitUndefValue(*UV); +  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { +    if (CE->getOpcode() == Instruction::IntToPtr) +      return unknown(); // clueless +    if (CE->getOpcode() == Instruction::GetElementPtr) +      return visitGEPOperator(cast<GEPOperator>(*CE)); +  } + +  LLVM_DEBUG(dbgs() << "ObjectSizeOffsetVisitor::compute() unhandled value: " +                    << *V << '\n'); +  return unknown(); +} + +/// When we're compiling N-bit code, and the user uses parameters that are +/// greater than N bits (e.g. uint64_t on a 32-bit build), we can run into +/// trouble with APInt size issues. This function handles resizing + overflow +/// checks for us. Check and zext or trunc \p I depending on IntTyBits and +/// I's value. +bool ObjectSizeOffsetVisitor::CheckedZextOrTrunc(APInt &I) { +  // More bits than we can handle. Checking the bit width isn't necessary, but +  // it's faster than checking active bits, and should give `false` in the +  // vast majority of cases. +  if (I.getBitWidth() > IntTyBits && I.getActiveBits() > IntTyBits) +    return false; +  if (I.getBitWidth() != IntTyBits) +    I = I.zextOrTrunc(IntTyBits); +  return true; +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitAllocaInst(AllocaInst &I) { +  if (!I.getAllocatedType()->isSized()) +    return unknown(); + +  APInt Size(IntTyBits, DL.getTypeAllocSize(I.getAllocatedType())); +  if (!I.isArrayAllocation()) +    return std::make_pair(align(Size, I.getAlignment()), Zero); + +  Value *ArraySize = I.getArraySize(); +  if (const ConstantInt *C = dyn_cast<ConstantInt>(ArraySize)) { +    APInt NumElems = C->getValue(); +    if (!CheckedZextOrTrunc(NumElems)) +      return unknown(); + +    bool Overflow; +    Size = Size.umul_ov(NumElems, Overflow); +    return Overflow ? unknown() : std::make_pair(align(Size, I.getAlignment()), +                                                 Zero); +  } +  return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitArgument(Argument &A) { +  // No interprocedural analysis is done at the moment. +  if (!A.hasByValOrInAllocaAttr()) { +    ++ObjectVisitorArgument; +    return unknown(); +  } +  PointerType *PT = cast<PointerType>(A.getType()); +  APInt Size(IntTyBits, DL.getTypeAllocSize(PT->getElementType())); +  return std::make_pair(align(Size, A.getParamAlignment()), Zero); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitCallSite(CallSite CS) { +  Optional<AllocFnsTy> FnData = getAllocationSize(CS.getInstruction(), TLI); +  if (!FnData) +    return unknown(); + +  // Handle strdup-like functions separately. +  if (FnData->AllocTy == StrDupLike) { +    APInt Size(IntTyBits, GetStringLength(CS.getArgument(0))); +    if (!Size) +      return unknown(); + +    // Strndup limits strlen. +    if (FnData->FstParam > 0) { +      ConstantInt *Arg = +          dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam)); +      if (!Arg) +        return unknown(); + +      APInt MaxSize = Arg->getValue().zextOrSelf(IntTyBits); +      if (Size.ugt(MaxSize)) +        Size = MaxSize + 1; +    } +    return std::make_pair(Size, Zero); +  } + +  ConstantInt *Arg = dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam)); +  if (!Arg) +    return unknown(); + +  APInt Size = Arg->getValue(); +  if (!CheckedZextOrTrunc(Size)) +    return unknown(); + +  // Size is determined by just 1 parameter. +  if (FnData->SndParam < 0) +    return std::make_pair(Size, Zero); + +  Arg = dyn_cast<ConstantInt>(CS.getArgument(FnData->SndParam)); +  if (!Arg) +    return unknown(); + +  APInt NumElems = Arg->getValue(); +  if (!CheckedZextOrTrunc(NumElems)) +    return unknown(); + +  bool Overflow; +  Size = Size.umul_ov(NumElems, Overflow); +  return Overflow ? unknown() : std::make_pair(Size, Zero); + +  // TODO: handle more standard functions (+ wchar cousins): +  // - strdup / strndup +  // - strcpy / strncpy +  // - strcat / strncat +  // - memcpy / memmove +  // - strcat / strncat +  // - memset +} + +SizeOffsetType +ObjectSizeOffsetVisitor::visitConstantPointerNull(ConstantPointerNull& CPN) { +  // If null is unknown, there's nothing we can do. Additionally, non-zero +  // address spaces can make use of null, so we don't presume to know anything +  // about that. +  // +  // TODO: How should this work with address space casts? We currently just drop +  // them on the floor, but it's unclear what we should do when a NULL from +  // addrspace(1) gets casted to addrspace(0) (or vice-versa). +  if (Options.NullIsUnknownSize || CPN.getType()->getAddressSpace()) +    return unknown(); +  return std::make_pair(Zero, Zero); +} + +SizeOffsetType +ObjectSizeOffsetVisitor::visitExtractElementInst(ExtractElementInst&) { +  return unknown(); +} + +SizeOffsetType +ObjectSizeOffsetVisitor::visitExtractValueInst(ExtractValueInst&) { +  // Easy cases were already folded by previous passes. +  return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitGEPOperator(GEPOperator &GEP) { +  SizeOffsetType PtrData = compute(GEP.getPointerOperand()); +  APInt Offset(IntTyBits, 0); +  if (!bothKnown(PtrData) || !GEP.accumulateConstantOffset(DL, Offset)) +    return unknown(); + +  return std::make_pair(PtrData.first, PtrData.second + Offset); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalAlias(GlobalAlias &GA) { +  if (GA.isInterposable()) +    return unknown(); +  return compute(GA.getAliasee()); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalVariable(GlobalVariable &GV){ +  if (!GV.hasDefinitiveInitializer()) +    return unknown(); + +  APInt Size(IntTyBits, DL.getTypeAllocSize(GV.getType()->getElementType())); +  return std::make_pair(align(Size, GV.getAlignment()), Zero); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitIntToPtrInst(IntToPtrInst&) { +  // clueless +  return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitLoadInst(LoadInst&) { +  ++ObjectVisitorLoad; +  return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitPHINode(PHINode&) { +  // too complex to analyze statically. +  return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitSelectInst(SelectInst &I) { +  SizeOffsetType TrueSide  = compute(I.getTrueValue()); +  SizeOffsetType FalseSide = compute(I.getFalseValue()); +  if (bothKnown(TrueSide) && bothKnown(FalseSide)) { +    if (TrueSide == FalseSide) { +        return TrueSide; +    } + +    APInt TrueResult = getSizeWithOverflow(TrueSide); +    APInt FalseResult = getSizeWithOverflow(FalseSide); + +    if (TrueResult == FalseResult) { +      return TrueSide; +    } +    if (Options.EvalMode == ObjectSizeOpts::Mode::Min) { +      if (TrueResult.slt(FalseResult)) +        return TrueSide; +      return FalseSide; +    } +    if (Options.EvalMode == ObjectSizeOpts::Mode::Max) { +      if (TrueResult.sgt(FalseResult)) +        return TrueSide; +      return FalseSide; +    } +  } +  return unknown(); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitUndefValue(UndefValue&) { +  return std::make_pair(Zero, Zero); +} + +SizeOffsetType ObjectSizeOffsetVisitor::visitInstruction(Instruction &I) { +  LLVM_DEBUG(dbgs() << "ObjectSizeOffsetVisitor unknown instruction:" << I +                    << '\n'); +  return unknown(); +} + +ObjectSizeOffsetEvaluator::ObjectSizeOffsetEvaluator( +    const DataLayout &DL, const TargetLibraryInfo *TLI, LLVMContext &Context, +    bool RoundToAlign) +    : DL(DL), TLI(TLI), Context(Context), Builder(Context, TargetFolder(DL)), +      RoundToAlign(RoundToAlign) { +  // IntTy and Zero must be set for each compute() since the address space may +  // be different for later objects. +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute(Value *V) { +  // XXX - Are vectors of pointers possible here? +  IntTy = cast<IntegerType>(DL.getIntPtrType(V->getType())); +  Zero = ConstantInt::get(IntTy, 0); + +  SizeOffsetEvalType Result = compute_(V); + +  if (!bothKnown(Result)) { +    // Erase everything that was computed in this iteration from the cache, so +    // that no dangling references are left behind. We could be a bit smarter if +    // we kept a dependency graph. It's probably not worth the complexity. +    for (const Value *SeenVal : SeenVals) { +      CacheMapTy::iterator CacheIt = CacheMap.find(SeenVal); +      // non-computable results can be safely cached +      if (CacheIt != CacheMap.end() && anyKnown(CacheIt->second)) +        CacheMap.erase(CacheIt); +    } +  } + +  SeenVals.clear(); +  return Result; +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) { +  ObjectSizeOpts ObjSizeOptions; +  ObjSizeOptions.RoundToAlign = RoundToAlign; + +  ObjectSizeOffsetVisitor Visitor(DL, TLI, Context, ObjSizeOptions); +  SizeOffsetType Const = Visitor.compute(V); +  if (Visitor.bothKnown(Const)) +    return std::make_pair(ConstantInt::get(Context, Const.first), +                          ConstantInt::get(Context, Const.second)); + +  V = V->stripPointerCasts(); + +  // Check cache. +  CacheMapTy::iterator CacheIt = CacheMap.find(V); +  if (CacheIt != CacheMap.end()) +    return CacheIt->second; + +  // Always generate code immediately before the instruction being +  // processed, so that the generated code dominates the same BBs. +  BuilderTy::InsertPointGuard Guard(Builder); +  if (Instruction *I = dyn_cast<Instruction>(V)) +    Builder.SetInsertPoint(I); + +  // Now compute the size and offset. +  SizeOffsetEvalType Result; + +  // Record the pointers that were handled in this run, so that they can be +  // cleaned later if something fails. We also use this set to break cycles that +  // can occur in dead code. +  if (!SeenVals.insert(V).second) { +    Result = unknown(); +  } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { +    Result = visitGEPOperator(*GEP); +  } else if (Instruction *I = dyn_cast<Instruction>(V)) { +    Result = visit(*I); +  } else if (isa<Argument>(V) || +             (isa<ConstantExpr>(V) && +              cast<ConstantExpr>(V)->getOpcode() == Instruction::IntToPtr) || +             isa<GlobalAlias>(V) || +             isa<GlobalVariable>(V)) { +    // Ignore values where we cannot do more than ObjectSizeVisitor. +    Result = unknown(); +  } else { +    LLVM_DEBUG( +        dbgs() << "ObjectSizeOffsetEvaluator::compute() unhandled value: " << *V +               << '\n'); +    Result = unknown(); +  } + +  // Don't reuse CacheIt since it may be invalid at this point. +  CacheMap[V] = Result; +  return Result; +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitAllocaInst(AllocaInst &I) { +  if (!I.getAllocatedType()->isSized()) +    return unknown(); + +  // must be a VLA +  assert(I.isArrayAllocation()); +  Value *ArraySize = I.getArraySize(); +  Value *Size = ConstantInt::get(ArraySize->getType(), +                                 DL.getTypeAllocSize(I.getAllocatedType())); +  Size = Builder.CreateMul(Size, ArraySize); +  return std::make_pair(Size, Zero); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitCallSite(CallSite CS) { +  Optional<AllocFnsTy> FnData = getAllocationSize(CS.getInstruction(), TLI); +  if (!FnData) +    return unknown(); + +  // Handle strdup-like functions separately. +  if (FnData->AllocTy == StrDupLike) { +    // TODO +    return unknown(); +  } + +  Value *FirstArg = CS.getArgument(FnData->FstParam); +  FirstArg = Builder.CreateZExt(FirstArg, IntTy); +  if (FnData->SndParam < 0) +    return std::make_pair(FirstArg, Zero); + +  Value *SecondArg = CS.getArgument(FnData->SndParam); +  SecondArg = Builder.CreateZExt(SecondArg, IntTy); +  Value *Size = Builder.CreateMul(FirstArg, SecondArg); +  return std::make_pair(Size, Zero); + +  // TODO: handle more standard functions (+ wchar cousins): +  // - strdup / strndup +  // - strcpy / strncpy +  // - strcat / strncat +  // - memcpy / memmove +  // - strcat / strncat +  // - memset +} + +SizeOffsetEvalType +ObjectSizeOffsetEvaluator::visitExtractElementInst(ExtractElementInst&) { +  return unknown(); +} + +SizeOffsetEvalType +ObjectSizeOffsetEvaluator::visitExtractValueInst(ExtractValueInst&) { +  return unknown(); +} + +SizeOffsetEvalType +ObjectSizeOffsetEvaluator::visitGEPOperator(GEPOperator &GEP) { +  SizeOffsetEvalType PtrData = compute_(GEP.getPointerOperand()); +  if (!bothKnown(PtrData)) +    return unknown(); + +  Value *Offset = EmitGEPOffset(&Builder, DL, &GEP, /*NoAssumptions=*/true); +  Offset = Builder.CreateAdd(PtrData.second, Offset); +  return std::make_pair(PtrData.first, Offset); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitIntToPtrInst(IntToPtrInst&) { +  // clueless +  return unknown(); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitLoadInst(LoadInst&) { +  return unknown(); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitPHINode(PHINode &PHI) { +  // Create 2 PHIs: one for size and another for offset. +  PHINode *SizePHI   = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues()); +  PHINode *OffsetPHI = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues()); + +  // Insert right away in the cache to handle recursive PHIs. +  CacheMap[&PHI] = std::make_pair(SizePHI, OffsetPHI); + +  // Compute offset/size for each PHI incoming pointer. +  for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i) { +    Builder.SetInsertPoint(&*PHI.getIncomingBlock(i)->getFirstInsertionPt()); +    SizeOffsetEvalType EdgeData = compute_(PHI.getIncomingValue(i)); + +    if (!bothKnown(EdgeData)) { +      OffsetPHI->replaceAllUsesWith(UndefValue::get(IntTy)); +      OffsetPHI->eraseFromParent(); +      SizePHI->replaceAllUsesWith(UndefValue::get(IntTy)); +      SizePHI->eraseFromParent(); +      return unknown(); +    } +    SizePHI->addIncoming(EdgeData.first, PHI.getIncomingBlock(i)); +    OffsetPHI->addIncoming(EdgeData.second, PHI.getIncomingBlock(i)); +  } + +  Value *Size = SizePHI, *Offset = OffsetPHI, *Tmp; +  if ((Tmp = SizePHI->hasConstantValue())) { +    Size = Tmp; +    SizePHI->replaceAllUsesWith(Size); +    SizePHI->eraseFromParent(); +  } +  if ((Tmp = OffsetPHI->hasConstantValue())) { +    Offset = Tmp; +    OffsetPHI->replaceAllUsesWith(Offset); +    OffsetPHI->eraseFromParent(); +  } +  return std::make_pair(Size, Offset); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitSelectInst(SelectInst &I) { +  SizeOffsetEvalType TrueSide  = compute_(I.getTrueValue()); +  SizeOffsetEvalType FalseSide = compute_(I.getFalseValue()); + +  if (!bothKnown(TrueSide) || !bothKnown(FalseSide)) +    return unknown(); +  if (TrueSide == FalseSide) +    return TrueSide; + +  Value *Size = Builder.CreateSelect(I.getCondition(), TrueSide.first, +                                     FalseSide.first); +  Value *Offset = Builder.CreateSelect(I.getCondition(), TrueSide.second, +                                       FalseSide.second); +  return std::make_pair(Size, Offset); +} + +SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitInstruction(Instruction &I) { +  LLVM_DEBUG(dbgs() << "ObjectSizeOffsetEvaluator unknown instruction:" << I +                    << '\n'); +  return unknown(); +} diff --git a/contrib/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp b/contrib/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp new file mode 100644 index 000000000000..feae53c54ecb --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -0,0 +1,1805 @@ +//===- MemoryDependenceAnalysis.cpp - Mem Deps Implementation -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements an analysis that determines, for a given memory +// operation, what preceding memory operations it depends on.  It builds on +// alias analysis information, and tries to provide a lazy, caching interface to +// a common kind of alias information query. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/PhiValues.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PredIteratorCache.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "memdep" + +STATISTIC(NumCacheNonLocal, "Number of fully cached non-local responses"); +STATISTIC(NumCacheDirtyNonLocal, "Number of dirty cached non-local responses"); +STATISTIC(NumUncacheNonLocal, "Number of uncached non-local responses"); + +STATISTIC(NumCacheNonLocalPtr, +          "Number of fully cached non-local ptr responses"); +STATISTIC(NumCacheDirtyNonLocalPtr, +          "Number of cached, but dirty, non-local ptr responses"); +STATISTIC(NumUncacheNonLocalPtr, "Number of uncached non-local ptr responses"); +STATISTIC(NumCacheCompleteNonLocalPtr, +          "Number of block queries that were completely cached"); + +// Limit for the number of instructions to scan in a block. + +static cl::opt<unsigned> BlockScanLimit( +    "memdep-block-scan-limit", cl::Hidden, cl::init(100), +    cl::desc("The number of instructions to scan in a block in memory " +             "dependency analysis (default = 100)")); + +static cl::opt<unsigned> +    BlockNumberLimit("memdep-block-number-limit", cl::Hidden, cl::init(1000), +                     cl::desc("The number of blocks to scan during memory " +                              "dependency analysis (default = 1000)")); + +// Limit on the number of memdep results to process. +static const unsigned int NumResultsLimit = 100; + +/// This is a helper function that removes Val from 'Inst's set in ReverseMap. +/// +/// If the set becomes empty, remove Inst's entry. +template <typename KeyTy> +static void +RemoveFromReverseMap(DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>> &ReverseMap, +                     Instruction *Inst, KeyTy Val) { +  typename DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>>::iterator InstIt = +      ReverseMap.find(Inst); +  assert(InstIt != ReverseMap.end() && "Reverse map out of sync?"); +  bool Found = InstIt->second.erase(Val); +  assert(Found && "Invalid reverse map!"); +  (void)Found; +  if (InstIt->second.empty()) +    ReverseMap.erase(InstIt); +} + +/// If the given instruction references a specific memory location, fill in Loc +/// with the details, otherwise set Loc.Ptr to null. +/// +/// Returns a ModRefInfo value describing the general behavior of the +/// instruction. +static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc, +                              const TargetLibraryInfo &TLI) { +  if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) { +    if (LI->isUnordered()) { +      Loc = MemoryLocation::get(LI); +      return ModRefInfo::Ref; +    } +    if (LI->getOrdering() == AtomicOrdering::Monotonic) { +      Loc = MemoryLocation::get(LI); +      return ModRefInfo::ModRef; +    } +    Loc = MemoryLocation(); +    return ModRefInfo::ModRef; +  } + +  if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { +    if (SI->isUnordered()) { +      Loc = MemoryLocation::get(SI); +      return ModRefInfo::Mod; +    } +    if (SI->getOrdering() == AtomicOrdering::Monotonic) { +      Loc = MemoryLocation::get(SI); +      return ModRefInfo::ModRef; +    } +    Loc = MemoryLocation(); +    return ModRefInfo::ModRef; +  } + +  if (const VAArgInst *V = dyn_cast<VAArgInst>(Inst)) { +    Loc = MemoryLocation::get(V); +    return ModRefInfo::ModRef; +  } + +  if (const CallInst *CI = isFreeCall(Inst, &TLI)) { +    // calls to free() deallocate the entire structure +    Loc = MemoryLocation(CI->getArgOperand(0)); +    return ModRefInfo::Mod; +  } + +  if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { +    switch (II->getIntrinsicID()) { +    case Intrinsic::lifetime_start: +    case Intrinsic::lifetime_end: +    case Intrinsic::invariant_start: +      Loc = MemoryLocation::getForArgument(II, 1, TLI); +      // These intrinsics don't really modify the memory, but returning Mod +      // will allow them to be handled conservatively. +      return ModRefInfo::Mod; +    case Intrinsic::invariant_end: +      Loc = MemoryLocation::getForArgument(II, 2, TLI); +      // These intrinsics don't really modify the memory, but returning Mod +      // will allow them to be handled conservatively. +      return ModRefInfo::Mod; +    default: +      break; +    } +  } + +  // Otherwise, just do the coarse-grained thing that always works. +  if (Inst->mayWriteToMemory()) +    return ModRefInfo::ModRef; +  if (Inst->mayReadFromMemory()) +    return ModRefInfo::Ref; +  return ModRefInfo::NoModRef; +} + +/// Private helper for finding the local dependencies of a call site. +MemDepResult MemoryDependenceResults::getCallSiteDependencyFrom( +    CallSite CS, bool isReadOnlyCall, BasicBlock::iterator ScanIt, +    BasicBlock *BB) { +  unsigned Limit = BlockScanLimit; + +  // Walk backwards through the block, looking for dependencies. +  while (ScanIt != BB->begin()) { +    Instruction *Inst = &*--ScanIt; +    // Debug intrinsics don't cause dependences and should not affect Limit +    if (isa<DbgInfoIntrinsic>(Inst)) +      continue; + +    // Limit the amount of scanning we do so we don't end up with quadratic +    // running time on extreme testcases. +    --Limit; +    if (!Limit) +      return MemDepResult::getUnknown(); + +    // If this inst is a memory op, get the pointer it accessed +    MemoryLocation Loc; +    ModRefInfo MR = GetLocation(Inst, Loc, TLI); +    if (Loc.Ptr) { +      // A simple instruction. +      if (isModOrRefSet(AA.getModRefInfo(CS, Loc))) +        return MemDepResult::getClobber(Inst); +      continue; +    } + +    if (auto InstCS = CallSite(Inst)) { +      // If these two calls do not interfere, look past it. +      if (isNoModRef(AA.getModRefInfo(CS, InstCS))) { +        // If the two calls are the same, return InstCS as a Def, so that +        // CS can be found redundant and eliminated. +        if (isReadOnlyCall && !isModSet(MR) && +            CS.getInstruction()->isIdenticalToWhenDefined(Inst)) +          return MemDepResult::getDef(Inst); + +        // Otherwise if the two calls don't interact (e.g. InstCS is readnone) +        // keep scanning. +        continue; +      } else +        return MemDepResult::getClobber(Inst); +    } + +    // If we could not obtain a pointer for the instruction and the instruction +    // touches memory then assume that this is a dependency. +    if (isModOrRefSet(MR)) +      return MemDepResult::getClobber(Inst); +  } + +  // No dependence found.  If this is the entry block of the function, it is +  // unknown, otherwise it is non-local. +  if (BB != &BB->getParent()->getEntryBlock()) +    return MemDepResult::getNonLocal(); +  return MemDepResult::getNonFuncLocal(); +} + +unsigned MemoryDependenceResults::getLoadLoadClobberFullWidthSize( +    const Value *MemLocBase, int64_t MemLocOffs, unsigned MemLocSize, +    const LoadInst *LI) { +  // We can only extend simple integer loads. +  if (!isa<IntegerType>(LI->getType()) || !LI->isSimple()) +    return 0; + +  // Load widening is hostile to ThreadSanitizer: it may cause false positives +  // or make the reports more cryptic (access sizes are wrong). +  if (LI->getParent()->getParent()->hasFnAttribute(Attribute::SanitizeThread)) +    return 0; + +  const DataLayout &DL = LI->getModule()->getDataLayout(); + +  // Get the base of this load. +  int64_t LIOffs = 0; +  const Value *LIBase = +      GetPointerBaseWithConstantOffset(LI->getPointerOperand(), LIOffs, DL); + +  // If the two pointers are not based on the same pointer, we can't tell that +  // they are related. +  if (LIBase != MemLocBase) +    return 0; + +  // Okay, the two values are based on the same pointer, but returned as +  // no-alias.  This happens when we have things like two byte loads at "P+1" +  // and "P+3".  Check to see if increasing the size of the "LI" load up to its +  // alignment (or the largest native integer type) will allow us to load all +  // the bits required by MemLoc. + +  // If MemLoc is before LI, then no widening of LI will help us out. +  if (MemLocOffs < LIOffs) +    return 0; + +  // Get the alignment of the load in bytes.  We assume that it is safe to load +  // any legal integer up to this size without a problem.  For example, if we're +  // looking at an i8 load on x86-32 that is known 1024 byte aligned, we can +  // widen it up to an i32 load.  If it is known 2-byte aligned, we can widen it +  // to i16. +  unsigned LoadAlign = LI->getAlignment(); + +  int64_t MemLocEnd = MemLocOffs + MemLocSize; + +  // If no amount of rounding up will let MemLoc fit into LI, then bail out. +  if (LIOffs + LoadAlign < MemLocEnd) +    return 0; + +  // This is the size of the load to try.  Start with the next larger power of +  // two. +  unsigned NewLoadByteSize = LI->getType()->getPrimitiveSizeInBits() / 8U; +  NewLoadByteSize = NextPowerOf2(NewLoadByteSize); + +  while (true) { +    // If this load size is bigger than our known alignment or would not fit +    // into a native integer register, then we fail. +    if (NewLoadByteSize > LoadAlign || +        !DL.fitsInLegalInteger(NewLoadByteSize * 8)) +      return 0; + +    if (LIOffs + NewLoadByteSize > MemLocEnd && +        (LI->getParent()->getParent()->hasFnAttribute( +             Attribute::SanitizeAddress) || +         LI->getParent()->getParent()->hasFnAttribute( +             Attribute::SanitizeHWAddress))) +      // We will be reading past the location accessed by the original program. +      // While this is safe in a regular build, Address Safety analysis tools +      // may start reporting false warnings. So, don't do widening. +      return 0; + +    // If a load of this width would include all of MemLoc, then we succeed. +    if (LIOffs + NewLoadByteSize >= MemLocEnd) +      return NewLoadByteSize; + +    NewLoadByteSize <<= 1; +  } +} + +static bool isVolatile(Instruction *Inst) { +  if (auto *LI = dyn_cast<LoadInst>(Inst)) +    return LI->isVolatile(); +  if (auto *SI = dyn_cast<StoreInst>(Inst)) +    return SI->isVolatile(); +  if (auto *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) +    return AI->isVolatile(); +  return false; +} + +MemDepResult MemoryDependenceResults::getPointerDependencyFrom( +    const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt, +    BasicBlock *BB, Instruction *QueryInst, unsigned *Limit) { +  MemDepResult InvariantGroupDependency = MemDepResult::getUnknown(); +  if (QueryInst != nullptr) { +    if (auto *LI = dyn_cast<LoadInst>(QueryInst)) { +      InvariantGroupDependency = getInvariantGroupPointerDependency(LI, BB); + +      if (InvariantGroupDependency.isDef()) +        return InvariantGroupDependency; +    } +  } +  MemDepResult SimpleDep = getSimplePointerDependencyFrom( +      MemLoc, isLoad, ScanIt, BB, QueryInst, Limit); +  if (SimpleDep.isDef()) +    return SimpleDep; +  // Non-local invariant group dependency indicates there is non local Def +  // (it only returns nonLocal if it finds nonLocal def), which is better than +  // local clobber and everything else. +  if (InvariantGroupDependency.isNonLocal()) +    return InvariantGroupDependency; + +  assert(InvariantGroupDependency.isUnknown() && +         "InvariantGroupDependency should be only unknown at this point"); +  return SimpleDep; +} + +MemDepResult +MemoryDependenceResults::getInvariantGroupPointerDependency(LoadInst *LI, +                                                            BasicBlock *BB) { + +  if (!LI->getMetadata(LLVMContext::MD_invariant_group)) +    return MemDepResult::getUnknown(); + +  // Take the ptr operand after all casts and geps 0. This way we can search +  // cast graph down only. +  Value *LoadOperand = LI->getPointerOperand()->stripPointerCasts(); + +  // It's is not safe to walk the use list of global value, because function +  // passes aren't allowed to look outside their functions. +  // FIXME: this could be fixed by filtering instructions from outside +  // of current function. +  if (isa<GlobalValue>(LoadOperand)) +    return MemDepResult::getUnknown(); + +  // Queue to process all pointers that are equivalent to load operand. +  SmallVector<const Value *, 8> LoadOperandsQueue; +  LoadOperandsQueue.push_back(LoadOperand); + +  Instruction *ClosestDependency = nullptr; +  // Order of instructions in uses list is unpredictible. In order to always +  // get the same result, we will look for the closest dominance. +  auto GetClosestDependency = [this](Instruction *Best, Instruction *Other) { +    assert(Other && "Must call it with not null instruction"); +    if (Best == nullptr || DT.dominates(Best, Other)) +      return Other; +    return Best; +  }; + +  // FIXME: This loop is O(N^2) because dominates can be O(n) and in worst case +  // we will see all the instructions. This should be fixed in MSSA. +  while (!LoadOperandsQueue.empty()) { +    const Value *Ptr = LoadOperandsQueue.pop_back_val(); +    assert(Ptr && !isa<GlobalValue>(Ptr) && +           "Null or GlobalValue should not be inserted"); + +    for (const Use &Us : Ptr->uses()) { +      auto *U = dyn_cast<Instruction>(Us.getUser()); +      if (!U || U == LI || !DT.dominates(U, LI)) +        continue; + +      // Bitcast or gep with zeros are using Ptr. Add to queue to check it's +      // users.      U = bitcast Ptr +      if (isa<BitCastInst>(U)) { +        LoadOperandsQueue.push_back(U); +        continue; +      } +      // Gep with zeros is equivalent to bitcast. +      // FIXME: we are not sure if some bitcast should be canonicalized to gep 0 +      // or gep 0 to bitcast because of SROA, so there are 2 forms. When +      // typeless pointers will be ready then both cases will be gone +      // (and this BFS also won't be needed). +      if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) +        if (GEP->hasAllZeroIndices()) { +          LoadOperandsQueue.push_back(U); +          continue; +        } + +      // If we hit load/store with the same invariant.group metadata (and the +      // same pointer operand) we can assume that value pointed by pointer +      // operand didn't change. +      if ((isa<LoadInst>(U) || isa<StoreInst>(U)) && +          U->getMetadata(LLVMContext::MD_invariant_group) != nullptr) +        ClosestDependency = GetClosestDependency(ClosestDependency, U); +    } +  } + +  if (!ClosestDependency) +    return MemDepResult::getUnknown(); +  if (ClosestDependency->getParent() == BB) +    return MemDepResult::getDef(ClosestDependency); +  // Def(U) can't be returned here because it is non-local. If local +  // dependency won't be found then return nonLocal counting that the +  // user will call getNonLocalPointerDependency, which will return cached +  // result. +  NonLocalDefsCache.try_emplace( +      LI, NonLocalDepResult(ClosestDependency->getParent(), +                            MemDepResult::getDef(ClosestDependency), nullptr)); +  ReverseNonLocalDefsCache[ClosestDependency].insert(LI); +  return MemDepResult::getNonLocal(); +} + +MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( +    const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt, +    BasicBlock *BB, Instruction *QueryInst, unsigned *Limit) { +  bool isInvariantLoad = false; + +  if (!Limit) { +    unsigned DefaultLimit = BlockScanLimit; +    return getSimplePointerDependencyFrom(MemLoc, isLoad, ScanIt, BB, QueryInst, +                                          &DefaultLimit); +  } + +  // We must be careful with atomic accesses, as they may allow another thread +  //   to touch this location, clobbering it. We are conservative: if the +  //   QueryInst is not a simple (non-atomic) memory access, we automatically +  //   return getClobber. +  // If it is simple, we know based on the results of +  // "Compiler testing via a theory of sound optimisations in the C11/C++11 +  //   memory model" in PLDI 2013, that a non-atomic location can only be +  //   clobbered between a pair of a release and an acquire action, with no +  //   access to the location in between. +  // Here is an example for giving the general intuition behind this rule. +  // In the following code: +  //   store x 0; +  //   release action; [1] +  //   acquire action; [4] +  //   %val = load x; +  // It is unsafe to replace %val by 0 because another thread may be running: +  //   acquire action; [2] +  //   store x 42; +  //   release action; [3] +  // with synchronization from 1 to 2 and from 3 to 4, resulting in %val +  // being 42. A key property of this program however is that if either +  // 1 or 4 were missing, there would be a race between the store of 42 +  // either the store of 0 or the load (making the whole program racy). +  // The paper mentioned above shows that the same property is respected +  // by every program that can detect any optimization of that kind: either +  // it is racy (undefined) or there is a release followed by an acquire +  // between the pair of accesses under consideration. + +  // If the load is invariant, we "know" that it doesn't alias *any* write. We +  // do want to respect mustalias results since defs are useful for value +  // forwarding, but any mayalias write can be assumed to be noalias. +  // Arguably, this logic should be pushed inside AliasAnalysis itself. +  if (isLoad && QueryInst) { +    LoadInst *LI = dyn_cast<LoadInst>(QueryInst); +    if (LI && LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr) +      isInvariantLoad = true; +  } + +  const DataLayout &DL = BB->getModule()->getDataLayout(); + +  // Create a numbered basic block to lazily compute and cache instruction +  // positions inside a BB. This is used to provide fast queries for relative +  // position between two instructions in a BB and can be used by +  // AliasAnalysis::callCapturesBefore. +  OrderedBasicBlock OBB(BB); + +  // Return "true" if and only if the instruction I is either a non-simple +  // load or a non-simple store. +  auto isNonSimpleLoadOrStore = [](Instruction *I) -> bool { +    if (auto *LI = dyn_cast<LoadInst>(I)) +      return !LI->isSimple(); +    if (auto *SI = dyn_cast<StoreInst>(I)) +      return !SI->isSimple(); +    return false; +  }; + +  // Return "true" if I is not a load and not a store, but it does access +  // memory. +  auto isOtherMemAccess = [](Instruction *I) -> bool { +    return !isa<LoadInst>(I) && !isa<StoreInst>(I) && I->mayReadOrWriteMemory(); +  }; + +  // Walk backwards through the basic block, looking for dependencies. +  while (ScanIt != BB->begin()) { +    Instruction *Inst = &*--ScanIt; + +    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) +      // Debug intrinsics don't (and can't) cause dependencies. +      if (isa<DbgInfoIntrinsic>(II)) +        continue; + +    // Limit the amount of scanning we do so we don't end up with quadratic +    // running time on extreme testcases. +    --*Limit; +    if (!*Limit) +      return MemDepResult::getUnknown(); + +    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { +      // If we reach a lifetime begin or end marker, then the query ends here +      // because the value is undefined. +      if (II->getIntrinsicID() == Intrinsic::lifetime_start) { +        // FIXME: This only considers queries directly on the invariant-tagged +        // pointer, not on query pointers that are indexed off of them.  It'd +        // be nice to handle that at some point (the right approach is to use +        // GetPointerBaseWithConstantOffset). +        if (AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), MemLoc)) +          return MemDepResult::getDef(II); +        continue; +      } +    } + +    // Values depend on loads if the pointers are must aliased.  This means +    // that a load depends on another must aliased load from the same value. +    // One exception is atomic loads: a value can depend on an atomic load that +    // it does not alias with when this atomic load indicates that another +    // thread may be accessing the location. +    if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { +      // While volatile access cannot be eliminated, they do not have to clobber +      // non-aliasing locations, as normal accesses, for example, can be safely +      // reordered with volatile accesses. +      if (LI->isVolatile()) { +        if (!QueryInst) +          // Original QueryInst *may* be volatile +          return MemDepResult::getClobber(LI); +        if (isVolatile(QueryInst)) +          // Ordering required if QueryInst is itself volatile +          return MemDepResult::getClobber(LI); +        // Otherwise, volatile doesn't imply any special ordering +      } + +      // Atomic loads have complications involved. +      // A Monotonic (or higher) load is OK if the query inst is itself not +      // atomic. +      // FIXME: This is overly conservative. +      if (LI->isAtomic() && isStrongerThanUnordered(LI->getOrdering())) { +        if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || +            isOtherMemAccess(QueryInst)) +          return MemDepResult::getClobber(LI); +        if (LI->getOrdering() != AtomicOrdering::Monotonic) +          return MemDepResult::getClobber(LI); +      } + +      MemoryLocation LoadLoc = MemoryLocation::get(LI); + +      // If we found a pointer, check if it could be the same as our pointer. +      AliasResult R = AA.alias(LoadLoc, MemLoc); + +      if (isLoad) { +        if (R == NoAlias) +          continue; + +        // Must aliased loads are defs of each other. +        if (R == MustAlias) +          return MemDepResult::getDef(Inst); + +#if 0 // FIXME: Temporarily disabled. GVN is cleverly rewriting loads +      // in terms of clobbering loads, but since it does this by looking +      // at the clobbering load directly, it doesn't know about any +      // phi translation that may have happened along the way. + +        // If we have a partial alias, then return this as a clobber for the +        // client to handle. +        if (R == PartialAlias) +          return MemDepResult::getClobber(Inst); +#endif + +        // Random may-alias loads don't depend on each other without a +        // dependence. +        continue; +      } + +      // Stores don't depend on other no-aliased accesses. +      if (R == NoAlias) +        continue; + +      // Stores don't alias loads from read-only memory. +      if (AA.pointsToConstantMemory(LoadLoc)) +        continue; + +      // Stores depend on may/must aliased loads. +      return MemDepResult::getDef(Inst); +    } + +    if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { +      // Atomic stores have complications involved. +      // A Monotonic store is OK if the query inst is itself not atomic. +      // FIXME: This is overly conservative. +      if (!SI->isUnordered() && SI->isAtomic()) { +        if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || +            isOtherMemAccess(QueryInst)) +          return MemDepResult::getClobber(SI); +        if (SI->getOrdering() != AtomicOrdering::Monotonic) +          return MemDepResult::getClobber(SI); +      } + +      // FIXME: this is overly conservative. +      // While volatile access cannot be eliminated, they do not have to clobber +      // non-aliasing locations, as normal accesses can for example be reordered +      // with volatile accesses. +      if (SI->isVolatile()) +        if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || +            isOtherMemAccess(QueryInst)) +          return MemDepResult::getClobber(SI); + +      // If alias analysis can tell that this store is guaranteed to not modify +      // the query pointer, ignore it.  Use getModRefInfo to handle cases where +      // the query pointer points to constant memory etc. +      if (!isModOrRefSet(AA.getModRefInfo(SI, MemLoc))) +        continue; + +      // Ok, this store might clobber the query pointer.  Check to see if it is +      // a must alias: in this case, we want to return this as a def. +      // FIXME: Use ModRefInfo::Must bit from getModRefInfo call above. +      MemoryLocation StoreLoc = MemoryLocation::get(SI); + +      // If we found a pointer, check if it could be the same as our pointer. +      AliasResult R = AA.alias(StoreLoc, MemLoc); + +      if (R == NoAlias) +        continue; +      if (R == MustAlias) +        return MemDepResult::getDef(Inst); +      if (isInvariantLoad) +        continue; +      return MemDepResult::getClobber(Inst); +    } + +    // If this is an allocation, and if we know that the accessed pointer is to +    // the allocation, return Def.  This means that there is no dependence and +    // the access can be optimized based on that.  For example, a load could +    // turn into undef.  Note that we can bypass the allocation itself when +    // looking for a clobber in many cases; that's an alias property and is +    // handled by BasicAA. +    if (isa<AllocaInst>(Inst) || isNoAliasFn(Inst, &TLI)) { +      const Value *AccessPtr = GetUnderlyingObject(MemLoc.Ptr, DL); +      if (AccessPtr == Inst || AA.isMustAlias(Inst, AccessPtr)) +        return MemDepResult::getDef(Inst); +    } + +    if (isInvariantLoad) +      continue; + +    // A release fence requires that all stores complete before it, but does +    // not prevent the reordering of following loads or stores 'before' the +    // fence.  As a result, we look past it when finding a dependency for +    // loads.  DSE uses this to find preceeding stores to delete and thus we +    // can't bypass the fence if the query instruction is a store. +    if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) +      if (isLoad && FI->getOrdering() == AtomicOrdering::Release) +        continue; + +    // See if this instruction (e.g. a call or vaarg) mod/ref's the pointer. +    ModRefInfo MR = AA.getModRefInfo(Inst, MemLoc); +    // If necessary, perform additional analysis. +    if (isModAndRefSet(MR)) +      MR = AA.callCapturesBefore(Inst, MemLoc, &DT, &OBB); +    switch (clearMust(MR)) { +    case ModRefInfo::NoModRef: +      // If the call has no effect on the queried pointer, just ignore it. +      continue; +    case ModRefInfo::Mod: +      return MemDepResult::getClobber(Inst); +    case ModRefInfo::Ref: +      // If the call is known to never store to the pointer, and if this is a +      // load query, we can safely ignore it (scan past it). +      if (isLoad) +        continue; +      LLVM_FALLTHROUGH; +    default: +      // Otherwise, there is a potential dependence.  Return a clobber. +      return MemDepResult::getClobber(Inst); +    } +  } + +  // No dependence found.  If this is the entry block of the function, it is +  // unknown, otherwise it is non-local. +  if (BB != &BB->getParent()->getEntryBlock()) +    return MemDepResult::getNonLocal(); +  return MemDepResult::getNonFuncLocal(); +} + +MemDepResult MemoryDependenceResults::getDependency(Instruction *QueryInst) { +  Instruction *ScanPos = QueryInst; + +  // Check for a cached result +  MemDepResult &LocalCache = LocalDeps[QueryInst]; + +  // If the cached entry is non-dirty, just return it.  Note that this depends +  // on MemDepResult's default constructing to 'dirty'. +  if (!LocalCache.isDirty()) +    return LocalCache; + +  // Otherwise, if we have a dirty entry, we know we can start the scan at that +  // instruction, which may save us some work. +  if (Instruction *Inst = LocalCache.getInst()) { +    ScanPos = Inst; + +    RemoveFromReverseMap(ReverseLocalDeps, Inst, QueryInst); +  } + +  BasicBlock *QueryParent = QueryInst->getParent(); + +  // Do the scan. +  if (BasicBlock::iterator(QueryInst) == QueryParent->begin()) { +    // No dependence found. If this is the entry block of the function, it is +    // unknown, otherwise it is non-local. +    if (QueryParent != &QueryParent->getParent()->getEntryBlock()) +      LocalCache = MemDepResult::getNonLocal(); +    else +      LocalCache = MemDepResult::getNonFuncLocal(); +  } else { +    MemoryLocation MemLoc; +    ModRefInfo MR = GetLocation(QueryInst, MemLoc, TLI); +    if (MemLoc.Ptr) { +      // If we can do a pointer scan, make it happen. +      bool isLoad = !isModSet(MR); +      if (auto *II = dyn_cast<IntrinsicInst>(QueryInst)) +        isLoad |= II->getIntrinsicID() == Intrinsic::lifetime_start; + +      LocalCache = getPointerDependencyFrom( +          MemLoc, isLoad, ScanPos->getIterator(), QueryParent, QueryInst); +    } else if (isa<CallInst>(QueryInst) || isa<InvokeInst>(QueryInst)) { +      CallSite QueryCS(QueryInst); +      bool isReadOnly = AA.onlyReadsMemory(QueryCS); +      LocalCache = getCallSiteDependencyFrom( +          QueryCS, isReadOnly, ScanPos->getIterator(), QueryParent); +    } else +      // Non-memory instruction. +      LocalCache = MemDepResult::getUnknown(); +  } + +  // Remember the result! +  if (Instruction *I = LocalCache.getInst()) +    ReverseLocalDeps[I].insert(QueryInst); + +  return LocalCache; +} + +#ifndef NDEBUG +/// This method is used when -debug is specified to verify that cache arrays +/// are properly kept sorted. +static void AssertSorted(MemoryDependenceResults::NonLocalDepInfo &Cache, +                         int Count = -1) { +  if (Count == -1) +    Count = Cache.size(); +  assert(std::is_sorted(Cache.begin(), Cache.begin() + Count) && +         "Cache isn't sorted!"); +} +#endif + +const MemoryDependenceResults::NonLocalDepInfo & +MemoryDependenceResults::getNonLocalCallDependency(CallSite QueryCS) { +  assert(getDependency(QueryCS.getInstruction()).isNonLocal() && +         "getNonLocalCallDependency should only be used on calls with " +         "non-local deps!"); +  PerInstNLInfo &CacheP = NonLocalDeps[QueryCS.getInstruction()]; +  NonLocalDepInfo &Cache = CacheP.first; + +  // This is the set of blocks that need to be recomputed.  In the cached case, +  // this can happen due to instructions being deleted etc. In the uncached +  // case, this starts out as the set of predecessors we care about. +  SmallVector<BasicBlock *, 32> DirtyBlocks; + +  if (!Cache.empty()) { +    // Okay, we have a cache entry.  If we know it is not dirty, just return it +    // with no computation. +    if (!CacheP.second) { +      ++NumCacheNonLocal; +      return Cache; +    } + +    // If we already have a partially computed set of results, scan them to +    // determine what is dirty, seeding our initial DirtyBlocks worklist. +    for (auto &Entry : Cache) +      if (Entry.getResult().isDirty()) +        DirtyBlocks.push_back(Entry.getBB()); + +    // Sort the cache so that we can do fast binary search lookups below. +    llvm::sort(Cache.begin(), Cache.end()); + +    ++NumCacheDirtyNonLocal; +    // cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: " +    //     << Cache.size() << " cached: " << *QueryInst; +  } else { +    // Seed DirtyBlocks with each of the preds of QueryInst's block. +    BasicBlock *QueryBB = QueryCS.getInstruction()->getParent(); +    for (BasicBlock *Pred : PredCache.get(QueryBB)) +      DirtyBlocks.push_back(Pred); +    ++NumUncacheNonLocal; +  } + +  // isReadonlyCall - If this is a read-only call, we can be more aggressive. +  bool isReadonlyCall = AA.onlyReadsMemory(QueryCS); + +  SmallPtrSet<BasicBlock *, 32> Visited; + +  unsigned NumSortedEntries = Cache.size(); +  LLVM_DEBUG(AssertSorted(Cache)); + +  // Iterate while we still have blocks to update. +  while (!DirtyBlocks.empty()) { +    BasicBlock *DirtyBB = DirtyBlocks.back(); +    DirtyBlocks.pop_back(); + +    // Already processed this block? +    if (!Visited.insert(DirtyBB).second) +      continue; + +    // Do a binary search to see if we already have an entry for this block in +    // the cache set.  If so, find it. +    LLVM_DEBUG(AssertSorted(Cache, NumSortedEntries)); +    NonLocalDepInfo::iterator Entry = +        std::upper_bound(Cache.begin(), Cache.begin() + NumSortedEntries, +                         NonLocalDepEntry(DirtyBB)); +    if (Entry != Cache.begin() && std::prev(Entry)->getBB() == DirtyBB) +      --Entry; + +    NonLocalDepEntry *ExistingResult = nullptr; +    if (Entry != Cache.begin() + NumSortedEntries && +        Entry->getBB() == DirtyBB) { +      // If we already have an entry, and if it isn't already dirty, the block +      // is done. +      if (!Entry->getResult().isDirty()) +        continue; + +      // Otherwise, remember this slot so we can update the value. +      ExistingResult = &*Entry; +    } + +    // If the dirty entry has a pointer, start scanning from it so we don't have +    // to rescan the entire block. +    BasicBlock::iterator ScanPos = DirtyBB->end(); +    if (ExistingResult) { +      if (Instruction *Inst = ExistingResult->getResult().getInst()) { +        ScanPos = Inst->getIterator(); +        // We're removing QueryInst's use of Inst. +        RemoveFromReverseMap(ReverseNonLocalDeps, Inst, +                             QueryCS.getInstruction()); +      } +    } + +    // Find out if this block has a local dependency for QueryInst. +    MemDepResult Dep; + +    if (ScanPos != DirtyBB->begin()) { +      Dep = +          getCallSiteDependencyFrom(QueryCS, isReadonlyCall, ScanPos, DirtyBB); +    } else if (DirtyBB != &DirtyBB->getParent()->getEntryBlock()) { +      // No dependence found.  If this is the entry block of the function, it is +      // a clobber, otherwise it is unknown. +      Dep = MemDepResult::getNonLocal(); +    } else { +      Dep = MemDepResult::getNonFuncLocal(); +    } + +    // If we had a dirty entry for the block, update it.  Otherwise, just add +    // a new entry. +    if (ExistingResult) +      ExistingResult->setResult(Dep); +    else +      Cache.push_back(NonLocalDepEntry(DirtyBB, Dep)); + +    // If the block has a dependency (i.e. it isn't completely transparent to +    // the value), remember the association! +    if (!Dep.isNonLocal()) { +      // Keep the ReverseNonLocalDeps map up to date so we can efficiently +      // update this when we remove instructions. +      if (Instruction *Inst = Dep.getInst()) +        ReverseNonLocalDeps[Inst].insert(QueryCS.getInstruction()); +    } else { + +      // If the block *is* completely transparent to the load, we need to check +      // the predecessors of this block.  Add them to our worklist. +      for (BasicBlock *Pred : PredCache.get(DirtyBB)) +        DirtyBlocks.push_back(Pred); +    } +  } + +  return Cache; +} + +void MemoryDependenceResults::getNonLocalPointerDependency( +    Instruction *QueryInst, SmallVectorImpl<NonLocalDepResult> &Result) { +  const MemoryLocation Loc = MemoryLocation::get(QueryInst); +  bool isLoad = isa<LoadInst>(QueryInst); +  BasicBlock *FromBB = QueryInst->getParent(); +  assert(FromBB); + +  assert(Loc.Ptr->getType()->isPointerTy() && +         "Can't get pointer deps of a non-pointer!"); +  Result.clear(); +  { +    // Check if there is cached Def with invariant.group. +    auto NonLocalDefIt = NonLocalDefsCache.find(QueryInst); +    if (NonLocalDefIt != NonLocalDefsCache.end()) { +      Result.push_back(NonLocalDefIt->second); +      ReverseNonLocalDefsCache[NonLocalDefIt->second.getResult().getInst()] +          .erase(QueryInst); +      NonLocalDefsCache.erase(NonLocalDefIt); +      return; +    } +  } +  // This routine does not expect to deal with volatile instructions. +  // Doing so would require piping through the QueryInst all the way through. +  // TODO: volatiles can't be elided, but they can be reordered with other +  // non-volatile accesses. + +  // We currently give up on any instruction which is ordered, but we do handle +  // atomic instructions which are unordered. +  // TODO: Handle ordered instructions +  auto isOrdered = [](Instruction *Inst) { +    if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { +      return !LI->isUnordered(); +    } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { +      return !SI->isUnordered(); +    } +    return false; +  }; +  if (isVolatile(QueryInst) || isOrdered(QueryInst)) { +    Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(), +                                       const_cast<Value *>(Loc.Ptr))); +    return; +  } +  const DataLayout &DL = FromBB->getModule()->getDataLayout(); +  PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, &AC); + +  // This is the set of blocks we've inspected, and the pointer we consider in +  // each block.  Because of critical edges, we currently bail out if querying +  // a block with multiple different pointers.  This can happen during PHI +  // translation. +  DenseMap<BasicBlock *, Value *> Visited; +  if (getNonLocalPointerDepFromBB(QueryInst, Address, Loc, isLoad, FromBB, +                                   Result, Visited, true)) +    return; +  Result.clear(); +  Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(), +                                     const_cast<Value *>(Loc.Ptr))); +} + +/// Compute the memdep value for BB with Pointer/PointeeSize using either +/// cached information in Cache or by doing a lookup (which may use dirty cache +/// info if available). +/// +/// If we do a lookup, add the result to the cache. +MemDepResult MemoryDependenceResults::GetNonLocalInfoForBlock( +    Instruction *QueryInst, const MemoryLocation &Loc, bool isLoad, +    BasicBlock *BB, NonLocalDepInfo *Cache, unsigned NumSortedEntries) { + +  // Do a binary search to see if we already have an entry for this block in +  // the cache set.  If so, find it. +  NonLocalDepInfo::iterator Entry = std::upper_bound( +      Cache->begin(), Cache->begin() + NumSortedEntries, NonLocalDepEntry(BB)); +  if (Entry != Cache->begin() && (Entry - 1)->getBB() == BB) +    --Entry; + +  NonLocalDepEntry *ExistingResult = nullptr; +  if (Entry != Cache->begin() + NumSortedEntries && Entry->getBB() == BB) +    ExistingResult = &*Entry; + +  // If we have a cached entry, and it is non-dirty, use it as the value for +  // this dependency. +  if (ExistingResult && !ExistingResult->getResult().isDirty()) { +    ++NumCacheNonLocalPtr; +    return ExistingResult->getResult(); +  } + +  // Otherwise, we have to scan for the value.  If we have a dirty cache +  // entry, start scanning from its position, otherwise we scan from the end +  // of the block. +  BasicBlock::iterator ScanPos = BB->end(); +  if (ExistingResult && ExistingResult->getResult().getInst()) { +    assert(ExistingResult->getResult().getInst()->getParent() == BB && +           "Instruction invalidated?"); +    ++NumCacheDirtyNonLocalPtr; +    ScanPos = ExistingResult->getResult().getInst()->getIterator(); + +    // Eliminating the dirty entry from 'Cache', so update the reverse info. +    ValueIsLoadPair CacheKey(Loc.Ptr, isLoad); +    RemoveFromReverseMap(ReverseNonLocalPtrDeps, &*ScanPos, CacheKey); +  } else { +    ++NumUncacheNonLocalPtr; +  } + +  // Scan the block for the dependency. +  MemDepResult Dep = +      getPointerDependencyFrom(Loc, isLoad, ScanPos, BB, QueryInst); + +  // If we had a dirty entry for the block, update it.  Otherwise, just add +  // a new entry. +  if (ExistingResult) +    ExistingResult->setResult(Dep); +  else +    Cache->push_back(NonLocalDepEntry(BB, Dep)); + +  // If the block has a dependency (i.e. it isn't completely transparent to +  // the value), remember the reverse association because we just added it +  // to Cache! +  if (!Dep.isDef() && !Dep.isClobber()) +    return Dep; + +  // Keep the ReverseNonLocalPtrDeps map up to date so we can efficiently +  // update MemDep when we remove instructions. +  Instruction *Inst = Dep.getInst(); +  assert(Inst && "Didn't depend on anything?"); +  ValueIsLoadPair CacheKey(Loc.Ptr, isLoad); +  ReverseNonLocalPtrDeps[Inst].insert(CacheKey); +  return Dep; +} + +/// Sort the NonLocalDepInfo cache, given a certain number of elements in the +/// array that are already properly ordered. +/// +/// This is optimized for the case when only a few entries are added. +static void +SortNonLocalDepInfoCache(MemoryDependenceResults::NonLocalDepInfo &Cache, +                         unsigned NumSortedEntries) { +  switch (Cache.size() - NumSortedEntries) { +  case 0: +    // done, no new entries. +    break; +  case 2: { +    // Two new entries, insert the last one into place. +    NonLocalDepEntry Val = Cache.back(); +    Cache.pop_back(); +    MemoryDependenceResults::NonLocalDepInfo::iterator Entry = +        std::upper_bound(Cache.begin(), Cache.end() - 1, Val); +    Cache.insert(Entry, Val); +    LLVM_FALLTHROUGH; +  } +  case 1: +    // One new entry, Just insert the new value at the appropriate position. +    if (Cache.size() != 1) { +      NonLocalDepEntry Val = Cache.back(); +      Cache.pop_back(); +      MemoryDependenceResults::NonLocalDepInfo::iterator Entry = +          std::upper_bound(Cache.begin(), Cache.end(), Val); +      Cache.insert(Entry, Val); +    } +    break; +  default: +    // Added many values, do a full scale sort. +    llvm::sort(Cache.begin(), Cache.end()); +    break; +  } +} + +/// Perform a dependency query based on pointer/pointeesize starting at the end +/// of StartBB. +/// +/// Add any clobber/def results to the results vector and keep track of which +/// blocks are visited in 'Visited'. +/// +/// This has special behavior for the first block queries (when SkipFirstBlock +/// is true).  In this special case, it ignores the contents of the specified +/// block and starts returning dependence info for its predecessors. +/// +/// This function returns true on success, or false to indicate that it could +/// not compute dependence information for some reason.  This should be treated +/// as a clobber dependence on the first instruction in the predecessor block. +bool MemoryDependenceResults::getNonLocalPointerDepFromBB( +    Instruction *QueryInst, const PHITransAddr &Pointer, +    const MemoryLocation &Loc, bool isLoad, BasicBlock *StartBB, +    SmallVectorImpl<NonLocalDepResult> &Result, +    DenseMap<BasicBlock *, Value *> &Visited, bool SkipFirstBlock) { +  // Look up the cached info for Pointer. +  ValueIsLoadPair CacheKey(Pointer.getAddr(), isLoad); + +  // Set up a temporary NLPI value. If the map doesn't yet have an entry for +  // CacheKey, this value will be inserted as the associated value. Otherwise, +  // it'll be ignored, and we'll have to check to see if the cached size and +  // aa tags are consistent with the current query. +  NonLocalPointerInfo InitialNLPI; +  InitialNLPI.Size = Loc.Size; +  InitialNLPI.AATags = Loc.AATags; + +  // Get the NLPI for CacheKey, inserting one into the map if it doesn't +  // already have one. +  std::pair<CachedNonLocalPointerInfo::iterator, bool> Pair = +      NonLocalPointerDeps.insert(std::make_pair(CacheKey, InitialNLPI)); +  NonLocalPointerInfo *CacheInfo = &Pair.first->second; + +  // If we already have a cache entry for this CacheKey, we may need to do some +  // work to reconcile the cache entry and the current query. +  if (!Pair.second) { +    if (CacheInfo->Size < Loc.Size) { +      // The query's Size is greater than the cached one. Throw out the +      // cached data and proceed with the query at the greater size. +      CacheInfo->Pair = BBSkipFirstBlockPair(); +      CacheInfo->Size = Loc.Size; +      for (auto &Entry : CacheInfo->NonLocalDeps) +        if (Instruction *Inst = Entry.getResult().getInst()) +          RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); +      CacheInfo->NonLocalDeps.clear(); +    } else if (CacheInfo->Size > Loc.Size) { +      // This query's Size is less than the cached one. Conservatively restart +      // the query using the greater size. +      return getNonLocalPointerDepFromBB( +          QueryInst, Pointer, Loc.getWithNewSize(CacheInfo->Size), isLoad, +          StartBB, Result, Visited, SkipFirstBlock); +    } + +    // If the query's AATags are inconsistent with the cached one, +    // conservatively throw out the cached data and restart the query with +    // no tag if needed. +    if (CacheInfo->AATags != Loc.AATags) { +      if (CacheInfo->AATags) { +        CacheInfo->Pair = BBSkipFirstBlockPair(); +        CacheInfo->AATags = AAMDNodes(); +        for (auto &Entry : CacheInfo->NonLocalDeps) +          if (Instruction *Inst = Entry.getResult().getInst()) +            RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); +        CacheInfo->NonLocalDeps.clear(); +      } +      if (Loc.AATags) +        return getNonLocalPointerDepFromBB( +            QueryInst, Pointer, Loc.getWithoutAATags(), isLoad, StartBB, Result, +            Visited, SkipFirstBlock); +    } +  } + +  NonLocalDepInfo *Cache = &CacheInfo->NonLocalDeps; + +  // If we have valid cached information for exactly the block we are +  // investigating, just return it with no recomputation. +  if (CacheInfo->Pair == BBSkipFirstBlockPair(StartBB, SkipFirstBlock)) { +    // We have a fully cached result for this query then we can just return the +    // cached results and populate the visited set.  However, we have to verify +    // that we don't already have conflicting results for these blocks.  Check +    // to ensure that if a block in the results set is in the visited set that +    // it was for the same pointer query. +    if (!Visited.empty()) { +      for (auto &Entry : *Cache) { +        DenseMap<BasicBlock *, Value *>::iterator VI = +            Visited.find(Entry.getBB()); +        if (VI == Visited.end() || VI->second == Pointer.getAddr()) +          continue; + +        // We have a pointer mismatch in a block.  Just return false, saying +        // that something was clobbered in this result.  We could also do a +        // non-fully cached query, but there is little point in doing this. +        return false; +      } +    } + +    Value *Addr = Pointer.getAddr(); +    for (auto &Entry : *Cache) { +      Visited.insert(std::make_pair(Entry.getBB(), Addr)); +      if (Entry.getResult().isNonLocal()) { +        continue; +      } + +      if (DT.isReachableFromEntry(Entry.getBB())) { +        Result.push_back( +            NonLocalDepResult(Entry.getBB(), Entry.getResult(), Addr)); +      } +    } +    ++NumCacheCompleteNonLocalPtr; +    return true; +  } + +  // Otherwise, either this is a new block, a block with an invalid cache +  // pointer or one that we're about to invalidate by putting more info into it +  // than its valid cache info.  If empty, the result will be valid cache info, +  // otherwise it isn't. +  if (Cache->empty()) +    CacheInfo->Pair = BBSkipFirstBlockPair(StartBB, SkipFirstBlock); +  else +    CacheInfo->Pair = BBSkipFirstBlockPair(); + +  SmallVector<BasicBlock *, 32> Worklist; +  Worklist.push_back(StartBB); + +  // PredList used inside loop. +  SmallVector<std::pair<BasicBlock *, PHITransAddr>, 16> PredList; + +  // Keep track of the entries that we know are sorted.  Previously cached +  // entries will all be sorted.  The entries we add we only sort on demand (we +  // don't insert every element into its sorted position).  We know that we +  // won't get any reuse from currently inserted values, because we don't +  // revisit blocks after we insert info for them. +  unsigned NumSortedEntries = Cache->size(); +  unsigned WorklistEntries = BlockNumberLimit; +  bool GotWorklistLimit = false; +  LLVM_DEBUG(AssertSorted(*Cache)); + +  while (!Worklist.empty()) { +    BasicBlock *BB = Worklist.pop_back_val(); + +    // If we do process a large number of blocks it becomes very expensive and +    // likely it isn't worth worrying about +    if (Result.size() > NumResultsLimit) { +      Worklist.clear(); +      // Sort it now (if needed) so that recursive invocations of +      // getNonLocalPointerDepFromBB and other routines that could reuse the +      // cache value will only see properly sorted cache arrays. +      if (Cache && NumSortedEntries != Cache->size()) { +        SortNonLocalDepInfoCache(*Cache, NumSortedEntries); +      } +      // Since we bail out, the "Cache" set won't contain all of the +      // results for the query.  This is ok (we can still use it to accelerate +      // specific block queries) but we can't do the fastpath "return all +      // results from the set".  Clear out the indicator for this. +      CacheInfo->Pair = BBSkipFirstBlockPair(); +      return false; +    } + +    // Skip the first block if we have it. +    if (!SkipFirstBlock) { +      // Analyze the dependency of *Pointer in FromBB.  See if we already have +      // been here. +      assert(Visited.count(BB) && "Should check 'visited' before adding to WL"); + +      // Get the dependency info for Pointer in BB.  If we have cached +      // information, we will use it, otherwise we compute it. +      LLVM_DEBUG(AssertSorted(*Cache, NumSortedEntries)); +      MemDepResult Dep = GetNonLocalInfoForBlock(QueryInst, Loc, isLoad, BB, +                                                 Cache, NumSortedEntries); + +      // If we got a Def or Clobber, add this to the list of results. +      if (!Dep.isNonLocal()) { +        if (DT.isReachableFromEntry(BB)) { +          Result.push_back(NonLocalDepResult(BB, Dep, Pointer.getAddr())); +          continue; +        } +      } +    } + +    // If 'Pointer' is an instruction defined in this block, then we need to do +    // phi translation to change it into a value live in the predecessor block. +    // If not, we just add the predecessors to the worklist and scan them with +    // the same Pointer. +    if (!Pointer.NeedsPHITranslationFromBlock(BB)) { +      SkipFirstBlock = false; +      SmallVector<BasicBlock *, 16> NewBlocks; +      for (BasicBlock *Pred : PredCache.get(BB)) { +        // Verify that we haven't looked at this block yet. +        std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes = +            Visited.insert(std::make_pair(Pred, Pointer.getAddr())); +        if (InsertRes.second) { +          // First time we've looked at *PI. +          NewBlocks.push_back(Pred); +          continue; +        } + +        // If we have seen this block before, but it was with a different +        // pointer then we have a phi translation failure and we have to treat +        // this as a clobber. +        if (InsertRes.first->second != Pointer.getAddr()) { +          // Make sure to clean up the Visited map before continuing on to +          // PredTranslationFailure. +          for (unsigned i = 0; i < NewBlocks.size(); i++) +            Visited.erase(NewBlocks[i]); +          goto PredTranslationFailure; +        } +      } +      if (NewBlocks.size() > WorklistEntries) { +        // Make sure to clean up the Visited map before continuing on to +        // PredTranslationFailure. +        for (unsigned i = 0; i < NewBlocks.size(); i++) +          Visited.erase(NewBlocks[i]); +        GotWorklistLimit = true; +        goto PredTranslationFailure; +      } +      WorklistEntries -= NewBlocks.size(); +      Worklist.append(NewBlocks.begin(), NewBlocks.end()); +      continue; +    } + +    // We do need to do phi translation, if we know ahead of time we can't phi +    // translate this value, don't even try. +    if (!Pointer.IsPotentiallyPHITranslatable()) +      goto PredTranslationFailure; + +    // We may have added values to the cache list before this PHI translation. +    // If so, we haven't done anything to ensure that the cache remains sorted. +    // Sort it now (if needed) so that recursive invocations of +    // getNonLocalPointerDepFromBB and other routines that could reuse the cache +    // value will only see properly sorted cache arrays. +    if (Cache && NumSortedEntries != Cache->size()) { +      SortNonLocalDepInfoCache(*Cache, NumSortedEntries); +      NumSortedEntries = Cache->size(); +    } +    Cache = nullptr; + +    PredList.clear(); +    for (BasicBlock *Pred : PredCache.get(BB)) { +      PredList.push_back(std::make_pair(Pred, Pointer)); + +      // Get the PHI translated pointer in this predecessor.  This can fail if +      // not translatable, in which case the getAddr() returns null. +      PHITransAddr &PredPointer = PredList.back().second; +      PredPointer.PHITranslateValue(BB, Pred, &DT, /*MustDominate=*/false); +      Value *PredPtrVal = PredPointer.getAddr(); + +      // Check to see if we have already visited this pred block with another +      // pointer.  If so, we can't do this lookup.  This failure can occur +      // with PHI translation when a critical edge exists and the PHI node in +      // the successor translates to a pointer value different than the +      // pointer the block was first analyzed with. +      std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes = +          Visited.insert(std::make_pair(Pred, PredPtrVal)); + +      if (!InsertRes.second) { +        // We found the pred; take it off the list of preds to visit. +        PredList.pop_back(); + +        // If the predecessor was visited with PredPtr, then we already did +        // the analysis and can ignore it. +        if (InsertRes.first->second == PredPtrVal) +          continue; + +        // Otherwise, the block was previously analyzed with a different +        // pointer.  We can't represent the result of this case, so we just +        // treat this as a phi translation failure. + +        // Make sure to clean up the Visited map before continuing on to +        // PredTranslationFailure. +        for (unsigned i = 0, n = PredList.size(); i < n; ++i) +          Visited.erase(PredList[i].first); + +        goto PredTranslationFailure; +      } +    } + +    // Actually process results here; this need to be a separate loop to avoid +    // calling getNonLocalPointerDepFromBB for blocks we don't want to return +    // any results for.  (getNonLocalPointerDepFromBB will modify our +    // datastructures in ways the code after the PredTranslationFailure label +    // doesn't expect.) +    for (unsigned i = 0, n = PredList.size(); i < n; ++i) { +      BasicBlock *Pred = PredList[i].first; +      PHITransAddr &PredPointer = PredList[i].second; +      Value *PredPtrVal = PredPointer.getAddr(); + +      bool CanTranslate = true; +      // If PHI translation was unable to find an available pointer in this +      // predecessor, then we have to assume that the pointer is clobbered in +      // that predecessor.  We can still do PRE of the load, which would insert +      // a computation of the pointer in this predecessor. +      if (!PredPtrVal) +        CanTranslate = false; + +      // FIXME: it is entirely possible that PHI translating will end up with +      // the same value.  Consider PHI translating something like: +      // X = phi [x, bb1], [y, bb2].  PHI translating for bb1 doesn't *need* +      // to recurse here, pedantically speaking. + +      // If getNonLocalPointerDepFromBB fails here, that means the cached +      // result conflicted with the Visited list; we have to conservatively +      // assume it is unknown, but this also does not block PRE of the load. +      if (!CanTranslate || +          !getNonLocalPointerDepFromBB(QueryInst, PredPointer, +                                      Loc.getWithNewPtr(PredPtrVal), isLoad, +                                      Pred, Result, Visited)) { +        // Add the entry to the Result list. +        NonLocalDepResult Entry(Pred, MemDepResult::getUnknown(), PredPtrVal); +        Result.push_back(Entry); + +        // Since we had a phi translation failure, the cache for CacheKey won't +        // include all of the entries that we need to immediately satisfy future +        // queries.  Mark this in NonLocalPointerDeps by setting the +        // BBSkipFirstBlockPair pointer to null.  This requires reuse of the +        // cached value to do more work but not miss the phi trans failure. +        NonLocalPointerInfo &NLPI = NonLocalPointerDeps[CacheKey]; +        NLPI.Pair = BBSkipFirstBlockPair(); +        continue; +      } +    } + +    // Refresh the CacheInfo/Cache pointer so that it isn't invalidated. +    CacheInfo = &NonLocalPointerDeps[CacheKey]; +    Cache = &CacheInfo->NonLocalDeps; +    NumSortedEntries = Cache->size(); + +    // Since we did phi translation, the "Cache" set won't contain all of the +    // results for the query.  This is ok (we can still use it to accelerate +    // specific block queries) but we can't do the fastpath "return all +    // results from the set"  Clear out the indicator for this. +    CacheInfo->Pair = BBSkipFirstBlockPair(); +    SkipFirstBlock = false; +    continue; + +  PredTranslationFailure: +    // The following code is "failure"; we can't produce a sane translation +    // for the given block.  It assumes that we haven't modified any of +    // our datastructures while processing the current block. + +    if (!Cache) { +      // Refresh the CacheInfo/Cache pointer if it got invalidated. +      CacheInfo = &NonLocalPointerDeps[CacheKey]; +      Cache = &CacheInfo->NonLocalDeps; +      NumSortedEntries = Cache->size(); +    } + +    // Since we failed phi translation, the "Cache" set won't contain all of the +    // results for the query.  This is ok (we can still use it to accelerate +    // specific block queries) but we can't do the fastpath "return all +    // results from the set".  Clear out the indicator for this. +    CacheInfo->Pair = BBSkipFirstBlockPair(); + +    // If *nothing* works, mark the pointer as unknown. +    // +    // If this is the magic first block, return this as a clobber of the whole +    // incoming value.  Since we can't phi translate to one of the predecessors, +    // we have to bail out. +    if (SkipFirstBlock) +      return false; + +    bool foundBlock = false; +    for (NonLocalDepEntry &I : llvm::reverse(*Cache)) { +      if (I.getBB() != BB) +        continue; + +      assert((GotWorklistLimit || I.getResult().isNonLocal() || +              !DT.isReachableFromEntry(BB)) && +             "Should only be here with transparent block"); +      foundBlock = true; +      I.setResult(MemDepResult::getUnknown()); +      Result.push_back( +          NonLocalDepResult(I.getBB(), I.getResult(), Pointer.getAddr())); +      break; +    } +    (void)foundBlock; (void)GotWorklistLimit; +    assert((foundBlock || GotWorklistLimit) && "Current block not in cache?"); +  } + +  // Okay, we're done now.  If we added new values to the cache, re-sort it. +  SortNonLocalDepInfoCache(*Cache, NumSortedEntries); +  LLVM_DEBUG(AssertSorted(*Cache)); +  return true; +} + +/// If P exists in CachedNonLocalPointerInfo or NonLocalDefsCache, remove it. +void MemoryDependenceResults::RemoveCachedNonLocalPointerDependencies( +    ValueIsLoadPair P) { + +  // Most of the time this cache is empty. +  if (!NonLocalDefsCache.empty()) { +    auto it = NonLocalDefsCache.find(P.getPointer()); +    if (it != NonLocalDefsCache.end()) { +      RemoveFromReverseMap(ReverseNonLocalDefsCache, +                           it->second.getResult().getInst(), P.getPointer()); +      NonLocalDefsCache.erase(it); +    } + +    if (auto *I = dyn_cast<Instruction>(P.getPointer())) { +      auto toRemoveIt = ReverseNonLocalDefsCache.find(I); +      if (toRemoveIt != ReverseNonLocalDefsCache.end()) { +        for (const auto &entry : toRemoveIt->second) +          NonLocalDefsCache.erase(entry); +        ReverseNonLocalDefsCache.erase(toRemoveIt); +      } +    } +  } + +  CachedNonLocalPointerInfo::iterator It = NonLocalPointerDeps.find(P); +  if (It == NonLocalPointerDeps.end()) +    return; + +  // Remove all of the entries in the BB->val map.  This involves removing +  // instructions from the reverse map. +  NonLocalDepInfo &PInfo = It->second.NonLocalDeps; + +  for (unsigned i = 0, e = PInfo.size(); i != e; ++i) { +    Instruction *Target = PInfo[i].getResult().getInst(); +    if (!Target) +      continue; // Ignore non-local dep results. +    assert(Target->getParent() == PInfo[i].getBB()); + +    // Eliminating the dirty entry from 'Cache', so update the reverse info. +    RemoveFromReverseMap(ReverseNonLocalPtrDeps, Target, P); +  } + +  // Remove P from NonLocalPointerDeps (which deletes NonLocalDepInfo). +  NonLocalPointerDeps.erase(It); +} + +void MemoryDependenceResults::invalidateCachedPointerInfo(Value *Ptr) { +  // If Ptr isn't really a pointer, just ignore it. +  if (!Ptr->getType()->isPointerTy()) +    return; +  // Flush store info for the pointer. +  RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, false)); +  // Flush load info for the pointer. +  RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, true)); +  // Invalidate phis that use the pointer. +  PV.invalidateValue(Ptr); +} + +void MemoryDependenceResults::invalidateCachedPredecessors() { +  PredCache.clear(); +} + +void MemoryDependenceResults::removeInstruction(Instruction *RemInst) { +  // Walk through the Non-local dependencies, removing this one as the value +  // for any cached queries. +  NonLocalDepMapType::iterator NLDI = NonLocalDeps.find(RemInst); +  if (NLDI != NonLocalDeps.end()) { +    NonLocalDepInfo &BlockMap = NLDI->second.first; +    for (auto &Entry : BlockMap) +      if (Instruction *Inst = Entry.getResult().getInst()) +        RemoveFromReverseMap(ReverseNonLocalDeps, Inst, RemInst); +    NonLocalDeps.erase(NLDI); +  } + +  // If we have a cached local dependence query for this instruction, remove it. +  LocalDepMapType::iterator LocalDepEntry = LocalDeps.find(RemInst); +  if (LocalDepEntry != LocalDeps.end()) { +    // Remove us from DepInst's reverse set now that the local dep info is gone. +    if (Instruction *Inst = LocalDepEntry->second.getInst()) +      RemoveFromReverseMap(ReverseLocalDeps, Inst, RemInst); + +    // Remove this local dependency info. +    LocalDeps.erase(LocalDepEntry); +  } + +  // If we have any cached pointer dependencies on this instruction, remove +  // them.  If the instruction has non-pointer type, then it can't be a pointer +  // base. + +  // Remove it from both the load info and the store info.  The instruction +  // can't be in either of these maps if it is non-pointer. +  if (RemInst->getType()->isPointerTy()) { +    RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, false)); +    RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(RemInst, true)); +  } + +  // Loop over all of the things that depend on the instruction we're removing. +  SmallVector<std::pair<Instruction *, Instruction *>, 8> ReverseDepsToAdd; + +  // If we find RemInst as a clobber or Def in any of the maps for other values, +  // we need to replace its entry with a dirty version of the instruction after +  // it.  If RemInst is a terminator, we use a null dirty value. +  // +  // Using a dirty version of the instruction after RemInst saves having to scan +  // the entire block to get to this point. +  MemDepResult NewDirtyVal; +  if (!RemInst->isTerminator()) +    NewDirtyVal = MemDepResult::getDirty(&*++RemInst->getIterator()); + +  ReverseDepMapType::iterator ReverseDepIt = ReverseLocalDeps.find(RemInst); +  if (ReverseDepIt != ReverseLocalDeps.end()) { +    // RemInst can't be the terminator if it has local stuff depending on it. +    assert(!ReverseDepIt->second.empty() && !isa<TerminatorInst>(RemInst) && +           "Nothing can locally depend on a terminator"); + +    for (Instruction *InstDependingOnRemInst : ReverseDepIt->second) { +      assert(InstDependingOnRemInst != RemInst && +             "Already removed our local dep info"); + +      LocalDeps[InstDependingOnRemInst] = NewDirtyVal; + +      // Make sure to remember that new things depend on NewDepInst. +      assert(NewDirtyVal.getInst() && +             "There is no way something else can have " +             "a local dep on this if it is a terminator!"); +      ReverseDepsToAdd.push_back( +          std::make_pair(NewDirtyVal.getInst(), InstDependingOnRemInst)); +    } + +    ReverseLocalDeps.erase(ReverseDepIt); + +    // Add new reverse deps after scanning the set, to avoid invalidating the +    // 'ReverseDeps' reference. +    while (!ReverseDepsToAdd.empty()) { +      ReverseLocalDeps[ReverseDepsToAdd.back().first].insert( +          ReverseDepsToAdd.back().second); +      ReverseDepsToAdd.pop_back(); +    } +  } + +  ReverseDepIt = ReverseNonLocalDeps.find(RemInst); +  if (ReverseDepIt != ReverseNonLocalDeps.end()) { +    for (Instruction *I : ReverseDepIt->second) { +      assert(I != RemInst && "Already removed NonLocalDep info for RemInst"); + +      PerInstNLInfo &INLD = NonLocalDeps[I]; +      // The information is now dirty! +      INLD.second = true; + +      for (auto &Entry : INLD.first) { +        if (Entry.getResult().getInst() != RemInst) +          continue; + +        // Convert to a dirty entry for the subsequent instruction. +        Entry.setResult(NewDirtyVal); + +        if (Instruction *NextI = NewDirtyVal.getInst()) +          ReverseDepsToAdd.push_back(std::make_pair(NextI, I)); +      } +    } + +    ReverseNonLocalDeps.erase(ReverseDepIt); + +    // Add new reverse deps after scanning the set, to avoid invalidating 'Set' +    while (!ReverseDepsToAdd.empty()) { +      ReverseNonLocalDeps[ReverseDepsToAdd.back().first].insert( +          ReverseDepsToAdd.back().second); +      ReverseDepsToAdd.pop_back(); +    } +  } + +  // If the instruction is in ReverseNonLocalPtrDeps then it appears as a +  // value in the NonLocalPointerDeps info. +  ReverseNonLocalPtrDepTy::iterator ReversePtrDepIt = +      ReverseNonLocalPtrDeps.find(RemInst); +  if (ReversePtrDepIt != ReverseNonLocalPtrDeps.end()) { +    SmallVector<std::pair<Instruction *, ValueIsLoadPair>, 8> +        ReversePtrDepsToAdd; + +    for (ValueIsLoadPair P : ReversePtrDepIt->second) { +      assert(P.getPointer() != RemInst && +             "Already removed NonLocalPointerDeps info for RemInst"); + +      NonLocalDepInfo &NLPDI = NonLocalPointerDeps[P].NonLocalDeps; + +      // The cache is not valid for any specific block anymore. +      NonLocalPointerDeps[P].Pair = BBSkipFirstBlockPair(); + +      // Update any entries for RemInst to use the instruction after it. +      for (auto &Entry : NLPDI) { +        if (Entry.getResult().getInst() != RemInst) +          continue; + +        // Convert to a dirty entry for the subsequent instruction. +        Entry.setResult(NewDirtyVal); + +        if (Instruction *NewDirtyInst = NewDirtyVal.getInst()) +          ReversePtrDepsToAdd.push_back(std::make_pair(NewDirtyInst, P)); +      } + +      // Re-sort the NonLocalDepInfo.  Changing the dirty entry to its +      // subsequent value may invalidate the sortedness. +      llvm::sort(NLPDI.begin(), NLPDI.end()); +    } + +    ReverseNonLocalPtrDeps.erase(ReversePtrDepIt); + +    while (!ReversePtrDepsToAdd.empty()) { +      ReverseNonLocalPtrDeps[ReversePtrDepsToAdd.back().first].insert( +          ReversePtrDepsToAdd.back().second); +      ReversePtrDepsToAdd.pop_back(); +    } +  } + +  // Invalidate phis that use the removed instruction. +  PV.invalidateValue(RemInst); + +  assert(!NonLocalDeps.count(RemInst) && "RemInst got reinserted?"); +  LLVM_DEBUG(verifyRemoved(RemInst)); +} + +/// Verify that the specified instruction does not occur in our internal data +/// structures. +/// +/// This function verifies by asserting in debug builds. +void MemoryDependenceResults::verifyRemoved(Instruction *D) const { +#ifndef NDEBUG +  for (const auto &DepKV : LocalDeps) { +    assert(DepKV.first != D && "Inst occurs in data structures"); +    assert(DepKV.second.getInst() != D && "Inst occurs in data structures"); +  } + +  for (const auto &DepKV : NonLocalPointerDeps) { +    assert(DepKV.first.getPointer() != D && "Inst occurs in NLPD map key"); +    for (const auto &Entry : DepKV.second.NonLocalDeps) +      assert(Entry.getResult().getInst() != D && "Inst occurs as NLPD value"); +  } + +  for (const auto &DepKV : NonLocalDeps) { +    assert(DepKV.first != D && "Inst occurs in data structures"); +    const PerInstNLInfo &INLD = DepKV.second; +    for (const auto &Entry : INLD.first) +      assert(Entry.getResult().getInst() != D && +             "Inst occurs in data structures"); +  } + +  for (const auto &DepKV : ReverseLocalDeps) { +    assert(DepKV.first != D && "Inst occurs in data structures"); +    for (Instruction *Inst : DepKV.second) +      assert(Inst != D && "Inst occurs in data structures"); +  } + +  for (const auto &DepKV : ReverseNonLocalDeps) { +    assert(DepKV.first != D && "Inst occurs in data structures"); +    for (Instruction *Inst : DepKV.second) +      assert(Inst != D && "Inst occurs in data structures"); +  } + +  for (const auto &DepKV : ReverseNonLocalPtrDeps) { +    assert(DepKV.first != D && "Inst occurs in rev NLPD map"); + +    for (ValueIsLoadPair P : DepKV.second) +      assert(P != ValueIsLoadPair(D, false) && P != ValueIsLoadPair(D, true) && +             "Inst occurs in ReverseNonLocalPtrDeps map"); +  } +#endif +} + +AnalysisKey MemoryDependenceAnalysis::Key; + +MemoryDependenceResults +MemoryDependenceAnalysis::run(Function &F, FunctionAnalysisManager &AM) { +  auto &AA = AM.getResult<AAManager>(F); +  auto &AC = AM.getResult<AssumptionAnalysis>(F); +  auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &PV = AM.getResult<PhiValuesAnalysis>(F); +  return MemoryDependenceResults(AA, AC, TLI, DT, PV); +} + +char MemoryDependenceWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(MemoryDependenceWrapperPass, "memdep", +                      "Memory Dependence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PhiValuesWrapperPass) +INITIALIZE_PASS_END(MemoryDependenceWrapperPass, "memdep", +                    "Memory Dependence Analysis", false, true) + +MemoryDependenceWrapperPass::MemoryDependenceWrapperPass() : FunctionPass(ID) { +  initializeMemoryDependenceWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +MemoryDependenceWrapperPass::~MemoryDependenceWrapperPass() = default; + +void MemoryDependenceWrapperPass::releaseMemory() { +  MemDep.reset(); +} + +void MemoryDependenceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<AssumptionCacheTracker>(); +  AU.addRequired<DominatorTreeWrapperPass>(); +  AU.addRequired<PhiValuesWrapperPass>(); +  AU.addRequiredTransitive<AAResultsWrapperPass>(); +  AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); +} + +bool MemoryDependenceResults::invalidate(Function &F, const PreservedAnalyses &PA, +                               FunctionAnalysisManager::Invalidator &Inv) { +  // Check whether our analysis is preserved. +  auto PAC = PA.getChecker<MemoryDependenceAnalysis>(); +  if (!PAC.preserved() && !PAC.preservedSet<AllAnalysesOn<Function>>()) +    // If not, give up now. +    return true; + +  // Check whether the analyses we depend on became invalid for any reason. +  if (Inv.invalidate<AAManager>(F, PA) || +      Inv.invalidate<AssumptionAnalysis>(F, PA) || +      Inv.invalidate<DominatorTreeAnalysis>(F, PA) || +      Inv.invalidate<PhiValuesAnalysis>(F, PA)) +    return true; + +  // Otherwise this analysis result remains valid. +  return false; +} + +unsigned MemoryDependenceResults::getDefaultBlockScanLimit() const { +  return BlockScanLimit; +} + +bool MemoryDependenceWrapperPass::runOnFunction(Function &F) { +  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); +  auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +  auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); +  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  auto &PV = getAnalysis<PhiValuesWrapperPass>().getResult(); +  MemDep.emplace(AA, AC, TLI, DT, PV); +  return false; +} diff --git a/contrib/llvm/lib/Analysis/MemoryLocation.cpp b/contrib/llvm/lib/Analysis/MemoryLocation.cpp new file mode 100644 index 000000000000..55924db284ec --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemoryLocation.cpp @@ -0,0 +1,174 @@ +//===- MemoryLocation.cpp - Memory location descriptions -------------------==// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +using namespace llvm; + +MemoryLocation MemoryLocation::get(const LoadInst *LI) { +  AAMDNodes AATags; +  LI->getAAMetadata(AATags); +  const auto &DL = LI->getModule()->getDataLayout(); + +  return MemoryLocation(LI->getPointerOperand(), +                        DL.getTypeStoreSize(LI->getType()), AATags); +} + +MemoryLocation MemoryLocation::get(const StoreInst *SI) { +  AAMDNodes AATags; +  SI->getAAMetadata(AATags); +  const auto &DL = SI->getModule()->getDataLayout(); + +  return MemoryLocation(SI->getPointerOperand(), +                        DL.getTypeStoreSize(SI->getValueOperand()->getType()), +                        AATags); +} + +MemoryLocation MemoryLocation::get(const VAArgInst *VI) { +  AAMDNodes AATags; +  VI->getAAMetadata(AATags); + +  return MemoryLocation(VI->getPointerOperand(), UnknownSize, AATags); +} + +MemoryLocation MemoryLocation::get(const AtomicCmpXchgInst *CXI) { +  AAMDNodes AATags; +  CXI->getAAMetadata(AATags); +  const auto &DL = CXI->getModule()->getDataLayout(); + +  return MemoryLocation( +      CXI->getPointerOperand(), +      DL.getTypeStoreSize(CXI->getCompareOperand()->getType()), AATags); +} + +MemoryLocation MemoryLocation::get(const AtomicRMWInst *RMWI) { +  AAMDNodes AATags; +  RMWI->getAAMetadata(AATags); +  const auto &DL = RMWI->getModule()->getDataLayout(); + +  return MemoryLocation(RMWI->getPointerOperand(), +                        DL.getTypeStoreSize(RMWI->getValOperand()->getType()), +                        AATags); +} + +MemoryLocation MemoryLocation::getForSource(const MemTransferInst *MTI) { +  return getForSource(cast<AnyMemTransferInst>(MTI)); +} + +MemoryLocation MemoryLocation::getForSource(const AtomicMemTransferInst *MTI) { +  return getForSource(cast<AnyMemTransferInst>(MTI)); +} + +MemoryLocation MemoryLocation::getForSource(const AnyMemTransferInst *MTI) { +  uint64_t Size = UnknownSize; +  if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) +    Size = C->getValue().getZExtValue(); + +  // memcpy/memmove can have AA tags. For memcpy, they apply +  // to both the source and the destination. +  AAMDNodes AATags; +  MTI->getAAMetadata(AATags); + +  return MemoryLocation(MTI->getRawSource(), Size, AATags); +} + +MemoryLocation MemoryLocation::getForDest(const MemIntrinsic *MI) { +  return getForDest(cast<AnyMemIntrinsic>(MI)); +} + +MemoryLocation MemoryLocation::getForDest(const AtomicMemIntrinsic *MI) { +  return getForDest(cast<AnyMemIntrinsic>(MI)); +} + +MemoryLocation MemoryLocation::getForDest(const AnyMemIntrinsic *MI) { +  uint64_t Size = UnknownSize; +  if (ConstantInt *C = dyn_cast<ConstantInt>(MI->getLength())) +    Size = C->getValue().getZExtValue(); + +  // memcpy/memmove can have AA tags. For memcpy, they apply +  // to both the source and the destination. +  AAMDNodes AATags; +  MI->getAAMetadata(AATags); + +  return MemoryLocation(MI->getRawDest(), Size, AATags); +} + +MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, +                                              unsigned ArgIdx, +                                              const TargetLibraryInfo &TLI) { +  AAMDNodes AATags; +  CS->getAAMetadata(AATags); +  const Value *Arg = CS.getArgument(ArgIdx); + +  // We may be able to produce an exact size for known intrinsics. +  if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { +    const DataLayout &DL = II->getModule()->getDataLayout(); + +    switch (II->getIntrinsicID()) { +    default: +      break; +    case Intrinsic::memset: +    case Intrinsic::memcpy: +    case Intrinsic::memmove: +      assert((ArgIdx == 0 || ArgIdx == 1) && +             "Invalid argument index for memory intrinsic"); +      if (ConstantInt *LenCI = dyn_cast<ConstantInt>(II->getArgOperand(2))) +        return MemoryLocation(Arg, LenCI->getZExtValue(), AATags); +      break; + +    case Intrinsic::lifetime_start: +    case Intrinsic::lifetime_end: +    case Intrinsic::invariant_start: +      assert(ArgIdx == 1 && "Invalid argument index"); +      return MemoryLocation( +          Arg, cast<ConstantInt>(II->getArgOperand(0))->getZExtValue(), AATags); + +    case Intrinsic::invariant_end: +      assert(ArgIdx == 2 && "Invalid argument index"); +      return MemoryLocation( +          Arg, cast<ConstantInt>(II->getArgOperand(1))->getZExtValue(), AATags); + +    case Intrinsic::arm_neon_vld1: +      assert(ArgIdx == 0 && "Invalid argument index"); +      // LLVM's vld1 and vst1 intrinsics currently only support a single +      // vector register. +      return MemoryLocation(Arg, DL.getTypeStoreSize(II->getType()), AATags); + +    case Intrinsic::arm_neon_vst1: +      assert(ArgIdx == 0 && "Invalid argument index"); +      return MemoryLocation( +          Arg, DL.getTypeStoreSize(II->getArgOperand(1)->getType()), AATags); +    } +  } + +  // We can bound the aliasing properties of memset_pattern16 just as we can +  // for memcpy/memset.  This is particularly important because the +  // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 +  // whenever possible. +  LibFunc F; +  if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && +      F == LibFunc_memset_pattern16 && TLI.has(F)) { +    assert((ArgIdx == 0 || ArgIdx == 1) && +           "Invalid argument index for memset_pattern16"); +    if (ArgIdx == 1) +      return MemoryLocation(Arg, 16, AATags); +    if (const ConstantInt *LenCI = dyn_cast<ConstantInt>(CS.getArgument(2))) +      return MemoryLocation(Arg, LenCI->getZExtValue(), AATags); +  } +  // FIXME: Handle memset_pattern4 and memset_pattern8 also. + +  return MemoryLocation(CS.getArgument(ArgIdx), UnknownSize, AATags); +} diff --git a/contrib/llvm/lib/Analysis/MemorySSA.cpp b/contrib/llvm/lib/Analysis/MemorySSA.cpp new file mode 100644 index 000000000000..b38c0c4f1439 --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemorySSA.cpp @@ -0,0 +1,2191 @@ +//===- MemorySSA.cpp - Memory SSA Builder ---------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the MemorySSA class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Use.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <memory> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "memoryssa" + +INITIALIZE_PASS_BEGIN(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, +                      true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(MemorySSAWrapperPass, "memoryssa", "Memory SSA", false, +                    true) + +INITIALIZE_PASS_BEGIN(MemorySSAPrinterLegacyPass, "print-memoryssa", +                      "Memory SSA Printer", false, false) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_END(MemorySSAPrinterLegacyPass, "print-memoryssa", +                    "Memory SSA Printer", false, false) + +static cl::opt<unsigned> MaxCheckLimit( +    "memssa-check-limit", cl::Hidden, cl::init(100), +    cl::desc("The maximum number of stores/phis MemorySSA" +             "will consider trying to walk past (default = 100)")); + +static cl::opt<bool> +    VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, +                    cl::desc("Verify MemorySSA in legacy printer pass.")); + +namespace llvm { + +/// An assembly annotator class to print Memory SSA information in +/// comments. +class MemorySSAAnnotatedWriter : public AssemblyAnnotationWriter { +  friend class MemorySSA; + +  const MemorySSA *MSSA; + +public: +  MemorySSAAnnotatedWriter(const MemorySSA *M) : MSSA(M) {} + +  void emitBasicBlockStartAnnot(const BasicBlock *BB, +                                formatted_raw_ostream &OS) override { +    if (MemoryAccess *MA = MSSA->getMemoryAccess(BB)) +      OS << "; " << *MA << "\n"; +  } + +  void emitInstructionAnnot(const Instruction *I, +                            formatted_raw_ostream &OS) override { +    if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) +      OS << "; " << *MA << "\n"; +  } +}; + +} // end namespace llvm + +namespace { + +/// Our current alias analysis API differentiates heavily between calls and +/// non-calls, and functions called on one usually assert on the other. +/// This class encapsulates the distinction to simplify other code that wants +/// "Memory affecting instructions and related data" to use as a key. +/// For example, this class is used as a densemap key in the use optimizer. +class MemoryLocOrCall { +public: +  bool IsCall = false; + +  MemoryLocOrCall() = default; +  MemoryLocOrCall(MemoryUseOrDef *MUD) +      : MemoryLocOrCall(MUD->getMemoryInst()) {} +  MemoryLocOrCall(const MemoryUseOrDef *MUD) +      : MemoryLocOrCall(MUD->getMemoryInst()) {} + +  MemoryLocOrCall(Instruction *Inst) { +    if (ImmutableCallSite(Inst)) { +      IsCall = true; +      CS = ImmutableCallSite(Inst); +    } else { +      IsCall = false; +      // There is no such thing as a memorylocation for a fence inst, and it is +      // unique in that regard. +      if (!isa<FenceInst>(Inst)) +        Loc = MemoryLocation::get(Inst); +    } +  } + +  explicit MemoryLocOrCall(const MemoryLocation &Loc) : Loc(Loc) {} + +  ImmutableCallSite getCS() const { +    assert(IsCall); +    return CS; +  } + +  MemoryLocation getLoc() const { +    assert(!IsCall); +    return Loc; +  } + +  bool operator==(const MemoryLocOrCall &Other) const { +    if (IsCall != Other.IsCall) +      return false; + +    if (!IsCall) +      return Loc == Other.Loc; + +    if (CS.getCalledValue() != Other.CS.getCalledValue()) +      return false; + +    return CS.arg_size() == Other.CS.arg_size() && +           std::equal(CS.arg_begin(), CS.arg_end(), Other.CS.arg_begin()); +  } + +private: +  union { +    ImmutableCallSite CS; +    MemoryLocation Loc; +  }; +}; + +} // end anonymous namespace + +namespace llvm { + +template <> struct DenseMapInfo<MemoryLocOrCall> { +  static inline MemoryLocOrCall getEmptyKey() { +    return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getEmptyKey()); +  } + +  static inline MemoryLocOrCall getTombstoneKey() { +    return MemoryLocOrCall(DenseMapInfo<MemoryLocation>::getTombstoneKey()); +  } + +  static unsigned getHashValue(const MemoryLocOrCall &MLOC) { +    if (!MLOC.IsCall) +      return hash_combine( +          MLOC.IsCall, +          DenseMapInfo<MemoryLocation>::getHashValue(MLOC.getLoc())); + +    hash_code hash = +        hash_combine(MLOC.IsCall, DenseMapInfo<const Value *>::getHashValue( +                                      MLOC.getCS().getCalledValue())); + +    for (const Value *Arg : MLOC.getCS().args()) +      hash = hash_combine(hash, DenseMapInfo<const Value *>::getHashValue(Arg)); +    return hash; +  } + +  static bool isEqual(const MemoryLocOrCall &LHS, const MemoryLocOrCall &RHS) { +    return LHS == RHS; +  } +}; + +} // end namespace llvm + +/// This does one-way checks to see if Use could theoretically be hoisted above +/// MayClobber. This will not check the other way around. +/// +/// This assumes that, for the purposes of MemorySSA, Use comes directly after +/// MayClobber, with no potentially clobbering operations in between them. +/// (Where potentially clobbering ops are memory barriers, aliased stores, etc.) +static bool areLoadsReorderable(const LoadInst *Use, +                                const LoadInst *MayClobber) { +  bool VolatileUse = Use->isVolatile(); +  bool VolatileClobber = MayClobber->isVolatile(); +  // Volatile operations may never be reordered with other volatile operations. +  if (VolatileUse && VolatileClobber) +    return false; +  // Otherwise, volatile doesn't matter here. From the language reference: +  // 'optimizers may change the order of volatile operations relative to +  // non-volatile operations.'" + +  // If a load is seq_cst, it cannot be moved above other loads. If its ordering +  // is weaker, it can be moved above other loads. We just need to be sure that +  // MayClobber isn't an acquire load, because loads can't be moved above +  // acquire loads. +  // +  // Note that this explicitly *does* allow the free reordering of monotonic (or +  // weaker) loads of the same address. +  bool SeqCstUse = Use->getOrdering() == AtomicOrdering::SequentiallyConsistent; +  bool MayClobberIsAcquire = isAtLeastOrStrongerThan(MayClobber->getOrdering(), +                                                     AtomicOrdering::Acquire); +  return !(SeqCstUse || MayClobberIsAcquire); +} + +namespace { + +struct ClobberAlias { +  bool IsClobber; +  Optional<AliasResult> AR; +}; + +} // end anonymous namespace + +// Return a pair of {IsClobber (bool), AR (AliasResult)}. It relies on AR being +// ignored if IsClobber = false. +static ClobberAlias instructionClobbersQuery(MemoryDef *MD, +                                             const MemoryLocation &UseLoc, +                                             const Instruction *UseInst, +                                             AliasAnalysis &AA) { +  Instruction *DefInst = MD->getMemoryInst(); +  assert(DefInst && "Defining instruction not actually an instruction"); +  ImmutableCallSite UseCS(UseInst); +  Optional<AliasResult> AR; + +  if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) { +    // These intrinsics will show up as affecting memory, but they are just +    // markers, mostly. +    // +    // FIXME: We probably don't actually want MemorySSA to model these at all +    // (including creating MemoryAccesses for them): we just end up inventing +    // clobbers where they don't really exist at all. Please see D43269 for +    // context. +    switch (II->getIntrinsicID()) { +    case Intrinsic::lifetime_start: +      if (UseCS) +        return {false, NoAlias}; +      AR = AA.alias(MemoryLocation(II->getArgOperand(1)), UseLoc); +      return {AR != NoAlias, AR}; +    case Intrinsic::lifetime_end: +    case Intrinsic::invariant_start: +    case Intrinsic::invariant_end: +    case Intrinsic::assume: +      return {false, NoAlias}; +    default: +      break; +    } +  } + +  if (UseCS) { +    ModRefInfo I = AA.getModRefInfo(DefInst, UseCS); +    AR = isMustSet(I) ? MustAlias : MayAlias; +    return {isModOrRefSet(I), AR}; +  } + +  if (auto *DefLoad = dyn_cast<LoadInst>(DefInst)) +    if (auto *UseLoad = dyn_cast<LoadInst>(UseInst)) +      return {!areLoadsReorderable(UseLoad, DefLoad), MayAlias}; + +  ModRefInfo I = AA.getModRefInfo(DefInst, UseLoc); +  AR = isMustSet(I) ? MustAlias : MayAlias; +  return {isModSet(I), AR}; +} + +static ClobberAlias instructionClobbersQuery(MemoryDef *MD, +                                             const MemoryUseOrDef *MU, +                                             const MemoryLocOrCall &UseMLOC, +                                             AliasAnalysis &AA) { +  // FIXME: This is a temporary hack to allow a single instructionClobbersQuery +  // to exist while MemoryLocOrCall is pushed through places. +  if (UseMLOC.IsCall) +    return instructionClobbersQuery(MD, MemoryLocation(), MU->getMemoryInst(), +                                    AA); +  return instructionClobbersQuery(MD, UseMLOC.getLoc(), MU->getMemoryInst(), +                                  AA); +} + +// Return true when MD may alias MU, return false otherwise. +bool MemorySSAUtil::defClobbersUseOrDef(MemoryDef *MD, const MemoryUseOrDef *MU, +                                        AliasAnalysis &AA) { +  return instructionClobbersQuery(MD, MU, MemoryLocOrCall(MU), AA).IsClobber; +} + +namespace { + +struct UpwardsMemoryQuery { +  // True if our original query started off as a call +  bool IsCall = false; +  // The pointer location we started the query with. This will be empty if +  // IsCall is true. +  MemoryLocation StartingLoc; +  // This is the instruction we were querying about. +  const Instruction *Inst = nullptr; +  // The MemoryAccess we actually got called with, used to test local domination +  const MemoryAccess *OriginalAccess = nullptr; +  Optional<AliasResult> AR = MayAlias; + +  UpwardsMemoryQuery() = default; + +  UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) +      : IsCall(ImmutableCallSite(Inst)), Inst(Inst), OriginalAccess(Access) { +    if (!IsCall) +      StartingLoc = MemoryLocation::get(Inst); +  } +}; + +} // end anonymous namespace + +static bool lifetimeEndsAt(MemoryDef *MD, const MemoryLocation &Loc, +                           AliasAnalysis &AA) { +  Instruction *Inst = MD->getMemoryInst(); +  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { +    switch (II->getIntrinsicID()) { +    case Intrinsic::lifetime_end: +      return AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), Loc); +    default: +      return false; +    } +  } +  return false; +} + +static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysis &AA, +                                                   const Instruction *I) { +  // If the memory can't be changed, then loads of the memory can't be +  // clobbered. +  return isa<LoadInst>(I) && (I->getMetadata(LLVMContext::MD_invariant_load) || +                              AA.pointsToConstantMemory(cast<LoadInst>(I)-> +                                                          getPointerOperand())); +} + +/// Verifies that `Start` is clobbered by `ClobberAt`, and that nothing +/// inbetween `Start` and `ClobberAt` can clobbers `Start`. +/// +/// This is meant to be as simple and self-contained as possible. Because it +/// uses no cache, etc., it can be relatively expensive. +/// +/// \param Start     The MemoryAccess that we want to walk from. +/// \param ClobberAt A clobber for Start. +/// \param StartLoc  The MemoryLocation for Start. +/// \param MSSA      The MemorySSA isntance that Start and ClobberAt belong to. +/// \param Query     The UpwardsMemoryQuery we used for our search. +/// \param AA        The AliasAnalysis we used for our search. +static void LLVM_ATTRIBUTE_UNUSED +checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, +                   const MemoryLocation &StartLoc, const MemorySSA &MSSA, +                   const UpwardsMemoryQuery &Query, AliasAnalysis &AA) { +  assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?"); + +  if (MSSA.isLiveOnEntryDef(Start)) { +    assert(MSSA.isLiveOnEntryDef(ClobberAt) && +           "liveOnEntry must clobber itself"); +    return; +  } + +  bool FoundClobber = false; +  DenseSet<MemoryAccessPair> VisitedPhis; +  SmallVector<MemoryAccessPair, 8> Worklist; +  Worklist.emplace_back(Start, StartLoc); +  // Walk all paths from Start to ClobberAt, while looking for clobbers. If one +  // is found, complain. +  while (!Worklist.empty()) { +    MemoryAccessPair MAP = Worklist.pop_back_val(); +    // All we care about is that nothing from Start to ClobberAt clobbers Start. +    // We learn nothing from revisiting nodes. +    if (!VisitedPhis.insert(MAP).second) +      continue; + +    for (MemoryAccess *MA : def_chain(MAP.first)) { +      if (MA == ClobberAt) { +        if (auto *MD = dyn_cast<MemoryDef>(MA)) { +          // instructionClobbersQuery isn't essentially free, so don't use `|=`, +          // since it won't let us short-circuit. +          // +          // Also, note that this can't be hoisted out of the `Worklist` loop, +          // since MD may only act as a clobber for 1 of N MemoryLocations. +          FoundClobber = FoundClobber || MSSA.isLiveOnEntryDef(MD); +          if (!FoundClobber) { +            ClobberAlias CA = +                instructionClobbersQuery(MD, MAP.second, Query.Inst, AA); +            if (CA.IsClobber) { +              FoundClobber = true; +              // Not used: CA.AR; +            } +          } +        } +        break; +      } + +      // We should never hit liveOnEntry, unless it's the clobber. +      assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?"); + +      if (auto *MD = dyn_cast<MemoryDef>(MA)) { +        (void)MD; +        assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) +                    .IsClobber && +               "Found clobber before reaching ClobberAt!"); +        continue; +      } + +      assert(isa<MemoryPhi>(MA)); +      Worklist.append(upward_defs_begin({MA, MAP.second}), upward_defs_end()); +    } +  } + +  // If ClobberAt is a MemoryPhi, we can assume something above it acted as a +  // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point. +  assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) && +         "ClobberAt never acted as a clobber"); +} + +namespace { + +/// Our algorithm for walking (and trying to optimize) clobbers, all wrapped up +/// in one class. +class ClobberWalker { +  /// Save a few bytes by using unsigned instead of size_t. +  using ListIndex = unsigned; + +  /// Represents a span of contiguous MemoryDefs, potentially ending in a +  /// MemoryPhi. +  struct DefPath { +    MemoryLocation Loc; +    // Note that, because we always walk in reverse, Last will always dominate +    // First. Also note that First and Last are inclusive. +    MemoryAccess *First; +    MemoryAccess *Last; +    Optional<ListIndex> Previous; + +    DefPath(const MemoryLocation &Loc, MemoryAccess *First, MemoryAccess *Last, +            Optional<ListIndex> Previous) +        : Loc(Loc), First(First), Last(Last), Previous(Previous) {} + +    DefPath(const MemoryLocation &Loc, MemoryAccess *Init, +            Optional<ListIndex> Previous) +        : DefPath(Loc, Init, Init, Previous) {} +  }; + +  const MemorySSA &MSSA; +  AliasAnalysis &AA; +  DominatorTree &DT; +  UpwardsMemoryQuery *Query; + +  // Phi optimization bookkeeping +  SmallVector<DefPath, 32> Paths; +  DenseSet<ConstMemoryAccessPair> VisitedPhis; + +  /// Find the nearest def or phi that `From` can legally be optimized to. +  const MemoryAccess *getWalkTarget(const MemoryPhi *From) const { +    assert(From->getNumOperands() && "Phi with no operands?"); + +    BasicBlock *BB = From->getBlock(); +    MemoryAccess *Result = MSSA.getLiveOnEntryDef(); +    DomTreeNode *Node = DT.getNode(BB); +    while ((Node = Node->getIDom())) { +      auto *Defs = MSSA.getBlockDefs(Node->getBlock()); +      if (Defs) +        return &*Defs->rbegin(); +    } +    return Result; +  } + +  /// Result of calling walkToPhiOrClobber. +  struct UpwardsWalkResult { +    /// The "Result" of the walk. Either a clobber, the last thing we walked, or +    /// both. Include alias info when clobber found. +    MemoryAccess *Result; +    bool IsKnownClobber; +    Optional<AliasResult> AR; +  }; + +  /// Walk to the next Phi or Clobber in the def chain starting at Desc.Last. +  /// This will update Desc.Last as it walks. It will (optionally) also stop at +  /// StopAt. +  /// +  /// This does not test for whether StopAt is a clobber +  UpwardsWalkResult +  walkToPhiOrClobber(DefPath &Desc, +                     const MemoryAccess *StopAt = nullptr) const { +    assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world"); + +    for (MemoryAccess *Current : def_chain(Desc.Last)) { +      Desc.Last = Current; +      if (Current == StopAt) +        return {Current, false, MayAlias}; + +      if (auto *MD = dyn_cast<MemoryDef>(Current)) { +        if (MSSA.isLiveOnEntryDef(MD)) +          return {MD, true, MustAlias}; +        ClobberAlias CA = +            instructionClobbersQuery(MD, Desc.Loc, Query->Inst, AA); +        if (CA.IsClobber) +          return {MD, true, CA.AR}; +      } +    } + +    assert(isa<MemoryPhi>(Desc.Last) && +           "Ended at a non-clobber that's not a phi?"); +    return {Desc.Last, false, MayAlias}; +  } + +  void addSearches(MemoryPhi *Phi, SmallVectorImpl<ListIndex> &PausedSearches, +                   ListIndex PriorNode) { +    auto UpwardDefs = make_range(upward_defs_begin({Phi, Paths[PriorNode].Loc}), +                                 upward_defs_end()); +    for (const MemoryAccessPair &P : UpwardDefs) { +      PausedSearches.push_back(Paths.size()); +      Paths.emplace_back(P.second, P.first, PriorNode); +    } +  } + +  /// Represents a search that terminated after finding a clobber. This clobber +  /// may or may not be present in the path of defs from LastNode..SearchStart, +  /// since it may have been retrieved from cache. +  struct TerminatedPath { +    MemoryAccess *Clobber; +    ListIndex LastNode; +  }; + +  /// Get an access that keeps us from optimizing to the given phi. +  /// +  /// PausedSearches is an array of indices into the Paths array. Its incoming +  /// value is the indices of searches that stopped at the last phi optimization +  /// target. It's left in an unspecified state. +  /// +  /// If this returns None, NewPaused is a vector of searches that terminated +  /// at StopWhere. Otherwise, NewPaused is left in an unspecified state. +  Optional<TerminatedPath> +  getBlockingAccess(const MemoryAccess *StopWhere, +                    SmallVectorImpl<ListIndex> &PausedSearches, +                    SmallVectorImpl<ListIndex> &NewPaused, +                    SmallVectorImpl<TerminatedPath> &Terminated) { +    assert(!PausedSearches.empty() && "No searches to continue?"); + +    // BFS vs DFS really doesn't make a difference here, so just do a DFS with +    // PausedSearches as our stack. +    while (!PausedSearches.empty()) { +      ListIndex PathIndex = PausedSearches.pop_back_val(); +      DefPath &Node = Paths[PathIndex]; + +      // If we've already visited this path with this MemoryLocation, we don't +      // need to do so again. +      // +      // NOTE: That we just drop these paths on the ground makes caching +      // behavior sporadic. e.g. given a diamond: +      //  A +      // B C +      //  D +      // +      // ...If we walk D, B, A, C, we'll only cache the result of phi +      // optimization for A, B, and D; C will be skipped because it dies here. +      // This arguably isn't the worst thing ever, since: +      //   - We generally query things in a top-down order, so if we got below D +      //     without needing cache entries for {C, MemLoc}, then chances are +      //     that those cache entries would end up ultimately unused. +      //   - We still cache things for A, so C only needs to walk up a bit. +      // If this behavior becomes problematic, we can fix without a ton of extra +      // work. +      if (!VisitedPhis.insert({Node.Last, Node.Loc}).second) +        continue; + +      UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere); +      if (Res.IsKnownClobber) { +        assert(Res.Result != StopWhere); +        // If this wasn't a cache hit, we hit a clobber when walking. That's a +        // failure. +        TerminatedPath Term{Res.Result, PathIndex}; +        if (!MSSA.dominates(Res.Result, StopWhere)) +          return Term; + +        // Otherwise, it's a valid thing to potentially optimize to. +        Terminated.push_back(Term); +        continue; +      } + +      if (Res.Result == StopWhere) { +        // We've hit our target. Save this path off for if we want to continue +        // walking. +        NewPaused.push_back(PathIndex); +        continue; +      } + +      assert(!MSSA.isLiveOnEntryDef(Res.Result) && "liveOnEntry is a clobber"); +      addSearches(cast<MemoryPhi>(Res.Result), PausedSearches, PathIndex); +    } + +    return None; +  } + +  template <typename T, typename Walker> +  struct generic_def_path_iterator +      : public iterator_facade_base<generic_def_path_iterator<T, Walker>, +                                    std::forward_iterator_tag, T *> { +    generic_def_path_iterator() = default; +    generic_def_path_iterator(Walker *W, ListIndex N) : W(W), N(N) {} + +    T &operator*() const { return curNode(); } + +    generic_def_path_iterator &operator++() { +      N = curNode().Previous; +      return *this; +    } + +    bool operator==(const generic_def_path_iterator &O) const { +      if (N.hasValue() != O.N.hasValue()) +        return false; +      return !N.hasValue() || *N == *O.N; +    } + +  private: +    T &curNode() const { return W->Paths[*N]; } + +    Walker *W = nullptr; +    Optional<ListIndex> N = None; +  }; + +  using def_path_iterator = generic_def_path_iterator<DefPath, ClobberWalker>; +  using const_def_path_iterator = +      generic_def_path_iterator<const DefPath, const ClobberWalker>; + +  iterator_range<def_path_iterator> def_path(ListIndex From) { +    return make_range(def_path_iterator(this, From), def_path_iterator()); +  } + +  iterator_range<const_def_path_iterator> const_def_path(ListIndex From) const { +    return make_range(const_def_path_iterator(this, From), +                      const_def_path_iterator()); +  } + +  struct OptznResult { +    /// The path that contains our result. +    TerminatedPath PrimaryClobber; +    /// The paths that we can legally cache back from, but that aren't +    /// necessarily the result of the Phi optimization. +    SmallVector<TerminatedPath, 4> OtherClobbers; +  }; + +  ListIndex defPathIndex(const DefPath &N) const { +    // The assert looks nicer if we don't need to do &N +    const DefPath *NP = &N; +    assert(!Paths.empty() && NP >= &Paths.front() && NP <= &Paths.back() && +           "Out of bounds DefPath!"); +    return NP - &Paths.front(); +  } + +  /// Try to optimize a phi as best as we can. Returns a SmallVector of Paths +  /// that act as legal clobbers. Note that this won't return *all* clobbers. +  /// +  /// Phi optimization algorithm tl;dr: +  ///   - Find the earliest def/phi, A, we can optimize to +  ///   - Find if all paths from the starting memory access ultimately reach A +  ///     - If not, optimization isn't possible. +  ///     - Otherwise, walk from A to another clobber or phi, A'. +  ///       - If A' is a def, we're done. +  ///       - If A' is a phi, try to optimize it. +  /// +  /// A path is a series of {MemoryAccess, MemoryLocation} pairs. A path +  /// terminates when a MemoryAccess that clobbers said MemoryLocation is found. +  OptznResult tryOptimizePhi(MemoryPhi *Phi, MemoryAccess *Start, +                             const MemoryLocation &Loc) { +    assert(Paths.empty() && VisitedPhis.empty() && +           "Reset the optimization state."); + +    Paths.emplace_back(Loc, Start, Phi, None); +    // Stores how many "valid" optimization nodes we had prior to calling +    // addSearches/getBlockingAccess. Necessary for caching if we had a blocker. +    auto PriorPathsSize = Paths.size(); + +    SmallVector<ListIndex, 16> PausedSearches; +    SmallVector<ListIndex, 8> NewPaused; +    SmallVector<TerminatedPath, 4> TerminatedPaths; + +    addSearches(Phi, PausedSearches, 0); + +    // Moves the TerminatedPath with the "most dominated" Clobber to the end of +    // Paths. +    auto MoveDominatedPathToEnd = [&](SmallVectorImpl<TerminatedPath> &Paths) { +      assert(!Paths.empty() && "Need a path to move"); +      auto Dom = Paths.begin(); +      for (auto I = std::next(Dom), E = Paths.end(); I != E; ++I) +        if (!MSSA.dominates(I->Clobber, Dom->Clobber)) +          Dom = I; +      auto Last = Paths.end() - 1; +      if (Last != Dom) +        std::iter_swap(Last, Dom); +    }; + +    MemoryPhi *Current = Phi; +    while (true) { +      assert(!MSSA.isLiveOnEntryDef(Current) && +             "liveOnEntry wasn't treated as a clobber?"); + +      const auto *Target = getWalkTarget(Current); +      // If a TerminatedPath doesn't dominate Target, then it wasn't a legal +      // optimization for the prior phi. +      assert(all_of(TerminatedPaths, [&](const TerminatedPath &P) { +        return MSSA.dominates(P.Clobber, Target); +      })); + +      // FIXME: This is broken, because the Blocker may be reported to be +      // liveOnEntry, and we'll happily wait for that to disappear (read: never) +      // For the moment, this is fine, since we do nothing with blocker info. +      if (Optional<TerminatedPath> Blocker = getBlockingAccess( +              Target, PausedSearches, NewPaused, TerminatedPaths)) { + +        // Find the node we started at. We can't search based on N->Last, since +        // we may have gone around a loop with a different MemoryLocation. +        auto Iter = find_if(def_path(Blocker->LastNode), [&](const DefPath &N) { +          return defPathIndex(N) < PriorPathsSize; +        }); +        assert(Iter != def_path_iterator()); + +        DefPath &CurNode = *Iter; +        assert(CurNode.Last == Current); + +        // Two things: +        // A. We can't reliably cache all of NewPaused back. Consider a case +        //    where we have two paths in NewPaused; one of which can't optimize +        //    above this phi, whereas the other can. If we cache the second path +        //    back, we'll end up with suboptimal cache entries. We can handle +        //    cases like this a bit better when we either try to find all +        //    clobbers that block phi optimization, or when our cache starts +        //    supporting unfinished searches. +        // B. We can't reliably cache TerminatedPaths back here without doing +        //    extra checks; consider a case like: +        //       T +        //      / \ +        //     D   C +        //      \ / +        //       S +        //    Where T is our target, C is a node with a clobber on it, D is a +        //    diamond (with a clobber *only* on the left or right node, N), and +        //    S is our start. Say we walk to D, through the node opposite N +        //    (read: ignoring the clobber), and see a cache entry in the top +        //    node of D. That cache entry gets put into TerminatedPaths. We then +        //    walk up to C (N is later in our worklist), find the clobber, and +        //    quit. If we append TerminatedPaths to OtherClobbers, we'll cache +        //    the bottom part of D to the cached clobber, ignoring the clobber +        //    in N. Again, this problem goes away if we start tracking all +        //    blockers for a given phi optimization. +        TerminatedPath Result{CurNode.Last, defPathIndex(CurNode)}; +        return {Result, {}}; +      } + +      // If there's nothing left to search, then all paths led to valid clobbers +      // that we got from our cache; pick the nearest to the start, and allow +      // the rest to be cached back. +      if (NewPaused.empty()) { +        MoveDominatedPathToEnd(TerminatedPaths); +        TerminatedPath Result = TerminatedPaths.pop_back_val(); +        return {Result, std::move(TerminatedPaths)}; +      } + +      MemoryAccess *DefChainEnd = nullptr; +      SmallVector<TerminatedPath, 4> Clobbers; +      for (ListIndex Paused : NewPaused) { +        UpwardsWalkResult WR = walkToPhiOrClobber(Paths[Paused]); +        if (WR.IsKnownClobber) +          Clobbers.push_back({WR.Result, Paused}); +        else +          // Micro-opt: If we hit the end of the chain, save it. +          DefChainEnd = WR.Result; +      } + +      if (!TerminatedPaths.empty()) { +        // If we couldn't find the dominating phi/liveOnEntry in the above loop, +        // do it now. +        if (!DefChainEnd) +          for (auto *MA : def_chain(const_cast<MemoryAccess *>(Target))) +            DefChainEnd = MA; + +        // If any of the terminated paths don't dominate the phi we'll try to +        // optimize, we need to figure out what they are and quit. +        const BasicBlock *ChainBB = DefChainEnd->getBlock(); +        for (const TerminatedPath &TP : TerminatedPaths) { +          // Because we know that DefChainEnd is as "high" as we can go, we +          // don't need local dominance checks; BB dominance is sufficient. +          if (DT.dominates(ChainBB, TP.Clobber->getBlock())) +            Clobbers.push_back(TP); +        } +      } + +      // If we have clobbers in the def chain, find the one closest to Current +      // and quit. +      if (!Clobbers.empty()) { +        MoveDominatedPathToEnd(Clobbers); +        TerminatedPath Result = Clobbers.pop_back_val(); +        return {Result, std::move(Clobbers)}; +      } + +      assert(all_of(NewPaused, +                    [&](ListIndex I) { return Paths[I].Last == DefChainEnd; })); + +      // Because liveOnEntry is a clobber, this must be a phi. +      auto *DefChainPhi = cast<MemoryPhi>(DefChainEnd); + +      PriorPathsSize = Paths.size(); +      PausedSearches.clear(); +      for (ListIndex I : NewPaused) +        addSearches(DefChainPhi, PausedSearches, I); +      NewPaused.clear(); + +      Current = DefChainPhi; +    } +  } + +  void verifyOptResult(const OptznResult &R) const { +    assert(all_of(R.OtherClobbers, [&](const TerminatedPath &P) { +      return MSSA.dominates(P.Clobber, R.PrimaryClobber.Clobber); +    })); +  } + +  void resetPhiOptznState() { +    Paths.clear(); +    VisitedPhis.clear(); +  } + +public: +  ClobberWalker(const MemorySSA &MSSA, AliasAnalysis &AA, DominatorTree &DT) +      : MSSA(MSSA), AA(AA), DT(DT) {} + +  /// Finds the nearest clobber for the given query, optimizing phis if +  /// possible. +  MemoryAccess *findClobber(MemoryAccess *Start, UpwardsMemoryQuery &Q) { +    Query = &Q; + +    MemoryAccess *Current = Start; +    // This walker pretends uses don't exist. If we're handed one, silently grab +    // its def. (This has the nice side-effect of ensuring we never cache uses) +    if (auto *MU = dyn_cast<MemoryUse>(Start)) +      Current = MU->getDefiningAccess(); + +    DefPath FirstDesc(Q.StartingLoc, Current, Current, None); +    // Fast path for the overly-common case (no crazy phi optimization +    // necessary) +    UpwardsWalkResult WalkResult = walkToPhiOrClobber(FirstDesc); +    MemoryAccess *Result; +    if (WalkResult.IsKnownClobber) { +      Result = WalkResult.Result; +      Q.AR = WalkResult.AR; +    } else { +      OptznResult OptRes = tryOptimizePhi(cast<MemoryPhi>(FirstDesc.Last), +                                          Current, Q.StartingLoc); +      verifyOptResult(OptRes); +      resetPhiOptznState(); +      Result = OptRes.PrimaryClobber.Clobber; +    } + +#ifdef EXPENSIVE_CHECKS +    checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); +#endif +    return Result; +  } + +  void verify(const MemorySSA *MSSA) { assert(MSSA == &this->MSSA); } +}; + +struct RenamePassData { +  DomTreeNode *DTN; +  DomTreeNode::const_iterator ChildIt; +  MemoryAccess *IncomingVal; + +  RenamePassData(DomTreeNode *D, DomTreeNode::const_iterator It, +                 MemoryAccess *M) +      : DTN(D), ChildIt(It), IncomingVal(M) {} + +  void swap(RenamePassData &RHS) { +    std::swap(DTN, RHS.DTN); +    std::swap(ChildIt, RHS.ChildIt); +    std::swap(IncomingVal, RHS.IncomingVal); +  } +}; + +} // end anonymous namespace + +namespace llvm { + +/// A MemorySSAWalker that does AA walks to disambiguate accesses. It no +/// longer does caching on its own, but the name has been retained for the +/// moment. +class MemorySSA::CachingWalker final : public MemorySSAWalker { +  ClobberWalker Walker; + +  MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); + +public: +  CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *); +  ~CachingWalker() override = default; + +  using MemorySSAWalker::getClobberingMemoryAccess; + +  MemoryAccess *getClobberingMemoryAccess(MemoryAccess *) override; +  MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, +                                          const MemoryLocation &) override; +  void invalidateInfo(MemoryAccess *) override; + +  void verify(const MemorySSA *MSSA) override { +    MemorySSAWalker::verify(MSSA); +    Walker.verify(MSSA); +  } +}; + +} // end namespace llvm + +void MemorySSA::renameSuccessorPhis(BasicBlock *BB, MemoryAccess *IncomingVal, +                                    bool RenameAllUses) { +  // Pass through values to our successors +  for (const BasicBlock *S : successors(BB)) { +    auto It = PerBlockAccesses.find(S); +    // Rename the phi nodes in our successor block +    if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) +      continue; +    AccessList *Accesses = It->second.get(); +    auto *Phi = cast<MemoryPhi>(&Accesses->front()); +    if (RenameAllUses) { +      int PhiIndex = Phi->getBasicBlockIndex(BB); +      assert(PhiIndex != -1 && "Incomplete phi during partial rename"); +      Phi->setIncomingValue(PhiIndex, IncomingVal); +    } else +      Phi->addIncoming(IncomingVal, BB); +  } +} + +/// Rename a single basic block into MemorySSA form. +/// Uses the standard SSA renaming algorithm. +/// \returns The new incoming value. +MemoryAccess *MemorySSA::renameBlock(BasicBlock *BB, MemoryAccess *IncomingVal, +                                     bool RenameAllUses) { +  auto It = PerBlockAccesses.find(BB); +  // Skip most processing if the list is empty. +  if (It != PerBlockAccesses.end()) { +    AccessList *Accesses = It->second.get(); +    for (MemoryAccess &L : *Accesses) { +      if (MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&L)) { +        if (MUD->getDefiningAccess() == nullptr || RenameAllUses) +          MUD->setDefiningAccess(IncomingVal); +        if (isa<MemoryDef>(&L)) +          IncomingVal = &L; +      } else { +        IncomingVal = &L; +      } +    } +  } +  return IncomingVal; +} + +/// This is the standard SSA renaming algorithm. +/// +/// We walk the dominator tree in preorder, renaming accesses, and then filling +/// in phi nodes in our successors. +void MemorySSA::renamePass(DomTreeNode *Root, MemoryAccess *IncomingVal, +                           SmallPtrSetImpl<BasicBlock *> &Visited, +                           bool SkipVisited, bool RenameAllUses) { +  SmallVector<RenamePassData, 32> WorkStack; +  // Skip everything if we already renamed this block and we are skipping. +  // Note: You can't sink this into the if, because we need it to occur +  // regardless of whether we skip blocks or not. +  bool AlreadyVisited = !Visited.insert(Root->getBlock()).second; +  if (SkipVisited && AlreadyVisited) +    return; + +  IncomingVal = renameBlock(Root->getBlock(), IncomingVal, RenameAllUses); +  renameSuccessorPhis(Root->getBlock(), IncomingVal, RenameAllUses); +  WorkStack.push_back({Root, Root->begin(), IncomingVal}); + +  while (!WorkStack.empty()) { +    DomTreeNode *Node = WorkStack.back().DTN; +    DomTreeNode::const_iterator ChildIt = WorkStack.back().ChildIt; +    IncomingVal = WorkStack.back().IncomingVal; + +    if (ChildIt == Node->end()) { +      WorkStack.pop_back(); +    } else { +      DomTreeNode *Child = *ChildIt; +      ++WorkStack.back().ChildIt; +      BasicBlock *BB = Child->getBlock(); +      // Note: You can't sink this into the if, because we need it to occur +      // regardless of whether we skip blocks or not. +      AlreadyVisited = !Visited.insert(BB).second; +      if (SkipVisited && AlreadyVisited) { +        // We already visited this during our renaming, which can happen when +        // being asked to rename multiple blocks. Figure out the incoming val, +        // which is the last def. +        // Incoming value can only change if there is a block def, and in that +        // case, it's the last block def in the list. +        if (auto *BlockDefs = getWritableBlockDefs(BB)) +          IncomingVal = &*BlockDefs->rbegin(); +      } else +        IncomingVal = renameBlock(BB, IncomingVal, RenameAllUses); +      renameSuccessorPhis(BB, IncomingVal, RenameAllUses); +      WorkStack.push_back({Child, Child->begin(), IncomingVal}); +    } +  } +} + +/// This handles unreachable block accesses by deleting phi nodes in +/// unreachable blocks, and marking all other unreachable MemoryAccess's as +/// being uses of the live on entry definition. +void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { +  assert(!DT->isReachableFromEntry(BB) && +         "Reachable block found while handling unreachable blocks"); + +  // Make sure phi nodes in our reachable successors end up with a +  // LiveOnEntryDef for our incoming edge, even though our block is forward +  // unreachable.  We could just disconnect these blocks from the CFG fully, +  // but we do not right now. +  for (const BasicBlock *S : successors(BB)) { +    if (!DT->isReachableFromEntry(S)) +      continue; +    auto It = PerBlockAccesses.find(S); +    // Rename the phi nodes in our successor block +    if (It == PerBlockAccesses.end() || !isa<MemoryPhi>(It->second->front())) +      continue; +    AccessList *Accesses = It->second.get(); +    auto *Phi = cast<MemoryPhi>(&Accesses->front()); +    Phi->addIncoming(LiveOnEntryDef.get(), BB); +  } + +  auto It = PerBlockAccesses.find(BB); +  if (It == PerBlockAccesses.end()) +    return; + +  auto &Accesses = It->second; +  for (auto AI = Accesses->begin(), AE = Accesses->end(); AI != AE;) { +    auto Next = std::next(AI); +    // If we have a phi, just remove it. We are going to replace all +    // users with live on entry. +    if (auto *UseOrDef = dyn_cast<MemoryUseOrDef>(AI)) +      UseOrDef->setDefiningAccess(LiveOnEntryDef.get()); +    else +      Accesses->erase(AI); +    AI = Next; +  } +} + +MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) +    : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), +      NextID(0) { +  buildMemorySSA(); +} + +MemorySSA::~MemorySSA() { +  // Drop all our references +  for (const auto &Pair : PerBlockAccesses) +    for (MemoryAccess &MA : *Pair.second) +      MA.dropAllReferences(); +} + +MemorySSA::AccessList *MemorySSA::getOrCreateAccessList(const BasicBlock *BB) { +  auto Res = PerBlockAccesses.insert(std::make_pair(BB, nullptr)); + +  if (Res.second) +    Res.first->second = llvm::make_unique<AccessList>(); +  return Res.first->second.get(); +} + +MemorySSA::DefsList *MemorySSA::getOrCreateDefsList(const BasicBlock *BB) { +  auto Res = PerBlockDefs.insert(std::make_pair(BB, nullptr)); + +  if (Res.second) +    Res.first->second = llvm::make_unique<DefsList>(); +  return Res.first->second.get(); +} + +namespace llvm { + +/// This class is a batch walker of all MemoryUse's in the program, and points +/// their defining access at the thing that actually clobbers them.  Because it +/// is a batch walker that touches everything, it does not operate like the +/// other walkers.  This walker is basically performing a top-down SSA renaming +/// pass, where the version stack is used as the cache.  This enables it to be +/// significantly more time and memory efficient than using the regular walker, +/// which is walking bottom-up. +class MemorySSA::OptimizeUses { +public: +  OptimizeUses(MemorySSA *MSSA, MemorySSAWalker *Walker, AliasAnalysis *AA, +               DominatorTree *DT) +      : MSSA(MSSA), Walker(Walker), AA(AA), DT(DT) { +    Walker = MSSA->getWalker(); +  } + +  void optimizeUses(); + +private: +  /// This represents where a given memorylocation is in the stack. +  struct MemlocStackInfo { +    // This essentially is keeping track of versions of the stack. Whenever +    // the stack changes due to pushes or pops, these versions increase. +    unsigned long StackEpoch; +    unsigned long PopEpoch; +    // This is the lower bound of places on the stack to check. It is equal to +    // the place the last stack walk ended. +    // Note: Correctness depends on this being initialized to 0, which densemap +    // does +    unsigned long LowerBound; +    const BasicBlock *LowerBoundBlock; +    // This is where the last walk for this memory location ended. +    unsigned long LastKill; +    bool LastKillValid; +    Optional<AliasResult> AR; +  }; + +  void optimizeUsesInBlock(const BasicBlock *, unsigned long &, unsigned long &, +                           SmallVectorImpl<MemoryAccess *> &, +                           DenseMap<MemoryLocOrCall, MemlocStackInfo> &); + +  MemorySSA *MSSA; +  MemorySSAWalker *Walker; +  AliasAnalysis *AA; +  DominatorTree *DT; +}; + +} // end namespace llvm + +/// Optimize the uses in a given block This is basically the SSA renaming +/// algorithm, with one caveat: We are able to use a single stack for all +/// MemoryUses.  This is because the set of *possible* reaching MemoryDefs is +/// the same for every MemoryUse.  The *actual* clobbering MemoryDef is just +/// going to be some position in that stack of possible ones. +/// +/// We track the stack positions that each MemoryLocation needs +/// to check, and last ended at.  This is because we only want to check the +/// things that changed since last time.  The same MemoryLocation should +/// get clobbered by the same store (getModRefInfo does not use invariantness or +/// things like this, and if they start, we can modify MemoryLocOrCall to +/// include relevant data) +void MemorySSA::OptimizeUses::optimizeUsesInBlock( +    const BasicBlock *BB, unsigned long &StackEpoch, unsigned long &PopEpoch, +    SmallVectorImpl<MemoryAccess *> &VersionStack, +    DenseMap<MemoryLocOrCall, MemlocStackInfo> &LocStackInfo) { + +  /// If no accesses, nothing to do. +  MemorySSA::AccessList *Accesses = MSSA->getWritableBlockAccesses(BB); +  if (Accesses == nullptr) +    return; + +  // Pop everything that doesn't dominate the current block off the stack, +  // increment the PopEpoch to account for this. +  while (true) { +    assert( +        !VersionStack.empty() && +        "Version stack should have liveOnEntry sentinel dominating everything"); +    BasicBlock *BackBlock = VersionStack.back()->getBlock(); +    if (DT->dominates(BackBlock, BB)) +      break; +    while (VersionStack.back()->getBlock() == BackBlock) +      VersionStack.pop_back(); +    ++PopEpoch; +  } + +  for (MemoryAccess &MA : *Accesses) { +    auto *MU = dyn_cast<MemoryUse>(&MA); +    if (!MU) { +      VersionStack.push_back(&MA); +      ++StackEpoch; +      continue; +    } + +    if (isUseTriviallyOptimizableToLiveOnEntry(*AA, MU->getMemoryInst())) { +      MU->setDefiningAccess(MSSA->getLiveOnEntryDef(), true, None); +      continue; +    } + +    MemoryLocOrCall UseMLOC(MU); +    auto &LocInfo = LocStackInfo[UseMLOC]; +    // If the pop epoch changed, it means we've removed stuff from top of +    // stack due to changing blocks. We may have to reset the lower bound or +    // last kill info. +    if (LocInfo.PopEpoch != PopEpoch) { +      LocInfo.PopEpoch = PopEpoch; +      LocInfo.StackEpoch = StackEpoch; +      // If the lower bound was in something that no longer dominates us, we +      // have to reset it. +      // We can't simply track stack size, because the stack may have had +      // pushes/pops in the meantime. +      // XXX: This is non-optimal, but only is slower cases with heavily +      // branching dominator trees.  To get the optimal number of queries would +      // be to make lowerbound and lastkill a per-loc stack, and pop it until +      // the top of that stack dominates us.  This does not seem worth it ATM. +      // A much cheaper optimization would be to always explore the deepest +      // branch of the dominator tree first. This will guarantee this resets on +      // the smallest set of blocks. +      if (LocInfo.LowerBoundBlock && LocInfo.LowerBoundBlock != BB && +          !DT->dominates(LocInfo.LowerBoundBlock, BB)) { +        // Reset the lower bound of things to check. +        // TODO: Some day we should be able to reset to last kill, rather than +        // 0. +        LocInfo.LowerBound = 0; +        LocInfo.LowerBoundBlock = VersionStack[0]->getBlock(); +        LocInfo.LastKillValid = false; +      } +    } else if (LocInfo.StackEpoch != StackEpoch) { +      // If all that has changed is the StackEpoch, we only have to check the +      // new things on the stack, because we've checked everything before.  In +      // this case, the lower bound of things to check remains the same. +      LocInfo.PopEpoch = PopEpoch; +      LocInfo.StackEpoch = StackEpoch; +    } +    if (!LocInfo.LastKillValid) { +      LocInfo.LastKill = VersionStack.size() - 1; +      LocInfo.LastKillValid = true; +      LocInfo.AR = MayAlias; +    } + +    // At this point, we should have corrected last kill and LowerBound to be +    // in bounds. +    assert(LocInfo.LowerBound < VersionStack.size() && +           "Lower bound out of range"); +    assert(LocInfo.LastKill < VersionStack.size() && +           "Last kill info out of range"); +    // In any case, the new upper bound is the top of the stack. +    unsigned long UpperBound = VersionStack.size() - 1; + +    if (UpperBound - LocInfo.LowerBound > MaxCheckLimit) { +      LLVM_DEBUG(dbgs() << "MemorySSA skipping optimization of " << *MU << " (" +                        << *(MU->getMemoryInst()) << ")" +                        << " because there are " +                        << UpperBound - LocInfo.LowerBound +                        << " stores to disambiguate\n"); +      // Because we did not walk, LastKill is no longer valid, as this may +      // have been a kill. +      LocInfo.LastKillValid = false; +      continue; +    } +    bool FoundClobberResult = false; +    while (UpperBound > LocInfo.LowerBound) { +      if (isa<MemoryPhi>(VersionStack[UpperBound])) { +        // For phis, use the walker, see where we ended up, go there +        Instruction *UseInst = MU->getMemoryInst(); +        MemoryAccess *Result = Walker->getClobberingMemoryAccess(UseInst); +        // We are guaranteed to find it or something is wrong +        while (VersionStack[UpperBound] != Result) { +          assert(UpperBound != 0); +          --UpperBound; +        } +        FoundClobberResult = true; +        break; +      } + +      MemoryDef *MD = cast<MemoryDef>(VersionStack[UpperBound]); +      // If the lifetime of the pointer ends at this instruction, it's live on +      // entry. +      if (!UseMLOC.IsCall && lifetimeEndsAt(MD, UseMLOC.getLoc(), *AA)) { +        // Reset UpperBound to liveOnEntryDef's place in the stack +        UpperBound = 0; +        FoundClobberResult = true; +        LocInfo.AR = MustAlias; +        break; +      } +      ClobberAlias CA = instructionClobbersQuery(MD, MU, UseMLOC, *AA); +      if (CA.IsClobber) { +        FoundClobberResult = true; +        LocInfo.AR = CA.AR; +        break; +      } +      --UpperBound; +    } + +    // Note: Phis always have AliasResult AR set to MayAlias ATM. + +    // At the end of this loop, UpperBound is either a clobber, or lower bound +    // PHI walking may cause it to be < LowerBound, and in fact, < LastKill. +    if (FoundClobberResult || UpperBound < LocInfo.LastKill) { +      // We were last killed now by where we got to +      if (MSSA->isLiveOnEntryDef(VersionStack[UpperBound])) +        LocInfo.AR = None; +      MU->setDefiningAccess(VersionStack[UpperBound], true, LocInfo.AR); +      LocInfo.LastKill = UpperBound; +    } else { +      // Otherwise, we checked all the new ones, and now we know we can get to +      // LastKill. +      MU->setDefiningAccess(VersionStack[LocInfo.LastKill], true, LocInfo.AR); +    } +    LocInfo.LowerBound = VersionStack.size() - 1; +    LocInfo.LowerBoundBlock = BB; +  } +} + +/// Optimize uses to point to their actual clobbering definitions. +void MemorySSA::OptimizeUses::optimizeUses() { +  SmallVector<MemoryAccess *, 16> VersionStack; +  DenseMap<MemoryLocOrCall, MemlocStackInfo> LocStackInfo; +  VersionStack.push_back(MSSA->getLiveOnEntryDef()); + +  unsigned long StackEpoch = 1; +  unsigned long PopEpoch = 1; +  // We perform a non-recursive top-down dominator tree walk. +  for (const auto *DomNode : depth_first(DT->getRootNode())) +    optimizeUsesInBlock(DomNode->getBlock(), StackEpoch, PopEpoch, VersionStack, +                        LocStackInfo); +} + +void MemorySSA::placePHINodes( +    const SmallPtrSetImpl<BasicBlock *> &DefiningBlocks) { +  // Determine where our MemoryPhi's should go +  ForwardIDFCalculator IDFs(*DT); +  IDFs.setDefiningBlocks(DefiningBlocks); +  SmallVector<BasicBlock *, 32> IDFBlocks; +  IDFs.calculate(IDFBlocks); + +  // Now place MemoryPhi nodes. +  for (auto &BB : IDFBlocks) +    createMemoryPhi(BB); +} + +void MemorySSA::buildMemorySSA() { +  // We create an access to represent "live on entry", for things like +  // arguments or users of globals, where the memory they use is defined before +  // the beginning of the function. We do not actually insert it into the IR. +  // We do not define a live on exit for the immediate uses, and thus our +  // semantics do *not* imply that something with no immediate uses can simply +  // be removed. +  BasicBlock &StartingPoint = F.getEntryBlock(); +  LiveOnEntryDef.reset(new MemoryDef(F.getContext(), nullptr, nullptr, +                                     &StartingPoint, NextID++)); + +  // We maintain lists of memory accesses per-block, trading memory for time. We +  // could just look up the memory access for every possible instruction in the +  // stream. +  SmallPtrSet<BasicBlock *, 32> DefiningBlocks; +  // Go through each block, figure out where defs occur, and chain together all +  // the accesses. +  for (BasicBlock &B : F) { +    bool InsertIntoDef = false; +    AccessList *Accesses = nullptr; +    DefsList *Defs = nullptr; +    for (Instruction &I : B) { +      MemoryUseOrDef *MUD = createNewAccess(&I); +      if (!MUD) +        continue; + +      if (!Accesses) +        Accesses = getOrCreateAccessList(&B); +      Accesses->push_back(MUD); +      if (isa<MemoryDef>(MUD)) { +        InsertIntoDef = true; +        if (!Defs) +          Defs = getOrCreateDefsList(&B); +        Defs->push_back(*MUD); +      } +    } +    if (InsertIntoDef) +      DefiningBlocks.insert(&B); +  } +  placePHINodes(DefiningBlocks); + +  // Now do regular SSA renaming on the MemoryDef/MemoryUse. Visited will get +  // filled in with all blocks. +  SmallPtrSet<BasicBlock *, 16> Visited; +  renamePass(DT->getRootNode(), LiveOnEntryDef.get(), Visited); + +  CachingWalker *Walker = getWalkerImpl(); + +  OptimizeUses(this, Walker, AA, DT).optimizeUses(); + +  // Mark the uses in unreachable blocks as live on entry, so that they go +  // somewhere. +  for (auto &BB : F) +    if (!Visited.count(&BB)) +      markUnreachableAsLiveOnEntry(&BB); +} + +MemorySSAWalker *MemorySSA::getWalker() { return getWalkerImpl(); } + +MemorySSA::CachingWalker *MemorySSA::getWalkerImpl() { +  if (Walker) +    return Walker.get(); + +  Walker = llvm::make_unique<CachingWalker>(this, AA, DT); +  return Walker.get(); +} + +// This is a helper function used by the creation routines. It places NewAccess +// into the access and defs lists for a given basic block, at the given +// insertion point. +void MemorySSA::insertIntoListsForBlock(MemoryAccess *NewAccess, +                                        const BasicBlock *BB, +                                        InsertionPlace Point) { +  auto *Accesses = getOrCreateAccessList(BB); +  if (Point == Beginning) { +    // If it's a phi node, it goes first, otherwise, it goes after any phi +    // nodes. +    if (isa<MemoryPhi>(NewAccess)) { +      Accesses->push_front(NewAccess); +      auto *Defs = getOrCreateDefsList(BB); +      Defs->push_front(*NewAccess); +    } else { +      auto AI = find_if_not( +          *Accesses, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); }); +      Accesses->insert(AI, NewAccess); +      if (!isa<MemoryUse>(NewAccess)) { +        auto *Defs = getOrCreateDefsList(BB); +        auto DI = find_if_not( +            *Defs, [](const MemoryAccess &MA) { return isa<MemoryPhi>(MA); }); +        Defs->insert(DI, *NewAccess); +      } +    } +  } else { +    Accesses->push_back(NewAccess); +    if (!isa<MemoryUse>(NewAccess)) { +      auto *Defs = getOrCreateDefsList(BB); +      Defs->push_back(*NewAccess); +    } +  } +  BlockNumberingValid.erase(BB); +} + +void MemorySSA::insertIntoListsBefore(MemoryAccess *What, const BasicBlock *BB, +                                      AccessList::iterator InsertPt) { +  auto *Accesses = getWritableBlockAccesses(BB); +  bool WasEnd = InsertPt == Accesses->end(); +  Accesses->insert(AccessList::iterator(InsertPt), What); +  if (!isa<MemoryUse>(What)) { +    auto *Defs = getOrCreateDefsList(BB); +    // If we got asked to insert at the end, we have an easy job, just shove it +    // at the end. If we got asked to insert before an existing def, we also get +    // an iterator. If we got asked to insert before a use, we have to hunt for +    // the next def. +    if (WasEnd) { +      Defs->push_back(*What); +    } else if (isa<MemoryDef>(InsertPt)) { +      Defs->insert(InsertPt->getDefsIterator(), *What); +    } else { +      while (InsertPt != Accesses->end() && !isa<MemoryDef>(InsertPt)) +        ++InsertPt; +      // Either we found a def, or we are inserting at the end +      if (InsertPt == Accesses->end()) +        Defs->push_back(*What); +      else +        Defs->insert(InsertPt->getDefsIterator(), *What); +    } +  } +  BlockNumberingValid.erase(BB); +} + +// Move What before Where in the IR.  The end result is that What will belong to +// the right lists and have the right Block set, but will not otherwise be +// correct. It will not have the right defining access, and if it is a def, +// things below it will not properly be updated. +void MemorySSA::moveTo(MemoryUseOrDef *What, BasicBlock *BB, +                       AccessList::iterator Where) { +  // Keep it in the lookup tables, remove from the lists +  removeFromLists(What, false); +  What->setBlock(BB); +  insertIntoListsBefore(What, BB, Where); +} + +void MemorySSA::moveTo(MemoryAccess *What, BasicBlock *BB, +                       InsertionPlace Point) { +  if (isa<MemoryPhi>(What)) { +    assert(Point == Beginning && +           "Can only move a Phi at the beginning of the block"); +    // Update lookup table entry +    ValueToMemoryAccess.erase(What->getBlock()); +    bool Inserted = ValueToMemoryAccess.insert({BB, What}).second; +    (void)Inserted; +    assert(Inserted && "Cannot move a Phi to a block that already has one"); +  } + +  removeFromLists(What, false); +  What->setBlock(BB); +  insertIntoListsForBlock(What, BB, Point); +} + +MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { +  assert(!getMemoryAccess(BB) && "MemoryPhi already exists for this BB"); +  MemoryPhi *Phi = new MemoryPhi(BB->getContext(), BB, NextID++); +  // Phi's always are placed at the front of the block. +  insertIntoListsForBlock(Phi, BB, Beginning); +  ValueToMemoryAccess[BB] = Phi; +  return Phi; +} + +MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, +                                               MemoryAccess *Definition) { +  assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI"); +  MemoryUseOrDef *NewAccess = createNewAccess(I); +  assert( +      NewAccess != nullptr && +      "Tried to create a memory access for a non-memory touching instruction"); +  NewAccess->setDefiningAccess(Definition); +  return NewAccess; +} + +// Return true if the instruction has ordering constraints. +// Note specifically that this only considers stores and loads +// because others are still considered ModRef by getModRefInfo. +static inline bool isOrdered(const Instruction *I) { +  if (auto *SI = dyn_cast<StoreInst>(I)) { +    if (!SI->isUnordered()) +      return true; +  } else if (auto *LI = dyn_cast<LoadInst>(I)) { +    if (!LI->isUnordered()) +      return true; +  } +  return false; +} + +/// Helper function to create new memory accesses +MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { +  // The assume intrinsic has a control dependency which we model by claiming +  // that it writes arbitrarily. Ignore that fake memory dependency here. +  // FIXME: Replace this special casing with a more accurate modelling of +  // assume's control dependency. +  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) +    if (II->getIntrinsicID() == Intrinsic::assume) +      return nullptr; + +  // Find out what affect this instruction has on memory. +  ModRefInfo ModRef = AA->getModRefInfo(I, None); +  // The isOrdered check is used to ensure that volatiles end up as defs +  // (atomics end up as ModRef right now anyway).  Until we separate the +  // ordering chain from the memory chain, this enables people to see at least +  // some relative ordering to volatiles.  Note that getClobberingMemoryAccess +  // will still give an answer that bypasses other volatile loads.  TODO: +  // Separate memory aliasing and ordering into two different chains so that we +  // can precisely represent both "what memory will this read/write/is clobbered +  // by" and "what instructions can I move this past". +  bool Def = isModSet(ModRef) || isOrdered(I); +  bool Use = isRefSet(ModRef); + +  // It's possible for an instruction to not modify memory at all. During +  // construction, we ignore them. +  if (!Def && !Use) +    return nullptr; + +  MemoryUseOrDef *MUD; +  if (Def) +    MUD = new MemoryDef(I->getContext(), nullptr, I, I->getParent(), NextID++); +  else +    MUD = new MemoryUse(I->getContext(), nullptr, I, I->getParent()); +  ValueToMemoryAccess[I] = MUD; +  return MUD; +} + +/// Returns true if \p Replacer dominates \p Replacee . +bool MemorySSA::dominatesUse(const MemoryAccess *Replacer, +                             const MemoryAccess *Replacee) const { +  if (isa<MemoryUseOrDef>(Replacee)) +    return DT->dominates(Replacer->getBlock(), Replacee->getBlock()); +  const auto *MP = cast<MemoryPhi>(Replacee); +  // For a phi node, the use occurs in the predecessor block of the phi node. +  // Since we may occur multiple times in the phi node, we have to check each +  // operand to ensure Replacer dominates each operand where Replacee occurs. +  for (const Use &Arg : MP->operands()) { +    if (Arg.get() != Replacee && +        !DT->dominates(Replacer->getBlock(), MP->getIncomingBlock(Arg))) +      return false; +  } +  return true; +} + +/// Properly remove \p MA from all of MemorySSA's lookup tables. +void MemorySSA::removeFromLookups(MemoryAccess *MA) { +  assert(MA->use_empty() && +         "Trying to remove memory access that still has uses"); +  BlockNumbering.erase(MA); +  if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) +    MUD->setDefiningAccess(nullptr); +  // Invalidate our walker's cache if necessary +  if (!isa<MemoryUse>(MA)) +    Walker->invalidateInfo(MA); + +  Value *MemoryInst; +  if (const auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) +    MemoryInst = MUD->getMemoryInst(); +  else +    MemoryInst = MA->getBlock(); + +  auto VMA = ValueToMemoryAccess.find(MemoryInst); +  if (VMA->second == MA) +    ValueToMemoryAccess.erase(VMA); +} + +/// Properly remove \p MA from all of MemorySSA's lists. +/// +/// Because of the way the intrusive list and use lists work, it is important to +/// do removal in the right order. +/// ShouldDelete defaults to true, and will cause the memory access to also be +/// deleted, not just removed. +void MemorySSA::removeFromLists(MemoryAccess *MA, bool ShouldDelete) { +  BasicBlock *BB = MA->getBlock(); +  // The access list owns the reference, so we erase it from the non-owning list +  // first. +  if (!isa<MemoryUse>(MA)) { +    auto DefsIt = PerBlockDefs.find(BB); +    std::unique_ptr<DefsList> &Defs = DefsIt->second; +    Defs->remove(*MA); +    if (Defs->empty()) +      PerBlockDefs.erase(DefsIt); +  } + +  // The erase call here will delete it. If we don't want it deleted, we call +  // remove instead. +  auto AccessIt = PerBlockAccesses.find(BB); +  std::unique_ptr<AccessList> &Accesses = AccessIt->second; +  if (ShouldDelete) +    Accesses->erase(MA); +  else +    Accesses->remove(MA); + +  if (Accesses->empty()) { +    PerBlockAccesses.erase(AccessIt); +    BlockNumberingValid.erase(BB); +  } +} + +void MemorySSA::print(raw_ostream &OS) const { +  MemorySSAAnnotatedWriter Writer(this); +  F.print(OS, &Writer); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void MemorySSA::dump() const { print(dbgs()); } +#endif + +void MemorySSA::verifyMemorySSA() const { +  verifyDefUses(F); +  verifyDomination(F); +  verifyOrdering(F); +  verifyDominationNumbers(F); +  Walker->verify(this); +} + +/// Verify that all of the blocks we believe to have valid domination numbers +/// actually have valid domination numbers. +void MemorySSA::verifyDominationNumbers(const Function &F) const { +#ifndef NDEBUG +  if (BlockNumberingValid.empty()) +    return; + +  SmallPtrSet<const BasicBlock *, 16> ValidBlocks = BlockNumberingValid; +  for (const BasicBlock &BB : F) { +    if (!ValidBlocks.count(&BB)) +      continue; + +    ValidBlocks.erase(&BB); + +    const AccessList *Accesses = getBlockAccesses(&BB); +    // It's correct to say an empty block has valid numbering. +    if (!Accesses) +      continue; + +    // Block numbering starts at 1. +    unsigned long LastNumber = 0; +    for (const MemoryAccess &MA : *Accesses) { +      auto ThisNumberIter = BlockNumbering.find(&MA); +      assert(ThisNumberIter != BlockNumbering.end() && +             "MemoryAccess has no domination number in a valid block!"); + +      unsigned long ThisNumber = ThisNumberIter->second; +      assert(ThisNumber > LastNumber && +             "Domination numbers should be strictly increasing!"); +      LastNumber = ThisNumber; +    } +  } + +  assert(ValidBlocks.empty() && +         "All valid BasicBlocks should exist in F -- dangling pointers?"); +#endif +} + +/// Verify that the order and existence of MemoryAccesses matches the +/// order and existence of memory affecting instructions. +void MemorySSA::verifyOrdering(Function &F) const { +  // Walk all the blocks, comparing what the lookups think and what the access +  // lists think, as well as the order in the blocks vs the order in the access +  // lists. +  SmallVector<MemoryAccess *, 32> ActualAccesses; +  SmallVector<MemoryAccess *, 32> ActualDefs; +  for (BasicBlock &B : F) { +    const AccessList *AL = getBlockAccesses(&B); +    const auto *DL = getBlockDefs(&B); +    MemoryAccess *Phi = getMemoryAccess(&B); +    if (Phi) { +      ActualAccesses.push_back(Phi); +      ActualDefs.push_back(Phi); +    } + +    for (Instruction &I : B) { +      MemoryAccess *MA = getMemoryAccess(&I); +      assert((!MA || (AL && (isa<MemoryUse>(MA) || DL))) && +             "We have memory affecting instructions " +             "in this block but they are not in the " +             "access list or defs list"); +      if (MA) { +        ActualAccesses.push_back(MA); +        if (isa<MemoryDef>(MA)) +          ActualDefs.push_back(MA); +      } +    } +    // Either we hit the assert, really have no accesses, or we have both +    // accesses and an access list. +    // Same with defs. +    if (!AL && !DL) +      continue; +    assert(AL->size() == ActualAccesses.size() && +           "We don't have the same number of accesses in the block as on the " +           "access list"); +    assert((DL || ActualDefs.size() == 0) && +           "Either we should have a defs list, or we should have no defs"); +    assert((!DL || DL->size() == ActualDefs.size()) && +           "We don't have the same number of defs in the block as on the " +           "def list"); +    auto ALI = AL->begin(); +    auto AAI = ActualAccesses.begin(); +    while (ALI != AL->end() && AAI != ActualAccesses.end()) { +      assert(&*ALI == *AAI && "Not the same accesses in the same order"); +      ++ALI; +      ++AAI; +    } +    ActualAccesses.clear(); +    if (DL) { +      auto DLI = DL->begin(); +      auto ADI = ActualDefs.begin(); +      while (DLI != DL->end() && ADI != ActualDefs.end()) { +        assert(&*DLI == *ADI && "Not the same defs in the same order"); +        ++DLI; +        ++ADI; +      } +    } +    ActualDefs.clear(); +  } +} + +/// Verify the domination properties of MemorySSA by checking that each +/// definition dominates all of its uses. +void MemorySSA::verifyDomination(Function &F) const { +#ifndef NDEBUG +  for (BasicBlock &B : F) { +    // Phi nodes are attached to basic blocks +    if (MemoryPhi *MP = getMemoryAccess(&B)) +      for (const Use &U : MP->uses()) +        assert(dominates(MP, U) && "Memory PHI does not dominate it's uses"); + +    for (Instruction &I : B) { +      MemoryAccess *MD = dyn_cast_or_null<MemoryDef>(getMemoryAccess(&I)); +      if (!MD) +        continue; + +      for (const Use &U : MD->uses()) +        assert(dominates(MD, U) && "Memory Def does not dominate it's uses"); +    } +  } +#endif +} + +/// Verify the def-use lists in MemorySSA, by verifying that \p Use +/// appears in the use list of \p Def. +void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { +#ifndef NDEBUG +  // The live on entry use may cause us to get a NULL def here +  if (!Def) +    assert(isLiveOnEntryDef(Use) && +           "Null def but use not point to live on entry def"); +  else +    assert(is_contained(Def->users(), Use) && +           "Did not find use in def's use list"); +#endif +} + +/// Verify the immediate use information, by walking all the memory +/// accesses and verifying that, for each use, it appears in the +/// appropriate def's use list +void MemorySSA::verifyDefUses(Function &F) const { +  for (BasicBlock &B : F) { +    // Phi nodes are attached to basic blocks +    if (MemoryPhi *Phi = getMemoryAccess(&B)) { +      assert(Phi->getNumOperands() == static_cast<unsigned>(std::distance( +                                          pred_begin(&B), pred_end(&B))) && +             "Incomplete MemoryPhi Node"); +      for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I) { +        verifyUseInDefs(Phi->getIncomingValue(I), Phi); +        assert(find(predecessors(&B), Phi->getIncomingBlock(I)) != +                   pred_end(&B) && +               "Incoming phi block not a block predecessor"); +      } +    } + +    for (Instruction &I : B) { +      if (MemoryUseOrDef *MA = getMemoryAccess(&I)) { +        verifyUseInDefs(MA->getDefiningAccess(), MA); +      } +    } +  } +} + +MemoryUseOrDef *MemorySSA::getMemoryAccess(const Instruction *I) const { +  return cast_or_null<MemoryUseOrDef>(ValueToMemoryAccess.lookup(I)); +} + +MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { +  return cast_or_null<MemoryPhi>(ValueToMemoryAccess.lookup(cast<Value>(BB))); +} + +/// Perform a local numbering on blocks so that instruction ordering can be +/// determined in constant time. +/// TODO: We currently just number in order.  If we numbered by N, we could +/// allow at least N-1 sequences of insertBefore or insertAfter (and at least +/// log2(N) sequences of mixed before and after) without needing to invalidate +/// the numbering. +void MemorySSA::renumberBlock(const BasicBlock *B) const { +  // The pre-increment ensures the numbers really start at 1. +  unsigned long CurrentNumber = 0; +  const AccessList *AL = getBlockAccesses(B); +  assert(AL != nullptr && "Asking to renumber an empty block"); +  for (const auto &I : *AL) +    BlockNumbering[&I] = ++CurrentNumber; +  BlockNumberingValid.insert(B); +} + +/// Determine, for two memory accesses in the same block, +/// whether \p Dominator dominates \p Dominatee. +/// \returns True if \p Dominator dominates \p Dominatee. +bool MemorySSA::locallyDominates(const MemoryAccess *Dominator, +                                 const MemoryAccess *Dominatee) const { +  const BasicBlock *DominatorBlock = Dominator->getBlock(); + +  assert((DominatorBlock == Dominatee->getBlock()) && +         "Asking for local domination when accesses are in different blocks!"); +  // A node dominates itself. +  if (Dominatee == Dominator) +    return true; + +  // When Dominatee is defined on function entry, it is not dominated by another +  // memory access. +  if (isLiveOnEntryDef(Dominatee)) +    return false; + +  // When Dominator is defined on function entry, it dominates the other memory +  // access. +  if (isLiveOnEntryDef(Dominator)) +    return true; + +  if (!BlockNumberingValid.count(DominatorBlock)) +    renumberBlock(DominatorBlock); + +  unsigned long DominatorNum = BlockNumbering.lookup(Dominator); +  // All numbers start with 1 +  assert(DominatorNum != 0 && "Block was not numbered properly"); +  unsigned long DominateeNum = BlockNumbering.lookup(Dominatee); +  assert(DominateeNum != 0 && "Block was not numbered properly"); +  return DominatorNum < DominateeNum; +} + +bool MemorySSA::dominates(const MemoryAccess *Dominator, +                          const MemoryAccess *Dominatee) const { +  if (Dominator == Dominatee) +    return true; + +  if (isLiveOnEntryDef(Dominatee)) +    return false; + +  if (Dominator->getBlock() != Dominatee->getBlock()) +    return DT->dominates(Dominator->getBlock(), Dominatee->getBlock()); +  return locallyDominates(Dominator, Dominatee); +} + +bool MemorySSA::dominates(const MemoryAccess *Dominator, +                          const Use &Dominatee) const { +  if (MemoryPhi *MP = dyn_cast<MemoryPhi>(Dominatee.getUser())) { +    BasicBlock *UseBB = MP->getIncomingBlock(Dominatee); +    // The def must dominate the incoming block of the phi. +    if (UseBB != Dominator->getBlock()) +      return DT->dominates(Dominator->getBlock(), UseBB); +    // If the UseBB and the DefBB are the same, compare locally. +    return locallyDominates(Dominator, cast<MemoryAccess>(Dominatee)); +  } +  // If it's not a PHI node use, the normal dominates can already handle it. +  return dominates(Dominator, cast<MemoryAccess>(Dominatee.getUser())); +} + +const static char LiveOnEntryStr[] = "liveOnEntry"; + +void MemoryAccess::print(raw_ostream &OS) const { +  switch (getValueID()) { +  case MemoryPhiVal: return static_cast<const MemoryPhi *>(this)->print(OS); +  case MemoryDefVal: return static_cast<const MemoryDef *>(this)->print(OS); +  case MemoryUseVal: return static_cast<const MemoryUse *>(this)->print(OS); +  } +  llvm_unreachable("invalid value id"); +} + +void MemoryDef::print(raw_ostream &OS) const { +  MemoryAccess *UO = getDefiningAccess(); + +  auto printID = [&OS](MemoryAccess *A) { +    if (A && A->getID()) +      OS << A->getID(); +    else +      OS << LiveOnEntryStr; +  }; + +  OS << getID() << " = MemoryDef("; +  printID(UO); +  OS << ")"; + +  if (isOptimized()) { +    OS << "->"; +    printID(getOptimized()); + +    if (Optional<AliasResult> AR = getOptimizedAccessType()) +      OS << " " << *AR; +  } +} + +void MemoryPhi::print(raw_ostream &OS) const { +  bool First = true; +  OS << getID() << " = MemoryPhi("; +  for (const auto &Op : operands()) { +    BasicBlock *BB = getIncomingBlock(Op); +    MemoryAccess *MA = cast<MemoryAccess>(Op); +    if (!First) +      OS << ','; +    else +      First = false; + +    OS << '{'; +    if (BB->hasName()) +      OS << BB->getName(); +    else +      BB->printAsOperand(OS, false); +    OS << ','; +    if (unsigned ID = MA->getID()) +      OS << ID; +    else +      OS << LiveOnEntryStr; +    OS << '}'; +  } +  OS << ')'; +} + +void MemoryUse::print(raw_ostream &OS) const { +  MemoryAccess *UO = getDefiningAccess(); +  OS << "MemoryUse("; +  if (UO && UO->getID()) +    OS << UO->getID(); +  else +    OS << LiveOnEntryStr; +  OS << ')'; + +  if (Optional<AliasResult> AR = getOptimizedAccessType()) +    OS << " " << *AR; +} + +void MemoryAccess::dump() const { +// Cannot completely remove virtual function even in release mode. +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +  print(dbgs()); +  dbgs() << "\n"; +#endif +} + +char MemorySSAPrinterLegacyPass::ID = 0; + +MemorySSAPrinterLegacyPass::MemorySSAPrinterLegacyPass() : FunctionPass(ID) { +  initializeMemorySSAPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAPrinterLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<MemorySSAWrapperPass>(); +} + +bool MemorySSAPrinterLegacyPass::runOnFunction(Function &F) { +  auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); +  MSSA.print(dbgs()); +  if (VerifyMemorySSA) +    MSSA.verifyMemorySSA(); +  return false; +} + +AnalysisKey MemorySSAAnalysis::Key; + +MemorySSAAnalysis::Result MemorySSAAnalysis::run(Function &F, +                                                 FunctionAnalysisManager &AM) { +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &AA = AM.getResult<AAManager>(F); +  return MemorySSAAnalysis::Result(llvm::make_unique<MemorySSA>(F, &AA, &DT)); +} + +PreservedAnalyses MemorySSAPrinterPass::run(Function &F, +                                            FunctionAnalysisManager &AM) { +  OS << "MemorySSA for function: " << F.getName() << "\n"; +  AM.getResult<MemorySSAAnalysis>(F).getMSSA().print(OS); + +  return PreservedAnalyses::all(); +} + +PreservedAnalyses MemorySSAVerifierPass::run(Function &F, +                                             FunctionAnalysisManager &AM) { +  AM.getResult<MemorySSAAnalysis>(F).getMSSA().verifyMemorySSA(); + +  return PreservedAnalyses::all(); +} + +char MemorySSAWrapperPass::ID = 0; + +MemorySSAWrapperPass::MemorySSAWrapperPass() : FunctionPass(ID) { +  initializeMemorySSAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void MemorySSAWrapperPass::releaseMemory() { MSSA.reset(); } + +void MemorySSAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequiredTransitive<DominatorTreeWrapperPass>(); +  AU.addRequiredTransitive<AAResultsWrapperPass>(); +} + +bool MemorySSAWrapperPass::runOnFunction(Function &F) { +  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); +  MSSA.reset(new MemorySSA(F, &AA, &DT)); +  return false; +} + +void MemorySSAWrapperPass::verifyAnalysis() const { MSSA->verifyMemorySSA(); } + +void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const { +  MSSA->print(OS); +} + +MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} + +MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A, +                                        DominatorTree *D) +    : MemorySSAWalker(M), Walker(*M, *A, *D) {} + +void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { +  if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) +    MUD->resetOptimized(); +} + +/// Walk the use-def chains starting at \p MA and find +/// the MemoryAccess that actually clobbers Loc. +/// +/// \returns our clobbering memory access +MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( +    MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) { +  return Walker.findClobber(StartingAccess, Q); +} + +MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( +    MemoryAccess *StartingAccess, const MemoryLocation &Loc) { +  if (isa<MemoryPhi>(StartingAccess)) +    return StartingAccess; + +  auto *StartingUseOrDef = cast<MemoryUseOrDef>(StartingAccess); +  if (MSSA->isLiveOnEntryDef(StartingUseOrDef)) +    return StartingUseOrDef; + +  Instruction *I = StartingUseOrDef->getMemoryInst(); + +  // Conservatively, fences are always clobbers, so don't perform the walk if we +  // hit a fence. +  if (!ImmutableCallSite(I) && I->isFenceLike()) +    return StartingUseOrDef; + +  UpwardsMemoryQuery Q; +  Q.OriginalAccess = StartingUseOrDef; +  Q.StartingLoc = Loc; +  Q.Inst = I; +  Q.IsCall = false; + +  // Unlike the other function, do not walk to the def of a def, because we are +  // handed something we already believe is the clobbering access. +  MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef) +                                     ? StartingUseOrDef->getDefiningAccess() +                                     : StartingUseOrDef; + +  MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); +  LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); +  LLVM_DEBUG(dbgs() << *StartingUseOrDef << "\n"); +  LLVM_DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); +  LLVM_DEBUG(dbgs() << *Clobber << "\n"); +  return Clobber; +} + +MemoryAccess * +MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { +  auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA); +  // If this is a MemoryPhi, we can't do anything. +  if (!StartingAccess) +    return MA; + +  // If this is an already optimized use or def, return the optimized result. +  // Note: Currently, we store the optimized def result in a separate field, +  // since we can't use the defining access. +  if (StartingAccess->isOptimized()) +    return StartingAccess->getOptimized(); + +  const Instruction *I = StartingAccess->getMemoryInst(); +  UpwardsMemoryQuery Q(I, StartingAccess); +  // We can't sanely do anything with a fence, since they conservatively clobber +  // all memory, and have no locations to get pointers from to try to +  // disambiguate. +  if (!Q.IsCall && I->isFenceLike()) +    return StartingAccess; + +  if (isUseTriviallyOptimizableToLiveOnEntry(*MSSA->AA, I)) { +    MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef(); +    StartingAccess->setOptimized(LiveOnEntry); +    StartingAccess->setOptimizedAccessType(None); +    return LiveOnEntry; +  } + +  // Start with the thing we already think clobbers this location +  MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); + +  // At this point, DefiningAccess may be the live on entry def. +  // If it is, we will not get a better result. +  if (MSSA->isLiveOnEntryDef(DefiningAccess)) { +    StartingAccess->setOptimized(DefiningAccess); +    StartingAccess->setOptimizedAccessType(None); +    return DefiningAccess; +  } + +  MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); +  LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); +  LLVM_DEBUG(dbgs() << *DefiningAccess << "\n"); +  LLVM_DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); +  LLVM_DEBUG(dbgs() << *Result << "\n"); + +  StartingAccess->setOptimized(Result); +  if (MSSA->isLiveOnEntryDef(Result)) +    StartingAccess->setOptimizedAccessType(None); +  else if (Q.AR == MustAlias) +    StartingAccess->setOptimizedAccessType(MustAlias); + +  return Result; +} + +MemoryAccess * +DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) { +  if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) +    return Use->getDefiningAccess(); +  return MA; +} + +MemoryAccess *DoNothingMemorySSAWalker::getClobberingMemoryAccess( +    MemoryAccess *StartingAccess, const MemoryLocation &) { +  if (auto *Use = dyn_cast<MemoryUseOrDef>(StartingAccess)) +    return Use->getDefiningAccess(); +  return StartingAccess; +} + +void MemoryPhi::deleteMe(DerivedUser *Self) { +  delete static_cast<MemoryPhi *>(Self); +} + +void MemoryDef::deleteMe(DerivedUser *Self) { +  delete static_cast<MemoryDef *>(Self); +} + +void MemoryUse::deleteMe(DerivedUser *Self) { +  delete static_cast<MemoryUse *>(Self); +} diff --git a/contrib/llvm/lib/Analysis/MemorySSAUpdater.cpp b/contrib/llvm/lib/Analysis/MemorySSAUpdater.cpp new file mode 100644 index 000000000000..abe2b3c25a58 --- /dev/null +++ b/contrib/llvm/lib/Analysis/MemorySSAUpdater.cpp @@ -0,0 +1,636 @@ +//===-- MemorySSAUpdater.cpp - Memory SSA Updater--------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------===// +// +// This file implements the MemorySSAUpdater class. +// +//===----------------------------------------------------------------===// +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormattedStream.h" +#include <algorithm> + +#define DEBUG_TYPE "memoryssa" +using namespace llvm; + +// This is the marker algorithm from "Simple and Efficient Construction of +// Static Single Assignment Form" +// The simple, non-marker algorithm places phi nodes at any join +// Here, we place markers, and only place phi nodes if they end up necessary. +// They are only necessary if they break a cycle (IE we recursively visit +// ourselves again), or we discover, while getting the value of the operands, +// that there are two or more definitions needing to be merged. +// This still will leave non-minimal form in the case of irreducible control +// flow, where phi nodes may be in cycles with themselves, but unnecessary. +MemoryAccess *MemorySSAUpdater::getPreviousDefRecursive( +    BasicBlock *BB, +    DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> &CachedPreviousDef) { +  // First, do a cache lookup. Without this cache, certain CFG structures +  // (like a series of if statements) take exponential time to visit. +  auto Cached = CachedPreviousDef.find(BB); +  if (Cached != CachedPreviousDef.end()) { +    return Cached->second; +  } + +  if (BasicBlock *Pred = BB->getSinglePredecessor()) { +    // Single predecessor case, just recurse, we can only have one definition. +    MemoryAccess *Result = getPreviousDefFromEnd(Pred, CachedPreviousDef); +    CachedPreviousDef.insert({BB, Result}); +    return Result; +  } + +  if (VisitedBlocks.count(BB)) { +    // We hit our node again, meaning we had a cycle, we must insert a phi +    // node to break it so we have an operand. The only case this will +    // insert useless phis is if we have irreducible control flow. +    MemoryAccess *Result = MSSA->createMemoryPhi(BB); +    CachedPreviousDef.insert({BB, Result}); +    return Result; +  } + +  if (VisitedBlocks.insert(BB).second) { +    // Mark us visited so we can detect a cycle +    SmallVector<TrackingVH<MemoryAccess>, 8> PhiOps; + +    // Recurse to get the values in our predecessors for placement of a +    // potential phi node. This will insert phi nodes if we cycle in order to +    // break the cycle and have an operand. +    for (auto *Pred : predecessors(BB)) +      PhiOps.push_back(getPreviousDefFromEnd(Pred, CachedPreviousDef)); + +    // Now try to simplify the ops to avoid placing a phi. +    // This may return null if we never created a phi yet, that's okay +    MemoryPhi *Phi = dyn_cast_or_null<MemoryPhi>(MSSA->getMemoryAccess(BB)); + +    // See if we can avoid the phi by simplifying it. +    auto *Result = tryRemoveTrivialPhi(Phi, PhiOps); +    // If we couldn't simplify, we may have to create a phi +    if (Result == Phi) { +      if (!Phi) +        Phi = MSSA->createMemoryPhi(BB); + +      // See if the existing phi operands match what we need. +      // Unlike normal SSA, we only allow one phi node per block, so we can't just +      // create a new one. +      if (Phi->getNumOperands() != 0) { +        // FIXME: Figure out whether this is dead code and if so remove it. +        if (!std::equal(Phi->op_begin(), Phi->op_end(), PhiOps.begin())) { +          // These will have been filled in by the recursive read we did above. +          std::copy(PhiOps.begin(), PhiOps.end(), Phi->op_begin()); +          std::copy(pred_begin(BB), pred_end(BB), Phi->block_begin()); +        } +      } else { +        unsigned i = 0; +        for (auto *Pred : predecessors(BB)) +          Phi->addIncoming(&*PhiOps[i++], Pred); +        InsertedPHIs.push_back(Phi); +      } +      Result = Phi; +    } + +    // Set ourselves up for the next variable by resetting visited state. +    VisitedBlocks.erase(BB); +    CachedPreviousDef.insert({BB, Result}); +    return Result; +  } +  llvm_unreachable("Should have hit one of the three cases above"); +} + +// This starts at the memory access, and goes backwards in the block to find the +// previous definition. If a definition is not found the block of the access, +// it continues globally, creating phi nodes to ensure we have a single +// definition. +MemoryAccess *MemorySSAUpdater::getPreviousDef(MemoryAccess *MA) { +  if (auto *LocalResult = getPreviousDefInBlock(MA)) +    return LocalResult; +  DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> CachedPreviousDef; +  return getPreviousDefRecursive(MA->getBlock(), CachedPreviousDef); +} + +// This starts at the memory access, and goes backwards in the block to the find +// the previous definition. If the definition is not found in the block of the +// access, it returns nullptr. +MemoryAccess *MemorySSAUpdater::getPreviousDefInBlock(MemoryAccess *MA) { +  auto *Defs = MSSA->getWritableBlockDefs(MA->getBlock()); + +  // It's possible there are no defs, or we got handed the first def to start. +  if (Defs) { +    // If this is a def, we can just use the def iterators. +    if (!isa<MemoryUse>(MA)) { +      auto Iter = MA->getReverseDefsIterator(); +      ++Iter; +      if (Iter != Defs->rend()) +        return &*Iter; +    } else { +      // Otherwise, have to walk the all access iterator. +      auto End = MSSA->getWritableBlockAccesses(MA->getBlock())->rend(); +      for (auto &U : make_range(++MA->getReverseIterator(), End)) +        if (!isa<MemoryUse>(U)) +          return cast<MemoryAccess>(&U); +      // Note that if MA comes before Defs->begin(), we won't hit a def. +      return nullptr; +    } +  } +  return nullptr; +} + +// This starts at the end of block +MemoryAccess *MemorySSAUpdater::getPreviousDefFromEnd( +    BasicBlock *BB, +    DenseMap<BasicBlock *, TrackingVH<MemoryAccess>> &CachedPreviousDef) { +  auto *Defs = MSSA->getWritableBlockDefs(BB); + +  if (Defs) +    return &*Defs->rbegin(); + +  return getPreviousDefRecursive(BB, CachedPreviousDef); +} +// Recurse over a set of phi uses to eliminate the trivial ones +MemoryAccess *MemorySSAUpdater::recursePhi(MemoryAccess *Phi) { +  if (!Phi) +    return nullptr; +  TrackingVH<MemoryAccess> Res(Phi); +  SmallVector<TrackingVH<Value>, 8> Uses; +  std::copy(Phi->user_begin(), Phi->user_end(), std::back_inserter(Uses)); +  for (auto &U : Uses) { +    if (MemoryPhi *UsePhi = dyn_cast<MemoryPhi>(&*U)) { +      auto OperRange = UsePhi->operands(); +      tryRemoveTrivialPhi(UsePhi, OperRange); +    } +  } +  return Res; +} + +// Eliminate trivial phis +// Phis are trivial if they are defined either by themselves, or all the same +// argument. +// IE phi(a, a) or b = phi(a, b) or c = phi(a, a, c) +// We recursively try to remove them. +template <class RangeType> +MemoryAccess *MemorySSAUpdater::tryRemoveTrivialPhi(MemoryPhi *Phi, +                                                    RangeType &Operands) { +  // Bail out on non-opt Phis. +  if (NonOptPhis.count(Phi)) +    return Phi; + +  // Detect equal or self arguments +  MemoryAccess *Same = nullptr; +  for (auto &Op : Operands) { +    // If the same or self, good so far +    if (Op == Phi || Op == Same) +      continue; +    // not the same, return the phi since it's not eliminatable by us +    if (Same) +      return Phi; +    Same = cast<MemoryAccess>(&*Op); +  } +  // Never found a non-self reference, the phi is undef +  if (Same == nullptr) +    return MSSA->getLiveOnEntryDef(); +  if (Phi) { +    Phi->replaceAllUsesWith(Same); +    removeMemoryAccess(Phi); +  } + +  // We should only end up recursing in case we replaced something, in which +  // case, we may have made other Phis trivial. +  return recursePhi(Same); +} + +void MemorySSAUpdater::insertUse(MemoryUse *MU) { +  InsertedPHIs.clear(); +  MU->setDefiningAccess(getPreviousDef(MU)); +  // Unlike for defs, there is no extra work to do.  Because uses do not create +  // new may-defs, there are only two cases: +  // +  // 1. There was a def already below us, and therefore, we should not have +  // created a phi node because it was already needed for the def. +  // +  // 2. There is no def below us, and therefore, there is no extra renaming work +  // to do. +} + +// Set every incoming edge {BB, MP->getBlock()} of MemoryPhi MP to NewDef. +static void setMemoryPhiValueForBlock(MemoryPhi *MP, const BasicBlock *BB, +                                      MemoryAccess *NewDef) { +  // Replace any operand with us an incoming block with the new defining +  // access. +  int i = MP->getBasicBlockIndex(BB); +  assert(i != -1 && "Should have found the basic block in the phi"); +  // We can't just compare i against getNumOperands since one is signed and the +  // other not. So use it to index into the block iterator. +  for (auto BBIter = MP->block_begin() + i; BBIter != MP->block_end(); +       ++BBIter) { +    if (*BBIter != BB) +      break; +    MP->setIncomingValue(i, NewDef); +    ++i; +  } +} + +// A brief description of the algorithm: +// First, we compute what should define the new def, using the SSA +// construction algorithm. +// Then, we update the defs below us (and any new phi nodes) in the graph to +// point to the correct new defs, to ensure we only have one variable, and no +// disconnected stores. +void MemorySSAUpdater::insertDef(MemoryDef *MD, bool RenameUses) { +  InsertedPHIs.clear(); + +  // See if we had a local def, and if not, go hunting. +  MemoryAccess *DefBefore = getPreviousDef(MD); +  bool DefBeforeSameBlock = DefBefore->getBlock() == MD->getBlock(); + +  // There is a def before us, which means we can replace any store/phi uses +  // of that thing with us, since we are in the way of whatever was there +  // before. +  // We now define that def's memorydefs and memoryphis +  if (DefBeforeSameBlock) { +    for (auto UI = DefBefore->use_begin(), UE = DefBefore->use_end(); +         UI != UE;) { +      Use &U = *UI++; +      // Leave the uses alone +      if (isa<MemoryUse>(U.getUser())) +        continue; +      U.set(MD); +    } +  } + +  // and that def is now our defining access. +  // We change them in this order otherwise we will appear in the use list +  // above and reset ourselves. +  MD->setDefiningAccess(DefBefore); + +  SmallVector<WeakVH, 8> FixupList(InsertedPHIs.begin(), InsertedPHIs.end()); +  if (!DefBeforeSameBlock) { +    // If there was a local def before us, we must have the same effect it +    // did. Because every may-def is the same, any phis/etc we would create, it +    // would also have created.  If there was no local def before us, we +    // performed a global update, and have to search all successors and make +    // sure we update the first def in each of them (following all paths until +    // we hit the first def along each path). This may also insert phi nodes. +    // TODO: There are other cases we can skip this work, such as when we have a +    // single successor, and only used a straight line of single pred blocks +    // backwards to find the def.  To make that work, we'd have to track whether +    // getDefRecursive only ever used the single predecessor case.  These types +    // of paths also only exist in between CFG simplifications. +    FixupList.push_back(MD); +  } + +  while (!FixupList.empty()) { +    unsigned StartingPHISize = InsertedPHIs.size(); +    fixupDefs(FixupList); +    FixupList.clear(); +    // Put any new phis on the fixup list, and process them +    FixupList.append(InsertedPHIs.begin() + StartingPHISize, InsertedPHIs.end()); +  } +  // Now that all fixups are done, rename all uses if we are asked. +  if (RenameUses) { +    SmallPtrSet<BasicBlock *, 16> Visited; +    BasicBlock *StartBlock = MD->getBlock(); +    // We are guaranteed there is a def in the block, because we just got it +    // handed to us in this function. +    MemoryAccess *FirstDef = &*MSSA->getWritableBlockDefs(StartBlock)->begin(); +    // Convert to incoming value if it's a memorydef. A phi *is* already an +    // incoming value. +    if (auto *MD = dyn_cast<MemoryDef>(FirstDef)) +      FirstDef = MD->getDefiningAccess(); + +    MSSA->renamePass(MD->getBlock(), FirstDef, Visited); +    // We just inserted a phi into this block, so the incoming value will become +    // the phi anyway, so it does not matter what we pass. +    for (auto &MP : InsertedPHIs) { +      MemoryPhi *Phi = dyn_cast_or_null<MemoryPhi>(MP); +      if (Phi) +        MSSA->renamePass(Phi->getBlock(), nullptr, Visited); +    } +  } +} + +void MemorySSAUpdater::fixupDefs(const SmallVectorImpl<WeakVH> &Vars) { +  SmallPtrSet<const BasicBlock *, 8> Seen; +  SmallVector<const BasicBlock *, 16> Worklist; +  for (auto &Var : Vars) { +    MemoryAccess *NewDef = dyn_cast_or_null<MemoryAccess>(Var); +    if (!NewDef) +      continue; +    // First, see if there is a local def after the operand. +    auto *Defs = MSSA->getWritableBlockDefs(NewDef->getBlock()); +    auto DefIter = NewDef->getDefsIterator(); + +    // The temporary Phi is being fixed, unmark it for not to optimize. +    if (MemoryPhi *Phi = dyn_cast<MemoryPhi>(NewDef)) +      NonOptPhis.erase(Phi); + +    // If there is a local def after us, we only have to rename that. +    if (++DefIter != Defs->end()) { +      cast<MemoryDef>(DefIter)->setDefiningAccess(NewDef); +      continue; +    } + +    // Otherwise, we need to search down through the CFG. +    // For each of our successors, handle it directly if their is a phi, or +    // place on the fixup worklist. +    for (const auto *S : successors(NewDef->getBlock())) { +      if (auto *MP = MSSA->getMemoryAccess(S)) +        setMemoryPhiValueForBlock(MP, NewDef->getBlock(), NewDef); +      else +        Worklist.push_back(S); +    } + +    while (!Worklist.empty()) { +      const BasicBlock *FixupBlock = Worklist.back(); +      Worklist.pop_back(); + +      // Get the first def in the block that isn't a phi node. +      if (auto *Defs = MSSA->getWritableBlockDefs(FixupBlock)) { +        auto *FirstDef = &*Defs->begin(); +        // The loop above and below should have taken care of phi nodes +        assert(!isa<MemoryPhi>(FirstDef) && +               "Should have already handled phi nodes!"); +        // We are now this def's defining access, make sure we actually dominate +        // it +        assert(MSSA->dominates(NewDef, FirstDef) && +               "Should have dominated the new access"); + +        // This may insert new phi nodes, because we are not guaranteed the +        // block we are processing has a single pred, and depending where the +        // store was inserted, it may require phi nodes below it. +        cast<MemoryDef>(FirstDef)->setDefiningAccess(getPreviousDef(FirstDef)); +        return; +      } +      // We didn't find a def, so we must continue. +      for (const auto *S : successors(FixupBlock)) { +        // If there is a phi node, handle it. +        // Otherwise, put the block on the worklist +        if (auto *MP = MSSA->getMemoryAccess(S)) +          setMemoryPhiValueForBlock(MP, FixupBlock, NewDef); +        else { +          // If we cycle, we should have ended up at a phi node that we already +          // processed.  FIXME: Double check this +          if (!Seen.insert(S).second) +            continue; +          Worklist.push_back(S); +        } +      } +    } +  } +} + +// Move What before Where in the MemorySSA IR. +template <class WhereType> +void MemorySSAUpdater::moveTo(MemoryUseOrDef *What, BasicBlock *BB, +                              WhereType Where) { +  // Mark MemoryPhi users of What not to be optimized. +  for (auto *U : What->users()) +    if (MemoryPhi *PhiUser = dyn_cast<MemoryPhi>(U)) +      NonOptPhis.insert(PhiUser); + +  // Replace all our users with our defining access. +  What->replaceAllUsesWith(What->getDefiningAccess()); + +  // Let MemorySSA take care of moving it around in the lists. +  MSSA->moveTo(What, BB, Where); + +  // Now reinsert it into the IR and do whatever fixups needed. +  if (auto *MD = dyn_cast<MemoryDef>(What)) +    insertDef(MD); +  else +    insertUse(cast<MemoryUse>(What)); + +  // Clear dangling pointers. We added all MemoryPhi users, but not all +  // of them are removed by fixupDefs(). +  NonOptPhis.clear(); +} + +// Move What before Where in the MemorySSA IR. +void MemorySSAUpdater::moveBefore(MemoryUseOrDef *What, MemoryUseOrDef *Where) { +  moveTo(What, Where->getBlock(), Where->getIterator()); +} + +// Move What after Where in the MemorySSA IR. +void MemorySSAUpdater::moveAfter(MemoryUseOrDef *What, MemoryUseOrDef *Where) { +  moveTo(What, Where->getBlock(), ++Where->getIterator()); +} + +void MemorySSAUpdater::moveToPlace(MemoryUseOrDef *What, BasicBlock *BB, +                                   MemorySSA::InsertionPlace Where) { +  return moveTo(What, BB, Where); +} + +// All accesses in To used to be in From. Move to end and update access lists. +void MemorySSAUpdater::moveAllAccesses(BasicBlock *From, BasicBlock *To, +                                       Instruction *Start) { + +  MemorySSA::AccessList *Accs = MSSA->getWritableBlockAccesses(From); +  if (!Accs) +    return; + +  MemoryAccess *FirstInNew = nullptr; +  for (Instruction &I : make_range(Start->getIterator(), To->end())) +    if ((FirstInNew = MSSA->getMemoryAccess(&I))) +      break; +  if (!FirstInNew) +    return; + +  auto *MUD = cast<MemoryUseOrDef>(FirstInNew); +  do { +    auto NextIt = ++MUD->getIterator(); +    MemoryUseOrDef *NextMUD = (!Accs || NextIt == Accs->end()) +                                  ? nullptr +                                  : cast<MemoryUseOrDef>(&*NextIt); +    MSSA->moveTo(MUD, To, MemorySSA::End); +    // Moving MUD from Accs in the moveTo above, may delete Accs, so we need to +    // retrieve it again. +    Accs = MSSA->getWritableBlockAccesses(From); +    MUD = NextMUD; +  } while (MUD); +} + +void MemorySSAUpdater::moveAllAfterSpliceBlocks(BasicBlock *From, +                                                BasicBlock *To, +                                                Instruction *Start) { +  assert(MSSA->getBlockAccesses(To) == nullptr && +         "To block is expected to be free of MemoryAccesses."); +  moveAllAccesses(From, To, Start); +  for (BasicBlock *Succ : successors(To)) +    if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Succ)) +      MPhi->setIncomingBlock(MPhi->getBasicBlockIndex(From), To); +} + +void MemorySSAUpdater::moveAllAfterMergeBlocks(BasicBlock *From, BasicBlock *To, +                                               Instruction *Start) { +  assert(From->getSinglePredecessor() == To && +         "From block is expected to have a single predecessor (To)."); +  moveAllAccesses(From, To, Start); +  for (BasicBlock *Succ : successors(From)) +    if (MemoryPhi *MPhi = MSSA->getMemoryAccess(Succ)) +      MPhi->setIncomingBlock(MPhi->getBasicBlockIndex(From), To); +} + +/// If all arguments of a MemoryPHI are defined by the same incoming +/// argument, return that argument. +static MemoryAccess *onlySingleValue(MemoryPhi *MP) { +  MemoryAccess *MA = nullptr; + +  for (auto &Arg : MP->operands()) { +    if (!MA) +      MA = cast<MemoryAccess>(Arg); +    else if (MA != Arg) +      return nullptr; +  } +  return MA; +} + +void MemorySSAUpdater::wireOldPredecessorsToNewImmediatePredecessor( +    BasicBlock *Old, BasicBlock *New, ArrayRef<BasicBlock *> Preds) { +  assert(!MSSA->getWritableBlockAccesses(New) && +         "Access list should be null for a new block."); +  MemoryPhi *Phi = MSSA->getMemoryAccess(Old); +  if (!Phi) +    return; +  if (pred_size(Old) == 1) { +    assert(pred_size(New) == Preds.size() && +           "Should have moved all predecessors."); +    MSSA->moveTo(Phi, New, MemorySSA::Beginning); +  } else { +    assert(!Preds.empty() && "Must be moving at least one predecessor to the " +                             "new immediate predecessor."); +    MemoryPhi *NewPhi = MSSA->createMemoryPhi(New); +    SmallPtrSet<BasicBlock *, 16> PredsSet(Preds.begin(), Preds.end()); +    Phi->unorderedDeleteIncomingIf([&](MemoryAccess *MA, BasicBlock *B) { +      if (PredsSet.count(B)) { +        NewPhi->addIncoming(MA, B); +        return true; +      } +      return false; +    }); +    Phi->addIncoming(NewPhi, New); +    if (onlySingleValue(NewPhi)) +      removeMemoryAccess(NewPhi); +  } +} + +void MemorySSAUpdater::removeMemoryAccess(MemoryAccess *MA) { +  assert(!MSSA->isLiveOnEntryDef(MA) && +         "Trying to remove the live on entry def"); +  // We can only delete phi nodes if they have no uses, or we can replace all +  // uses with a single definition. +  MemoryAccess *NewDefTarget = nullptr; +  if (MemoryPhi *MP = dyn_cast<MemoryPhi>(MA)) { +    // Note that it is sufficient to know that all edges of the phi node have +    // the same argument.  If they do, by the definition of dominance frontiers +    // (which we used to place this phi), that argument must dominate this phi, +    // and thus, must dominate the phi's uses, and so we will not hit the assert +    // below. +    NewDefTarget = onlySingleValue(MP); +    assert((NewDefTarget || MP->use_empty()) && +           "We can't delete this memory phi"); +  } else { +    NewDefTarget = cast<MemoryUseOrDef>(MA)->getDefiningAccess(); +  } + +  // Re-point the uses at our defining access +  if (!isa<MemoryUse>(MA) && !MA->use_empty()) { +    // Reset optimized on users of this store, and reset the uses. +    // A few notes: +    // 1. This is a slightly modified version of RAUW to avoid walking the +    // uses twice here. +    // 2. If we wanted to be complete, we would have to reset the optimized +    // flags on users of phi nodes if doing the below makes a phi node have all +    // the same arguments. Instead, we prefer users to removeMemoryAccess those +    // phi nodes, because doing it here would be N^3. +    if (MA->hasValueHandle()) +      ValueHandleBase::ValueIsRAUWd(MA, NewDefTarget); +    // Note: We assume MemorySSA is not used in metadata since it's not really +    // part of the IR. + +    while (!MA->use_empty()) { +      Use &U = *MA->use_begin(); +      if (auto *MUD = dyn_cast<MemoryUseOrDef>(U.getUser())) +        MUD->resetOptimized(); +      U.set(NewDefTarget); +    } +  } + +  // The call below to erase will destroy MA, so we can't change the order we +  // are doing things here +  MSSA->removeFromLookups(MA); +  MSSA->removeFromLists(MA); +} + +void MemorySSAUpdater::removeBlocks( +    const SmallPtrSetImpl<BasicBlock *> &DeadBlocks) { +  // First delete all uses of BB in MemoryPhis. +  for (BasicBlock *BB : DeadBlocks) { +    TerminatorInst *TI = BB->getTerminator(); +    assert(TI && "Basic block expected to have a terminator instruction"); +    for (BasicBlock *Succ : TI->successors()) +      if (!DeadBlocks.count(Succ)) +        if (MemoryPhi *MP = MSSA->getMemoryAccess(Succ)) { +          MP->unorderedDeleteIncomingBlock(BB); +          if (MP->getNumIncomingValues() == 1) +            removeMemoryAccess(MP); +        } +    // Drop all references of all accesses in BB +    if (MemorySSA::AccessList *Acc = MSSA->getWritableBlockAccesses(BB)) +      for (MemoryAccess &MA : *Acc) +        MA.dropAllReferences(); +  } + +  // Next, delete all memory accesses in each block +  for (BasicBlock *BB : DeadBlocks) { +    MemorySSA::AccessList *Acc = MSSA->getWritableBlockAccesses(BB); +    if (!Acc) +      continue; +    for (auto AB = Acc->begin(), AE = Acc->end(); AB != AE;) { +      MemoryAccess *MA = &*AB; +      ++AB; +      MSSA->removeFromLookups(MA); +      MSSA->removeFromLists(MA); +    } +  } +} + +MemoryAccess *MemorySSAUpdater::createMemoryAccessInBB( +    Instruction *I, MemoryAccess *Definition, const BasicBlock *BB, +    MemorySSA::InsertionPlace Point) { +  MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); +  MSSA->insertIntoListsForBlock(NewAccess, BB, Point); +  return NewAccess; +} + +MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessBefore( +    Instruction *I, MemoryAccess *Definition, MemoryUseOrDef *InsertPt) { +  assert(I->getParent() == InsertPt->getBlock() && +         "New and old access must be in the same block"); +  MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); +  MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(), +                              InsertPt->getIterator()); +  return NewAccess; +} + +MemoryUseOrDef *MemorySSAUpdater::createMemoryAccessAfter( +    Instruction *I, MemoryAccess *Definition, MemoryAccess *InsertPt) { +  assert(I->getParent() == InsertPt->getBlock() && +         "New and old access must be in the same block"); +  MemoryUseOrDef *NewAccess = MSSA->createDefinedAccess(I, Definition); +  MSSA->insertIntoListsBefore(NewAccess, InsertPt->getBlock(), +                              ++InsertPt->getIterator()); +  return NewAccess; +} diff --git a/contrib/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp b/contrib/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp new file mode 100644 index 000000000000..1e321f17d59f --- /dev/null +++ b/contrib/llvm/lib/Analysis/ModuleDebugInfoPrinter.cpp @@ -0,0 +1,128 @@ +//===-- ModuleDebugInfoPrinter.cpp - Prints module debug info metadata ----===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass decodes the debug info metadata in a module and prints in a +// (sufficiently-prepared-) human-readable form. +// +// For example, run this pass from opt along with the -analyze option, and +// it'll print to standard output. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/Pass.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +namespace { +  class ModuleDebugInfoPrinter : public ModulePass { +    DebugInfoFinder Finder; +  public: +    static char ID; // Pass identification, replacement for typeid +    ModuleDebugInfoPrinter() : ModulePass(ID) { +      initializeModuleDebugInfoPrinterPass(*PassRegistry::getPassRegistry()); +    } + +    bool runOnModule(Module &M) override; + +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +    } +    void print(raw_ostream &O, const Module *M) const override; +  }; +} + +char ModuleDebugInfoPrinter::ID = 0; +INITIALIZE_PASS(ModuleDebugInfoPrinter, "module-debuginfo", +                "Decodes module-level debug info", false, true) + +ModulePass *llvm::createModuleDebugInfoPrinterPass() { +  return new ModuleDebugInfoPrinter(); +} + +bool ModuleDebugInfoPrinter::runOnModule(Module &M) { +  Finder.processModule(M); +  return false; +} + +static void printFile(raw_ostream &O, StringRef Filename, StringRef Directory, +                      unsigned Line = 0) { +  if (Filename.empty()) +    return; + +  O << " from "; +  if (!Directory.empty()) +    O << Directory << "/"; +  O << Filename; +  if (Line) +    O << ":" << Line; +} + +void ModuleDebugInfoPrinter::print(raw_ostream &O, const Module *M) const { +  // Printing the nodes directly isn't particularly helpful (since they +  // reference other nodes that won't be printed, particularly for the +  // filenames), so just print a few useful things. +  for (DICompileUnit *CU : Finder.compile_units()) { +    O << "Compile unit: "; +    auto Lang = dwarf::LanguageString(CU->getSourceLanguage()); +    if (!Lang.empty()) +      O << Lang; +    else +      O << "unknown-language(" << CU->getSourceLanguage() << ")"; +    printFile(O, CU->getFilename(), CU->getDirectory()); +    O << '\n'; +  } + +  for (DISubprogram *S : Finder.subprograms()) { +    O << "Subprogram: " << S->getName(); +    printFile(O, S->getFilename(), S->getDirectory(), S->getLine()); +    if (!S->getLinkageName().empty()) +      O << " ('" << S->getLinkageName() << "')"; +    O << '\n'; +  } + +  for (auto GVU : Finder.global_variables()) { +    const auto *GV = GVU->getVariable(); +    O << "Global variable: " << GV->getName(); +    printFile(O, GV->getFilename(), GV->getDirectory(), GV->getLine()); +    if (!GV->getLinkageName().empty()) +      O << " ('" << GV->getLinkageName() << "')"; +    O << '\n'; +  } + +  for (const DIType *T : Finder.types()) { +    O << "Type:"; +    if (!T->getName().empty()) +      O << ' ' << T->getName(); +    printFile(O, T->getFilename(), T->getDirectory(), T->getLine()); +    if (auto *BT = dyn_cast<DIBasicType>(T)) { +      O << " "; +      auto Encoding = dwarf::AttributeEncodingString(BT->getEncoding()); +      if (!Encoding.empty()) +        O << Encoding; +      else +        O << "unknown-encoding(" << BT->getEncoding() << ')'; +    } else { +      O << ' '; +      auto Tag = dwarf::TagString(T->getTag()); +      if (!Tag.empty()) +        O << Tag; +      else +        O << "unknown-tag(" << T->getTag() << ")"; +    } +    if (auto *CT = dyn_cast<DICompositeType>(T)) { +      if (auto *S = CT->getRawIdentifier()) +        O << " (identifier: '" << S->getString() << "')"; +    } +    O << '\n'; +  } +} diff --git a/contrib/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/contrib/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp new file mode 100644 index 000000000000..17dae20ce3a1 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -0,0 +1,630 @@ +//===- ModuleSummaryAnalysis.cpp - Module summary index builder -----------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass builds a ModuleSummaryIndex object for the module, to be written +// to bitcode or LLVM assembly. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/IndirectCallPromotionAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndex.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/Object/ModuleSymbolTable.h" +#include "llvm/Object/SymbolicFile.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "module-summary-analysis" + +// Option to force edges cold which will block importing when the +// -import-cold-multiplier is set to 0. Useful for debugging. +FunctionSummary::ForceSummaryHotnessType ForceSummaryEdgesCold = +    FunctionSummary::FSHT_None; +cl::opt<FunctionSummary::ForceSummaryHotnessType, true> FSEC( +    "force-summary-edges-cold", cl::Hidden, cl::location(ForceSummaryEdgesCold), +    cl::desc("Force all edges in the function summary to cold"), +    cl::values(clEnumValN(FunctionSummary::FSHT_None, "none", "None."), +               clEnumValN(FunctionSummary::FSHT_AllNonCritical, +                          "all-non-critical", "All non-critical edges."), +               clEnumValN(FunctionSummary::FSHT_All, "all", "All edges."))); + +// Walk through the operands of a given User via worklist iteration and populate +// the set of GlobalValue references encountered. Invoked either on an +// Instruction or a GlobalVariable (which walks its initializer). +static void findRefEdges(ModuleSummaryIndex &Index, const User *CurUser, +                         SetVector<ValueInfo> &RefEdges, +                         SmallPtrSet<const User *, 8> &Visited) { +  SmallVector<const User *, 32> Worklist; +  Worklist.push_back(CurUser); + +  while (!Worklist.empty()) { +    const User *U = Worklist.pop_back_val(); + +    if (!Visited.insert(U).second) +      continue; + +    ImmutableCallSite CS(U); + +    for (const auto &OI : U->operands()) { +      const User *Operand = dyn_cast<User>(OI); +      if (!Operand) +        continue; +      if (isa<BlockAddress>(Operand)) +        continue; +      if (auto *GV = dyn_cast<GlobalValue>(Operand)) { +        // We have a reference to a global value. This should be added to +        // the reference set unless it is a callee. Callees are handled +        // specially by WriteFunction and are added to a separate list. +        if (!(CS && CS.isCallee(&OI))) +          RefEdges.insert(Index.getOrInsertValueInfo(GV)); +        continue; +      } +      Worklist.push_back(Operand); +    } +  } +} + +static CalleeInfo::HotnessType getHotness(uint64_t ProfileCount, +                                          ProfileSummaryInfo *PSI) { +  if (!PSI) +    return CalleeInfo::HotnessType::Unknown; +  if (PSI->isHotCount(ProfileCount)) +    return CalleeInfo::HotnessType::Hot; +  if (PSI->isColdCount(ProfileCount)) +    return CalleeInfo::HotnessType::Cold; +  return CalleeInfo::HotnessType::None; +} + +static bool isNonRenamableLocal(const GlobalValue &GV) { +  return GV.hasSection() && GV.hasLocalLinkage(); +} + +/// Determine whether this call has all constant integer arguments (excluding +/// "this") and summarize it to VCalls or ConstVCalls as appropriate. +static void addVCallToSet(DevirtCallSite Call, GlobalValue::GUID Guid, +                          SetVector<FunctionSummary::VFuncId> &VCalls, +                          SetVector<FunctionSummary::ConstVCall> &ConstVCalls) { +  std::vector<uint64_t> Args; +  // Start from the second argument to skip the "this" pointer. +  for (auto &Arg : make_range(Call.CS.arg_begin() + 1, Call.CS.arg_end())) { +    auto *CI = dyn_cast<ConstantInt>(Arg); +    if (!CI || CI->getBitWidth() > 64) { +      VCalls.insert({Guid, Call.Offset}); +      return; +    } +    Args.push_back(CI->getZExtValue()); +  } +  ConstVCalls.insert({{Guid, Call.Offset}, std::move(Args)}); +} + +/// If this intrinsic call requires that we add information to the function +/// summary, do so via the non-constant reference arguments. +static void addIntrinsicToSummary( +    const CallInst *CI, SetVector<GlobalValue::GUID> &TypeTests, +    SetVector<FunctionSummary::VFuncId> &TypeTestAssumeVCalls, +    SetVector<FunctionSummary::VFuncId> &TypeCheckedLoadVCalls, +    SetVector<FunctionSummary::ConstVCall> &TypeTestAssumeConstVCalls, +    SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls) { +  switch (CI->getCalledFunction()->getIntrinsicID()) { +  case Intrinsic::type_test: { +    auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1)); +    auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); +    if (!TypeId) +      break; +    GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString()); + +    // Produce a summary from type.test intrinsics. We only summarize type.test +    // intrinsics that are used other than by an llvm.assume intrinsic. +    // Intrinsics that are assumed are relevant only to the devirtualization +    // pass, not the type test lowering pass. +    bool HasNonAssumeUses = llvm::any_of(CI->uses(), [](const Use &CIU) { +      auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser()); +      if (!AssumeCI) +        return true; +      Function *F = AssumeCI->getCalledFunction(); +      return !F || F->getIntrinsicID() != Intrinsic::assume; +    }); +    if (HasNonAssumeUses) +      TypeTests.insert(Guid); + +    SmallVector<DevirtCallSite, 4> DevirtCalls; +    SmallVector<CallInst *, 4> Assumes; +    findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); +    for (auto &Call : DevirtCalls) +      addVCallToSet(Call, Guid, TypeTestAssumeVCalls, +                    TypeTestAssumeConstVCalls); + +    break; +  } + +  case Intrinsic::type_checked_load: { +    auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(2)); +    auto *TypeId = dyn_cast<MDString>(TypeMDVal->getMetadata()); +    if (!TypeId) +      break; +    GlobalValue::GUID Guid = GlobalValue::getGUID(TypeId->getString()); + +    SmallVector<DevirtCallSite, 4> DevirtCalls; +    SmallVector<Instruction *, 4> LoadedPtrs; +    SmallVector<Instruction *, 4> Preds; +    bool HasNonCallUses = false; +    findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, +                                               HasNonCallUses, CI); +    // Any non-call uses of the result of llvm.type.checked.load will +    // prevent us from optimizing away the llvm.type.test. +    if (HasNonCallUses) +      TypeTests.insert(Guid); +    for (auto &Call : DevirtCalls) +      addVCallToSet(Call, Guid, TypeCheckedLoadVCalls, +                    TypeCheckedLoadConstVCalls); + +    break; +  } +  default: +    break; +  } +} + +static void +computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, +                       const Function &F, BlockFrequencyInfo *BFI, +                       ProfileSummaryInfo *PSI, bool HasLocalsInUsedOrAsm, +                       DenseSet<GlobalValue::GUID> &CantBePromoted) { +  // Summary not currently supported for anonymous functions, they should +  // have been named. +  assert(F.hasName()); + +  unsigned NumInsts = 0; +  // Map from callee ValueId to profile count. Used to accumulate profile +  // counts for all static calls to a given callee. +  MapVector<ValueInfo, CalleeInfo> CallGraphEdges; +  SetVector<ValueInfo> RefEdges; +  SetVector<GlobalValue::GUID> TypeTests; +  SetVector<FunctionSummary::VFuncId> TypeTestAssumeVCalls, +      TypeCheckedLoadVCalls; +  SetVector<FunctionSummary::ConstVCall> TypeTestAssumeConstVCalls, +      TypeCheckedLoadConstVCalls; +  ICallPromotionAnalysis ICallAnalysis; +  SmallPtrSet<const User *, 8> Visited; + +  // Add personality function, prefix data and prologue data to function's ref +  // list. +  findRefEdges(Index, &F, RefEdges, Visited); + +  bool HasInlineAsmMaybeReferencingInternal = false; +  for (const BasicBlock &BB : F) +    for (const Instruction &I : BB) { +      if (isa<DbgInfoIntrinsic>(I)) +        continue; +      ++NumInsts; +      findRefEdges(Index, &I, RefEdges, Visited); +      auto CS = ImmutableCallSite(&I); +      if (!CS) +        continue; + +      const auto *CI = dyn_cast<CallInst>(&I); +      // Since we don't know exactly which local values are referenced in inline +      // assembly, conservatively mark the function as possibly referencing +      // a local value from inline assembly to ensure we don't export a +      // reference (which would require renaming and promotion of the +      // referenced value). +      if (HasLocalsInUsedOrAsm && CI && CI->isInlineAsm()) +        HasInlineAsmMaybeReferencingInternal = true; + +      auto *CalledValue = CS.getCalledValue(); +      auto *CalledFunction = CS.getCalledFunction(); +      if (CalledValue && !CalledFunction) { +        CalledValue = CalledValue->stripPointerCastsNoFollowAliases(); +        // Stripping pointer casts can reveal a called function. +        CalledFunction = dyn_cast<Function>(CalledValue); +      } +      // Check if this is an alias to a function. If so, get the +      // called aliasee for the checks below. +      if (auto *GA = dyn_cast<GlobalAlias>(CalledValue)) { +        assert(!CalledFunction && "Expected null called function in callsite for alias"); +        CalledFunction = dyn_cast<Function>(GA->getBaseObject()); +      } +      // Check if this is a direct call to a known function or a known +      // intrinsic, or an indirect call with profile data. +      if (CalledFunction) { +        if (CI && CalledFunction->isIntrinsic()) { +          addIntrinsicToSummary( +              CI, TypeTests, TypeTestAssumeVCalls, TypeCheckedLoadVCalls, +              TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls); +          continue; +        } +        // We should have named any anonymous globals +        assert(CalledFunction->hasName()); +        auto ScaledCount = PSI->getProfileCount(&I, BFI); +        auto Hotness = ScaledCount ? getHotness(ScaledCount.getValue(), PSI) +                                   : CalleeInfo::HotnessType::Unknown; +        if (ForceSummaryEdgesCold != FunctionSummary::FSHT_None) +          Hotness = CalleeInfo::HotnessType::Cold; + +        // Use the original CalledValue, in case it was an alias. We want +        // to record the call edge to the alias in that case. Eventually +        // an alias summary will be created to associate the alias and +        // aliasee. +        auto &ValueInfo = CallGraphEdges[Index.getOrInsertValueInfo( +            cast<GlobalValue>(CalledValue))]; +        ValueInfo.updateHotness(Hotness); +        // Add the relative block frequency to CalleeInfo if there is no profile +        // information. +        if (BFI != nullptr && Hotness == CalleeInfo::HotnessType::Unknown) { +          uint64_t BBFreq = BFI->getBlockFreq(&BB).getFrequency(); +          uint64_t EntryFreq = BFI->getEntryFreq(); +          ValueInfo.updateRelBlockFreq(BBFreq, EntryFreq); +        } +      } else { +        // Skip inline assembly calls. +        if (CI && CI->isInlineAsm()) +          continue; +        // Skip direct calls. +        if (!CalledValue || isa<Constant>(CalledValue)) +          continue; + +        // Check if the instruction has a callees metadata. If so, add callees +        // to CallGraphEdges to reflect the references from the metadata, and +        // to enable importing for subsequent indirect call promotion and +        // inlining. +        if (auto *MD = I.getMetadata(LLVMContext::MD_callees)) { +          for (auto &Op : MD->operands()) { +            Function *Callee = mdconst::extract_or_null<Function>(Op); +            if (Callee) +              CallGraphEdges[Index.getOrInsertValueInfo(Callee)]; +          } +        } + +        uint32_t NumVals, NumCandidates; +        uint64_t TotalCount; +        auto CandidateProfileData = +            ICallAnalysis.getPromotionCandidatesForInstruction( +                &I, NumVals, TotalCount, NumCandidates); +        for (auto &Candidate : CandidateProfileData) +          CallGraphEdges[Index.getOrInsertValueInfo(Candidate.Value)] +              .updateHotness(getHotness(Candidate.Count, PSI)); +      } +    } + +  // Explicit add hot edges to enforce importing for designated GUIDs for +  // sample PGO, to enable the same inlines as the profiled optimized binary. +  for (auto &I : F.getImportGUIDs()) +    CallGraphEdges[Index.getOrInsertValueInfo(I)].updateHotness( +        ForceSummaryEdgesCold == FunctionSummary::FSHT_All +            ? CalleeInfo::HotnessType::Cold +            : CalleeInfo::HotnessType::Critical); + +  bool NonRenamableLocal = isNonRenamableLocal(F); +  bool NotEligibleForImport = +      NonRenamableLocal || HasInlineAsmMaybeReferencingInternal || +      // Inliner doesn't handle variadic functions. +      // FIXME: refactor this to use the same code that inliner is using. +      F.isVarArg() || +      // Don't try to import functions with noinline attribute. +      F.getAttributes().hasFnAttribute(Attribute::NoInline); +  GlobalValueSummary::GVFlags Flags(F.getLinkage(), NotEligibleForImport, +                                    /* Live = */ false, F.isDSOLocal()); +  FunctionSummary::FFlags FunFlags{ +      F.hasFnAttribute(Attribute::ReadNone), +      F.hasFnAttribute(Attribute::ReadOnly), +      F.hasFnAttribute(Attribute::NoRecurse), +      F.returnDoesNotAlias(), +  }; +  auto FuncSummary = llvm::make_unique<FunctionSummary>( +      Flags, NumInsts, FunFlags, RefEdges.takeVector(), +      CallGraphEdges.takeVector(), TypeTests.takeVector(), +      TypeTestAssumeVCalls.takeVector(), TypeCheckedLoadVCalls.takeVector(), +      TypeTestAssumeConstVCalls.takeVector(), +      TypeCheckedLoadConstVCalls.takeVector()); +  if (NonRenamableLocal) +    CantBePromoted.insert(F.getGUID()); +  Index.addGlobalValueSummary(F, std::move(FuncSummary)); +} + +static void +computeVariableSummary(ModuleSummaryIndex &Index, const GlobalVariable &V, +                       DenseSet<GlobalValue::GUID> &CantBePromoted) { +  SetVector<ValueInfo> RefEdges; +  SmallPtrSet<const User *, 8> Visited; +  findRefEdges(Index, &V, RefEdges, Visited); +  bool NonRenamableLocal = isNonRenamableLocal(V); +  GlobalValueSummary::GVFlags Flags(V.getLinkage(), NonRenamableLocal, +                                    /* Live = */ false, V.isDSOLocal()); +  auto GVarSummary = +      llvm::make_unique<GlobalVarSummary>(Flags, RefEdges.takeVector()); +  if (NonRenamableLocal) +    CantBePromoted.insert(V.getGUID()); +  Index.addGlobalValueSummary(V, std::move(GVarSummary)); +} + +static void +computeAliasSummary(ModuleSummaryIndex &Index, const GlobalAlias &A, +                    DenseSet<GlobalValue::GUID> &CantBePromoted) { +  bool NonRenamableLocal = isNonRenamableLocal(A); +  GlobalValueSummary::GVFlags Flags(A.getLinkage(), NonRenamableLocal, +                                    /* Live = */ false, A.isDSOLocal()); +  auto AS = llvm::make_unique<AliasSummary>(Flags); +  auto *Aliasee = A.getBaseObject(); +  auto *AliaseeSummary = Index.getGlobalValueSummary(*Aliasee); +  assert(AliaseeSummary && "Alias expects aliasee summary to be parsed"); +  AS->setAliasee(AliaseeSummary); +  if (NonRenamableLocal) +    CantBePromoted.insert(A.getGUID()); +  Index.addGlobalValueSummary(A, std::move(AS)); +} + +// Set LiveRoot flag on entries matching the given value name. +static void setLiveRoot(ModuleSummaryIndex &Index, StringRef Name) { +  if (ValueInfo VI = Index.getValueInfo(GlobalValue::getGUID(Name))) +    for (auto &Summary : VI.getSummaryList()) +      Summary->setLive(true); +} + +ModuleSummaryIndex llvm::buildModuleSummaryIndex( +    const Module &M, +    std::function<BlockFrequencyInfo *(const Function &F)> GetBFICallback, +    ProfileSummaryInfo *PSI) { +  assert(PSI); +  ModuleSummaryIndex Index(/*HaveGVs=*/true); + +  // Identify the local values in the llvm.used and llvm.compiler.used sets, +  // which should not be exported as they would then require renaming and +  // promotion, but we may have opaque uses e.g. in inline asm. We collect them +  // here because we use this information to mark functions containing inline +  // assembly calls as not importable. +  SmallPtrSet<GlobalValue *, 8> LocalsUsed; +  SmallPtrSet<GlobalValue *, 8> Used; +  // First collect those in the llvm.used set. +  collectUsedGlobalVariables(M, Used, /*CompilerUsed*/ false); +  // Next collect those in the llvm.compiler.used set. +  collectUsedGlobalVariables(M, Used, /*CompilerUsed*/ true); +  DenseSet<GlobalValue::GUID> CantBePromoted; +  for (auto *V : Used) { +    if (V->hasLocalLinkage()) { +      LocalsUsed.insert(V); +      CantBePromoted.insert(V->getGUID()); +    } +  } + +  bool HasLocalInlineAsmSymbol = false; +  if (!M.getModuleInlineAsm().empty()) { +    // Collect the local values defined by module level asm, and set up +    // summaries for these symbols so that they can be marked as NoRename, +    // to prevent export of any use of them in regular IR that would require +    // renaming within the module level asm. Note we don't need to create a +    // summary for weak or global defs, as they don't need to be flagged as +    // NoRename, and defs in module level asm can't be imported anyway. +    // Also, any values used but not defined within module level asm should +    // be listed on the llvm.used or llvm.compiler.used global and marked as +    // referenced from there. +    ModuleSymbolTable::CollectAsmSymbols( +        M, [&](StringRef Name, object::BasicSymbolRef::Flags Flags) { +          // Symbols not marked as Weak or Global are local definitions. +          if (Flags & (object::BasicSymbolRef::SF_Weak | +                       object::BasicSymbolRef::SF_Global)) +            return; +          HasLocalInlineAsmSymbol = true; +          GlobalValue *GV = M.getNamedValue(Name); +          if (!GV) +            return; +          assert(GV->isDeclaration() && "Def in module asm already has definition"); +          GlobalValueSummary::GVFlags GVFlags(GlobalValue::InternalLinkage, +                                              /* NotEligibleToImport = */ true, +                                              /* Live = */ true, +                                              /* Local */ GV->isDSOLocal()); +          CantBePromoted.insert(GV->getGUID()); +          // Create the appropriate summary type. +          if (Function *F = dyn_cast<Function>(GV)) { +            std::unique_ptr<FunctionSummary> Summary = +                llvm::make_unique<FunctionSummary>( +                    GVFlags, 0, +                    FunctionSummary::FFlags{ +                        F->hasFnAttribute(Attribute::ReadNone), +                        F->hasFnAttribute(Attribute::ReadOnly), +                        F->hasFnAttribute(Attribute::NoRecurse), +                        F->returnDoesNotAlias()}, +                    ArrayRef<ValueInfo>{}, ArrayRef<FunctionSummary::EdgeTy>{}, +                    ArrayRef<GlobalValue::GUID>{}, +                    ArrayRef<FunctionSummary::VFuncId>{}, +                    ArrayRef<FunctionSummary::VFuncId>{}, +                    ArrayRef<FunctionSummary::ConstVCall>{}, +                    ArrayRef<FunctionSummary::ConstVCall>{}); +            Index.addGlobalValueSummary(*GV, std::move(Summary)); +          } else { +            std::unique_ptr<GlobalVarSummary> Summary = +                llvm::make_unique<GlobalVarSummary>(GVFlags, +                                                    ArrayRef<ValueInfo>{}); +            Index.addGlobalValueSummary(*GV, std::move(Summary)); +          } +        }); +  } + +  // Compute summaries for all functions defined in module, and save in the +  // index. +  for (auto &F : M) { +    if (F.isDeclaration()) +      continue; + +    BlockFrequencyInfo *BFI = nullptr; +    std::unique_ptr<BlockFrequencyInfo> BFIPtr; +    if (GetBFICallback) +      BFI = GetBFICallback(F); +    else if (F.hasProfileData()) { +      LoopInfo LI{DominatorTree(const_cast<Function &>(F))}; +      BranchProbabilityInfo BPI{F, LI}; +      BFIPtr = llvm::make_unique<BlockFrequencyInfo>(F, BPI, LI); +      BFI = BFIPtr.get(); +    } + +    computeFunctionSummary(Index, M, F, BFI, PSI, +                           !LocalsUsed.empty() || HasLocalInlineAsmSymbol, +                           CantBePromoted); +  } + +  // Compute summaries for all variables defined in module, and save in the +  // index. +  for (const GlobalVariable &G : M.globals()) { +    if (G.isDeclaration()) +      continue; +    computeVariableSummary(Index, G, CantBePromoted); +  } + +  // Compute summaries for all aliases defined in module, and save in the +  // index. +  for (const GlobalAlias &A : M.aliases()) +    computeAliasSummary(Index, A, CantBePromoted); + +  for (auto *V : LocalsUsed) { +    auto *Summary = Index.getGlobalValueSummary(*V); +    assert(Summary && "Missing summary for global value"); +    Summary->setNotEligibleToImport(); +  } + +  // The linker doesn't know about these LLVM produced values, so we need +  // to flag them as live in the index to ensure index-based dead value +  // analysis treats them as live roots of the analysis. +  setLiveRoot(Index, "llvm.used"); +  setLiveRoot(Index, "llvm.compiler.used"); +  setLiveRoot(Index, "llvm.global_ctors"); +  setLiveRoot(Index, "llvm.global_dtors"); +  setLiveRoot(Index, "llvm.global.annotations"); + +  bool IsThinLTO = true; +  if (auto *MD = +          mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("ThinLTO"))) +    IsThinLTO = MD->getZExtValue(); + +  for (auto &GlobalList : Index) { +    // Ignore entries for references that are undefined in the current module. +    if (GlobalList.second.SummaryList.empty()) +      continue; + +    assert(GlobalList.second.SummaryList.size() == 1 && +           "Expected module's index to have one summary per GUID"); +    auto &Summary = GlobalList.second.SummaryList[0]; +    if (!IsThinLTO) { +      Summary->setNotEligibleToImport(); +      continue; +    } + +    bool AllRefsCanBeExternallyReferenced = +        llvm::all_of(Summary->refs(), [&](const ValueInfo &VI) { +          return !CantBePromoted.count(VI.getGUID()); +        }); +    if (!AllRefsCanBeExternallyReferenced) { +      Summary->setNotEligibleToImport(); +      continue; +    } + +    if (auto *FuncSummary = dyn_cast<FunctionSummary>(Summary.get())) { +      bool AllCallsCanBeExternallyReferenced = llvm::all_of( +          FuncSummary->calls(), [&](const FunctionSummary::EdgeTy &Edge) { +            return !CantBePromoted.count(Edge.first.getGUID()); +          }); +      if (!AllCallsCanBeExternallyReferenced) +        Summary->setNotEligibleToImport(); +    } +  } + +  return Index; +} + +AnalysisKey ModuleSummaryIndexAnalysis::Key; + +ModuleSummaryIndex +ModuleSummaryIndexAnalysis::run(Module &M, ModuleAnalysisManager &AM) { +  ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M); +  auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); +  return buildModuleSummaryIndex( +      M, +      [&FAM](const Function &F) { +        return &FAM.getResult<BlockFrequencyAnalysis>( +            *const_cast<Function *>(&F)); +      }, +      &PSI); +} + +char ModuleSummaryIndexWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ModuleSummaryIndexWrapperPass, "module-summary-analysis", +                      "Module Summary Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_END(ModuleSummaryIndexWrapperPass, "module-summary-analysis", +                    "Module Summary Analysis", false, true) + +ModulePass *llvm::createModuleSummaryIndexWrapperPass() { +  return new ModuleSummaryIndexWrapperPass(); +} + +ModuleSummaryIndexWrapperPass::ModuleSummaryIndexWrapperPass() +    : ModulePass(ID) { +  initializeModuleSummaryIndexWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ModuleSummaryIndexWrapperPass::runOnModule(Module &M) { +  auto &PSI = *getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); +  Index.emplace(buildModuleSummaryIndex( +      M, +      [this](const Function &F) { +        return &(this->getAnalysis<BlockFrequencyInfoWrapperPass>( +                         *const_cast<Function *>(&F)) +                     .getBFI()); +      }, +      &PSI)); +  return false; +} + +bool ModuleSummaryIndexWrapperPass::doFinalization(Module &M) { +  Index.reset(); +  return false; +} + +void ModuleSummaryIndexWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<BlockFrequencyInfoWrapperPass>(); +  AU.addRequired<ProfileSummaryInfoWrapperPass>(); +} diff --git a/contrib/llvm/lib/Analysis/MustExecute.cpp b/contrib/llvm/lib/Analysis/MustExecute.cpp new file mode 100644 index 000000000000..8e85366b4618 --- /dev/null +++ b/contrib/llvm/lib/Analysis/MustExecute.cpp @@ -0,0 +1,269 @@ +//===- MustExecute.cpp - Printer for isGuaranteedToExecute ----------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/MustExecute.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +/// Computes loop safety information, checks loop body & header +/// for the possibility of may throw exception. +/// +void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { +  assert(CurLoop != nullptr && "CurLoop can't be null"); +  BasicBlock *Header = CurLoop->getHeader(); +  // Setting default safety values. +  SafetyInfo->MayThrow = false; +  SafetyInfo->HeaderMayThrow = false; +  // Iterate over header and compute safety info. +  SafetyInfo->HeaderMayThrow = +    !isGuaranteedToTransferExecutionToSuccessor(Header); + +  SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow; +  // Iterate over loop instructions and compute safety info. +  // Skip header as it has been computed and stored in HeaderMayThrow. +  // The first block in loopinfo.Blocks is guaranteed to be the header. +  assert(Header == *CurLoop->getBlocks().begin() && +         "First block must be header"); +  for (Loop::block_iterator BB = std::next(CurLoop->block_begin()), +                            BBE = CurLoop->block_end(); +       (BB != BBE) && !SafetyInfo->MayThrow; ++BB) +    SafetyInfo->MayThrow |= +      !isGuaranteedToTransferExecutionToSuccessor(*BB); + +  // Compute funclet colors if we might sink/hoist in a function with a funclet +  // personality routine. +  Function *Fn = CurLoop->getHeader()->getParent(); +  if (Fn->hasPersonalityFn()) +    if (Constant *PersonalityFn = Fn->getPersonalityFn()) +      if (isScopedEHPersonality(classifyEHPersonality(PersonalityFn))) +        SafetyInfo->BlockColors = colorEHFunclets(*Fn); +} + +/// Return true if we can prove that the given ExitBlock is not reached on the +/// first iteration of the given loop.  That is, the backedge of the loop must +/// be executed before the ExitBlock is executed in any dynamic execution trace. +static bool CanProveNotTakenFirstIteration(BasicBlock *ExitBlock, +                                           const DominatorTree *DT, +                                           const Loop *CurLoop) { +  auto *CondExitBlock = ExitBlock->getSinglePredecessor(); +  if (!CondExitBlock) +    // expect unique exits +    return false; +  assert(CurLoop->contains(CondExitBlock) && "meaning of exit block"); +  auto *BI = dyn_cast<BranchInst>(CondExitBlock->getTerminator()); +  if (!BI || !BI->isConditional()) +    return false; +  // If condition is constant and false leads to ExitBlock then we always +  // execute the true branch. +  if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) +    return BI->getSuccessor(Cond->getZExtValue() ? 1 : 0) == ExitBlock; +  auto *Cond = dyn_cast<CmpInst>(BI->getCondition()); +  if (!Cond) +    return false; +  // todo: this would be a lot more powerful if we used scev, but all the +  // plumbing is currently missing to pass a pointer in from the pass +  // Check for cmp (phi [x, preheader] ...), y where (pred x, y is known +  auto *LHS = dyn_cast<PHINode>(Cond->getOperand(0)); +  auto *RHS = Cond->getOperand(1); +  if (!LHS || LHS->getParent() != CurLoop->getHeader()) +    return false; +  auto DL = ExitBlock->getModule()->getDataLayout(); +  auto *IVStart = LHS->getIncomingValueForBlock(CurLoop->getLoopPreheader()); +  auto *SimpleValOrNull = SimplifyCmpInst(Cond->getPredicate(), +                                          IVStart, RHS, +                                          {DL, /*TLI*/ nullptr, +                                              DT, /*AC*/ nullptr, BI}); +  auto *SimpleCst = dyn_cast_or_null<Constant>(SimpleValOrNull); +  if (!SimpleCst) +    return false; +  if (ExitBlock == BI->getSuccessor(0)) +    return SimpleCst->isZeroValue(); +  assert(ExitBlock == BI->getSuccessor(1) && "implied by above"); +  return SimpleCst->isAllOnesValue(); +} + +/// Returns true if the instruction in a loop is guaranteed to execute at least +/// once. +bool llvm::isGuaranteedToExecute(const Instruction &Inst, +                                 const DominatorTree *DT, const Loop *CurLoop, +                                 const LoopSafetyInfo *SafetyInfo) { +  // We have to check to make sure that the instruction dominates all +  // of the exit blocks.  If it doesn't, then there is a path out of the loop +  // which does not execute this instruction, so we can't hoist it. + +  // If the instruction is in the header block for the loop (which is very +  // common), it is always guaranteed to dominate the exit blocks.  Since this +  // is a common case, and can save some work, check it now. +  if (Inst.getParent() == CurLoop->getHeader()) +    // If there's a throw in the header block, we can't guarantee we'll reach +    // Inst unless we can prove that Inst comes before the potential implicit +    // exit.  At the moment, we use a (cheap) hack for the common case where +    // the instruction of interest is the first one in the block. +    return !SafetyInfo->HeaderMayThrow || +      Inst.getParent()->getFirstNonPHIOrDbg() == &Inst; + +  // Somewhere in this loop there is an instruction which may throw and make us +  // exit the loop. +  if (SafetyInfo->MayThrow) +    return false; + +  // Note: There are two styles of reasoning intermixed below for +  // implementation efficiency reasons.  They are: +  // 1) If we can prove that the instruction dominates all exit blocks, then we +  // know the instruction must have executed on *some* iteration before we +  // exit.  We do not prove *which* iteration the instruction must execute on. +  // 2) If we can prove that the instruction dominates the latch and all exits +  // which might be taken on the first iteration, we know the instruction must +  // execute on the first iteration.  This second style allows a conditional +  // exit before the instruction of interest which is provably not taken on the +  // first iteration.  This is a quite common case for range check like +  // patterns.  TODO: support loops with multiple latches. + +  const bool InstDominatesLatch = +    CurLoop->getLoopLatch() != nullptr && +    DT->dominates(Inst.getParent(), CurLoop->getLoopLatch()); + +  // Get the exit blocks for the current loop. +  SmallVector<BasicBlock *, 8> ExitBlocks; +  CurLoop->getExitBlocks(ExitBlocks); + +  // Verify that the block dominates each of the exit blocks of the loop. +  for (BasicBlock *ExitBlock : ExitBlocks) +    if (!DT->dominates(Inst.getParent(), ExitBlock)) +      if (!InstDominatesLatch || +          !CanProveNotTakenFirstIteration(ExitBlock, DT, CurLoop)) +        return false; + +  // As a degenerate case, if the loop is statically infinite then we haven't +  // proven anything since there are no exit blocks. +  if (ExitBlocks.empty()) +    return false; + +  // FIXME: In general, we have to prove that the loop isn't an infinite loop. +  // See http::llvm.org/PR24078 .  (The "ExitBlocks.empty()" check above is +  // just a special case of this.) +  return true; +} + + +namespace { +  struct MustExecutePrinter : public FunctionPass { + +    static char ID; // Pass identification, replacement for typeid +    MustExecutePrinter() : FunctionPass(ID) { +      initializeMustExecutePrinterPass(*PassRegistry::getPassRegistry()); +    } +    void getAnalysisUsage(AnalysisUsage &AU) const override { +      AU.setPreservesAll(); +      AU.addRequired<DominatorTreeWrapperPass>(); +      AU.addRequired<LoopInfoWrapperPass>(); +    } +    bool runOnFunction(Function &F) override; +  }; +} + +char MustExecutePrinter::ID = 0; +INITIALIZE_PASS_BEGIN(MustExecutePrinter, "print-mustexecute", +                      "Instructions which execute on loop entry", false, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(MustExecutePrinter, "print-mustexecute", +                    "Instructions which execute on loop entry", false, true) + +FunctionPass *llvm::createMustExecutePrinter() { +  return new MustExecutePrinter(); +} + +static bool isMustExecuteIn(const Instruction &I, Loop *L, DominatorTree *DT) { +  // TODO: merge these two routines.  For the moment, we display the best +  // result obtained by *either* implementation.  This is a bit unfair since no +  // caller actually gets the full power at the moment. +  LoopSafetyInfo LSI; +  computeLoopSafetyInfo(&LSI, L); +  return isGuaranteedToExecute(I, DT, L, &LSI) || +    isGuaranteedToExecuteForEveryIteration(&I, L); +} + +namespace { +/// An assembly annotator class to print must execute information in +/// comments. +class MustExecuteAnnotatedWriter : public AssemblyAnnotationWriter { +  DenseMap<const Value*, SmallVector<Loop*, 4> > MustExec; + +public: +  MustExecuteAnnotatedWriter(const Function &F, +                             DominatorTree &DT, LoopInfo &LI) { +    for (auto &I: instructions(F)) { +      Loop *L = LI.getLoopFor(I.getParent()); +      while (L) { +        if (isMustExecuteIn(I, L, &DT)) { +          MustExec[&I].push_back(L); +        } +        L = L->getParentLoop(); +      }; +    } +  } +  MustExecuteAnnotatedWriter(const Module &M, +                             DominatorTree &DT, LoopInfo &LI) { +    for (auto &F : M) +    for (auto &I: instructions(F)) { +      Loop *L = LI.getLoopFor(I.getParent()); +      while (L) { +        if (isMustExecuteIn(I, L, &DT)) { +          MustExec[&I].push_back(L); +        } +        L = L->getParentLoop(); +      }; +    } +  } + + +  void printInfoComment(const Value &V, formatted_raw_ostream &OS) override { +    if (!MustExec.count(&V)) +      return; + +    const auto &Loops = MustExec.lookup(&V); +    const auto NumLoops = Loops.size(); +    if (NumLoops > 1) +      OS << " ; (mustexec in " << NumLoops << " loops: "; +    else +      OS << " ; (mustexec in: "; + +    bool first = true; +    for (const Loop *L : Loops) { +      if (!first) +        OS << ", "; +      first = false; +      OS << L->getHeader()->getName(); +    } +    OS << ")"; +  } +}; +} // namespace + +bool MustExecutePrinter::runOnFunction(Function &F) { +  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + +  MustExecuteAnnotatedWriter Writer(F, DT, LI); +  F.print(dbgs(), &Writer); + +  return false; +} diff --git a/contrib/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp new file mode 100644 index 000000000000..096ea661ecb6 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ObjCARCAliasAnalysis.cpp @@ -0,0 +1,162 @@ +//===- ObjCARCAliasAnalysis.cpp - ObjC ARC Optimization -------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines a simple ARC-aware AliasAnalysis using special knowledge +/// of Objective C to enhance other optimization passes which rely on the Alias +/// Analysis infrastructure. +/// +/// WARNING: This file knows about certain library functions. It recognizes them +/// by name, and hardwires knowledge of their semantics. +/// +/// WARNING: This file knows about how certain Objective-C library functions are +/// used. Naive LLVM IR transformations which would otherwise be +/// behavior-preserving may break these assumptions. +/// +/// TODO: Theoretically we could check for dependencies between objc_* calls +/// and FMRB_OnlyAccessesArgumentPointees calls or other well-behaved calls. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ObjCARCAliasAnalysis.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/PassAnalysisSupport.h" +#include "llvm/PassSupport.h" + +#define DEBUG_TYPE "objc-arc-aa" + +using namespace llvm; +using namespace llvm::objcarc; + +AliasResult ObjCARCAAResult::alias(const MemoryLocation &LocA, +                                   const MemoryLocation &LocB) { +  if (!EnableARCOpts) +    return AAResultBase::alias(LocA, LocB); + +  // First, strip off no-ops, including ObjC-specific no-ops, and try making a +  // precise alias query. +  const Value *SA = GetRCIdentityRoot(LocA.Ptr); +  const Value *SB = GetRCIdentityRoot(LocB.Ptr); +  AliasResult Result = +      AAResultBase::alias(MemoryLocation(SA, LocA.Size, LocA.AATags), +                          MemoryLocation(SB, LocB.Size, LocB.AATags)); +  if (Result != MayAlias) +    return Result; + +  // If that failed, climb to the underlying object, including climbing through +  // ObjC-specific no-ops, and try making an imprecise alias query. +  const Value *UA = GetUnderlyingObjCPtr(SA, DL); +  const Value *UB = GetUnderlyingObjCPtr(SB, DL); +  if (UA != SA || UB != SB) { +    Result = AAResultBase::alias(MemoryLocation(UA), MemoryLocation(UB)); +    // We can't use MustAlias or PartialAlias results here because +    // GetUnderlyingObjCPtr may return an offsetted pointer value. +    if (Result == NoAlias) +      return NoAlias; +  } + +  // If that failed, fail. We don't need to chain here, since that's covered +  // by the earlier precise query. +  return MayAlias; +} + +bool ObjCARCAAResult::pointsToConstantMemory(const MemoryLocation &Loc, +                                             bool OrLocal) { +  if (!EnableARCOpts) +    return AAResultBase::pointsToConstantMemory(Loc, OrLocal); + +  // First, strip off no-ops, including ObjC-specific no-ops, and try making +  // a precise alias query. +  const Value *S = GetRCIdentityRoot(Loc.Ptr); +  if (AAResultBase::pointsToConstantMemory( +          MemoryLocation(S, Loc.Size, Loc.AATags), OrLocal)) +    return true; + +  // If that failed, climb to the underlying object, including climbing through +  // ObjC-specific no-ops, and try making an imprecise alias query. +  const Value *U = GetUnderlyingObjCPtr(S, DL); +  if (U != S) +    return AAResultBase::pointsToConstantMemory(MemoryLocation(U), OrLocal); + +  // If that failed, fail. We don't need to chain here, since that's covered +  // by the earlier precise query. +  return false; +} + +FunctionModRefBehavior ObjCARCAAResult::getModRefBehavior(const Function *F) { +  if (!EnableARCOpts) +    return AAResultBase::getModRefBehavior(F); + +  switch (GetFunctionClass(F)) { +  case ARCInstKind::NoopCast: +    return FMRB_DoesNotAccessMemory; +  default: +    break; +  } + +  return AAResultBase::getModRefBehavior(F); +} + +ModRefInfo ObjCARCAAResult::getModRefInfo(ImmutableCallSite CS, +                                          const MemoryLocation &Loc) { +  if (!EnableARCOpts) +    return AAResultBase::getModRefInfo(CS, Loc); + +  switch (GetBasicARCInstKind(CS.getInstruction())) { +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::NoopCast: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +    // These functions don't access any memory visible to the compiler. +    // Note that this doesn't include objc_retainBlock, because it updates +    // pointers when it copies block data. +    return ModRefInfo::NoModRef; +  default: +    break; +  } + +  return AAResultBase::getModRefInfo(CS, Loc); +} + +ObjCARCAAResult ObjCARCAA::run(Function &F, FunctionAnalysisManager &AM) { +  return ObjCARCAAResult(F.getParent()->getDataLayout()); +} + +char ObjCARCAAWrapperPass::ID = 0; +INITIALIZE_PASS(ObjCARCAAWrapperPass, "objc-arc-aa", +                "ObjC-ARC-Based Alias Analysis", false, true) + +ImmutablePass *llvm::createObjCARCAAWrapperPass() { +  return new ObjCARCAAWrapperPass(); +} + +ObjCARCAAWrapperPass::ObjCARCAAWrapperPass() : ImmutablePass(ID) { +  initializeObjCARCAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ObjCARCAAWrapperPass::doInitialization(Module &M) { +  Result.reset(new ObjCARCAAResult(M.getDataLayout())); +  return false; +} + +bool ObjCARCAAWrapperPass::doFinalization(Module &M) { +  Result.reset(); +  return false; +} + +void ObjCARCAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +} diff --git a/contrib/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp b/contrib/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp new file mode 100644 index 000000000000..d6db6386c38b --- /dev/null +++ b/contrib/llvm/lib/Analysis/ObjCARCAnalysisUtils.cpp @@ -0,0 +1,26 @@ +//===- ObjCARCAnalysisUtils.cpp -------------------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements common infrastructure for libLLVMObjCARCOpts.a, which +// implements several scalar transformations over the LLVM intermediate +// representation, including the C bindings for that library. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; +using namespace llvm::objcarc; + +/// A handy option to enable/disable all ARC Optimizations. +bool llvm::objcarc::EnableARCOpts; +static cl::opt<bool, true> EnableARCOptimizations( +    "enable-objc-arc-opts", cl::desc("enable/disable all ARC Optimizations"), +    cl::location(EnableARCOpts), cl::init(true), cl::Hidden); diff --git a/contrib/llvm/lib/Analysis/ObjCARCInstKind.cpp b/contrib/llvm/lib/Analysis/ObjCARCInstKind.cpp new file mode 100644 index 000000000000..f268e2a9abdd --- /dev/null +++ b/contrib/llvm/lib/Analysis/ObjCARCInstKind.cpp @@ -0,0 +1,695 @@ +//===- ARCInstKind.cpp - ObjC ARC Optimization ----------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines several utility functions used by various ARC +/// optimizations which are IMHO too big to be in a header file. +/// +/// WARNING: This file knows about certain library functions. It recognizes them +/// by name, and hardwires knowledge of their semantics. +/// +/// WARNING: This file knows about how certain Objective-C library functions are +/// used. Naive LLVM IR transformations which would otherwise be +/// behavior-preserving may break these assumptions. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ObjCARCInstKind.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/IR/Intrinsics.h" + +using namespace llvm; +using namespace llvm::objcarc; + +raw_ostream &llvm::objcarc::operator<<(raw_ostream &OS, +                                       const ARCInstKind Class) { +  switch (Class) { +  case ARCInstKind::Retain: +    return OS << "ARCInstKind::Retain"; +  case ARCInstKind::RetainRV: +    return OS << "ARCInstKind::RetainRV"; +  case ARCInstKind::ClaimRV: +    return OS << "ARCInstKind::ClaimRV"; +  case ARCInstKind::RetainBlock: +    return OS << "ARCInstKind::RetainBlock"; +  case ARCInstKind::Release: +    return OS << "ARCInstKind::Release"; +  case ARCInstKind::Autorelease: +    return OS << "ARCInstKind::Autorelease"; +  case ARCInstKind::AutoreleaseRV: +    return OS << "ARCInstKind::AutoreleaseRV"; +  case ARCInstKind::AutoreleasepoolPush: +    return OS << "ARCInstKind::AutoreleasepoolPush"; +  case ARCInstKind::AutoreleasepoolPop: +    return OS << "ARCInstKind::AutoreleasepoolPop"; +  case ARCInstKind::NoopCast: +    return OS << "ARCInstKind::NoopCast"; +  case ARCInstKind::FusedRetainAutorelease: +    return OS << "ARCInstKind::FusedRetainAutorelease"; +  case ARCInstKind::FusedRetainAutoreleaseRV: +    return OS << "ARCInstKind::FusedRetainAutoreleaseRV"; +  case ARCInstKind::LoadWeakRetained: +    return OS << "ARCInstKind::LoadWeakRetained"; +  case ARCInstKind::StoreWeak: +    return OS << "ARCInstKind::StoreWeak"; +  case ARCInstKind::InitWeak: +    return OS << "ARCInstKind::InitWeak"; +  case ARCInstKind::LoadWeak: +    return OS << "ARCInstKind::LoadWeak"; +  case ARCInstKind::MoveWeak: +    return OS << "ARCInstKind::MoveWeak"; +  case ARCInstKind::CopyWeak: +    return OS << "ARCInstKind::CopyWeak"; +  case ARCInstKind::DestroyWeak: +    return OS << "ARCInstKind::DestroyWeak"; +  case ARCInstKind::StoreStrong: +    return OS << "ARCInstKind::StoreStrong"; +  case ARCInstKind::CallOrUser: +    return OS << "ARCInstKind::CallOrUser"; +  case ARCInstKind::Call: +    return OS << "ARCInstKind::Call"; +  case ARCInstKind::User: +    return OS << "ARCInstKind::User"; +  case ARCInstKind::IntrinsicUser: +    return OS << "ARCInstKind::IntrinsicUser"; +  case ARCInstKind::None: +    return OS << "ARCInstKind::None"; +  } +  llvm_unreachable("Unknown instruction class!"); +} + +ARCInstKind llvm::objcarc::GetFunctionClass(const Function *F) { +  Function::const_arg_iterator AI = F->arg_begin(), AE = F->arg_end(); + +  // No (mandatory) arguments. +  if (AI == AE) +    return StringSwitch<ARCInstKind>(F->getName()) +        .Case("objc_autoreleasePoolPush", ARCInstKind::AutoreleasepoolPush) +        .Case("clang.arc.use", ARCInstKind::IntrinsicUser) +        .Default(ARCInstKind::CallOrUser); + +  // One argument. +  const Argument *A0 = &*AI++; +  if (AI == AE) { +    // Argument is a pointer. +    PointerType *PTy = dyn_cast<PointerType>(A0->getType()); +    if (!PTy) +      return ARCInstKind::CallOrUser; + +    Type *ETy = PTy->getElementType(); +    // Argument is i8*. +    if (ETy->isIntegerTy(8)) +      return StringSwitch<ARCInstKind>(F->getName()) +          .Case("objc_retain", ARCInstKind::Retain) +          .Case("objc_retainAutoreleasedReturnValue", ARCInstKind::RetainRV) +          .Case("objc_unsafeClaimAutoreleasedReturnValue", ARCInstKind::ClaimRV) +          .Case("objc_retainBlock", ARCInstKind::RetainBlock) +          .Case("objc_release", ARCInstKind::Release) +          .Case("objc_autorelease", ARCInstKind::Autorelease) +          .Case("objc_autoreleaseReturnValue", ARCInstKind::AutoreleaseRV) +          .Case("objc_autoreleasePoolPop", ARCInstKind::AutoreleasepoolPop) +          .Case("objc_retainedObject", ARCInstKind::NoopCast) +          .Case("objc_unretainedObject", ARCInstKind::NoopCast) +          .Case("objc_unretainedPointer", ARCInstKind::NoopCast) +          .Case("objc_retain_autorelease", ARCInstKind::FusedRetainAutorelease) +          .Case("objc_retainAutorelease", ARCInstKind::FusedRetainAutorelease) +          .Case("objc_retainAutoreleaseReturnValue", +                ARCInstKind::FusedRetainAutoreleaseRV) +          .Case("objc_sync_enter", ARCInstKind::User) +          .Case("objc_sync_exit", ARCInstKind::User) +          .Default(ARCInstKind::CallOrUser); + +    // Argument is i8** +    if (PointerType *Pte = dyn_cast<PointerType>(ETy)) +      if (Pte->getElementType()->isIntegerTy(8)) +        return StringSwitch<ARCInstKind>(F->getName()) +            .Case("objc_loadWeakRetained", ARCInstKind::LoadWeakRetained) +            .Case("objc_loadWeak", ARCInstKind::LoadWeak) +            .Case("objc_destroyWeak", ARCInstKind::DestroyWeak) +            .Default(ARCInstKind::CallOrUser); + +    // Anything else with one argument. +    return ARCInstKind::CallOrUser; +  } + +  // Two arguments, first is i8**. +  const Argument *A1 = &*AI++; +  if (AI == AE) +    if (PointerType *PTy = dyn_cast<PointerType>(A0->getType())) +      if (PointerType *Pte = dyn_cast<PointerType>(PTy->getElementType())) +        if (Pte->getElementType()->isIntegerTy(8)) +          if (PointerType *PTy1 = dyn_cast<PointerType>(A1->getType())) { +            Type *ETy1 = PTy1->getElementType(); +            // Second argument is i8* +            if (ETy1->isIntegerTy(8)) +              return StringSwitch<ARCInstKind>(F->getName()) +                  .Case("objc_storeWeak", ARCInstKind::StoreWeak) +                  .Case("objc_initWeak", ARCInstKind::InitWeak) +                  .Case("objc_storeStrong", ARCInstKind::StoreStrong) +                  .Default(ARCInstKind::CallOrUser); +            // Second argument is i8**. +            if (PointerType *Pte1 = dyn_cast<PointerType>(ETy1)) +              if (Pte1->getElementType()->isIntegerTy(8)) +                return StringSwitch<ARCInstKind>(F->getName()) +                    .Case("objc_moveWeak", ARCInstKind::MoveWeak) +                    .Case("objc_copyWeak", ARCInstKind::CopyWeak) +                    // Ignore annotation calls. This is important to stop the +                    // optimizer from treating annotations as uses which would +                    // make the state of the pointers they are attempting to +                    // elucidate to be incorrect. +                    .Case("llvm.arc.annotation.topdown.bbstart", +                          ARCInstKind::None) +                    .Case("llvm.arc.annotation.topdown.bbend", +                          ARCInstKind::None) +                    .Case("llvm.arc.annotation.bottomup.bbstart", +                          ARCInstKind::None) +                    .Case("llvm.arc.annotation.bottomup.bbend", +                          ARCInstKind::None) +                    .Default(ARCInstKind::CallOrUser); +          } + +  // Anything else. +  return ARCInstKind::CallOrUser; +} + +// A whitelist of intrinsics that we know do not use objc pointers or decrement +// ref counts. +static bool isInertIntrinsic(unsigned ID) { +  // TODO: Make this into a covered switch. +  switch (ID) { +  case Intrinsic::returnaddress: +  case Intrinsic::addressofreturnaddress: +  case Intrinsic::frameaddress: +  case Intrinsic::stacksave: +  case Intrinsic::stackrestore: +  case Intrinsic::vastart: +  case Intrinsic::vacopy: +  case Intrinsic::vaend: +  case Intrinsic::objectsize: +  case Intrinsic::prefetch: +  case Intrinsic::stackprotector: +  case Intrinsic::eh_return_i32: +  case Intrinsic::eh_return_i64: +  case Intrinsic::eh_typeid_for: +  case Intrinsic::eh_dwarf_cfa: +  case Intrinsic::eh_sjlj_lsda: +  case Intrinsic::eh_sjlj_functioncontext: +  case Intrinsic::init_trampoline: +  case Intrinsic::adjust_trampoline: +  case Intrinsic::lifetime_start: +  case Intrinsic::lifetime_end: +  case Intrinsic::invariant_start: +  case Intrinsic::invariant_end: +  // Don't let dbg info affect our results. +  case Intrinsic::dbg_declare: +  case Intrinsic::dbg_value: +  case Intrinsic::dbg_label: +    // Short cut: Some intrinsics obviously don't use ObjC pointers. +    return true; +  default: +    return false; +  } +} + +// A whitelist of intrinsics that we know do not use objc pointers or decrement +// ref counts. +static bool isUseOnlyIntrinsic(unsigned ID) { +  // We are conservative and even though intrinsics are unlikely to touch +  // reference counts, we white list them for safety. +  // +  // TODO: Expand this into a covered switch. There is a lot more here. +  switch (ID) { +  case Intrinsic::memcpy: +  case Intrinsic::memmove: +  case Intrinsic::memset: +    return true; +  default: +    return false; +  } +} + +/// Determine what kind of construct V is. +ARCInstKind llvm::objcarc::GetARCInstKind(const Value *V) { +  if (const Instruction *I = dyn_cast<Instruction>(V)) { +    // Any instruction other than bitcast and gep with a pointer operand have a +    // use of an objc pointer. Bitcasts, GEPs, Selects, PHIs transfer a pointer +    // to a subsequent use, rather than using it themselves, in this sense. +    // As a short cut, several other opcodes are known to have no pointer +    // operands of interest. And ret is never followed by a release, so it's +    // not interesting to examine. +    switch (I->getOpcode()) { +    case Instruction::Call: { +      const CallInst *CI = cast<CallInst>(I); +      // See if we have a function that we know something about. +      if (const Function *F = CI->getCalledFunction()) { +        ARCInstKind Class = GetFunctionClass(F); +        if (Class != ARCInstKind::CallOrUser) +          return Class; +        Intrinsic::ID ID = F->getIntrinsicID(); +        if (isInertIntrinsic(ID)) +          return ARCInstKind::None; +        if (isUseOnlyIntrinsic(ID)) +          return ARCInstKind::User; +      } + +      // Otherwise, be conservative. +      return GetCallSiteClass(CI); +    } +    case Instruction::Invoke: +      // Otherwise, be conservative. +      return GetCallSiteClass(cast<InvokeInst>(I)); +    case Instruction::BitCast: +    case Instruction::GetElementPtr: +    case Instruction::Select: +    case Instruction::PHI: +    case Instruction::Ret: +    case Instruction::Br: +    case Instruction::Switch: +    case Instruction::IndirectBr: +    case Instruction::Alloca: +    case Instruction::VAArg: +    case Instruction::Add: +    case Instruction::FAdd: +    case Instruction::Sub: +    case Instruction::FSub: +    case Instruction::Mul: +    case Instruction::FMul: +    case Instruction::SDiv: +    case Instruction::UDiv: +    case Instruction::FDiv: +    case Instruction::SRem: +    case Instruction::URem: +    case Instruction::FRem: +    case Instruction::Shl: +    case Instruction::LShr: +    case Instruction::AShr: +    case Instruction::And: +    case Instruction::Or: +    case Instruction::Xor: +    case Instruction::SExt: +    case Instruction::ZExt: +    case Instruction::Trunc: +    case Instruction::IntToPtr: +    case Instruction::FCmp: +    case Instruction::FPTrunc: +    case Instruction::FPExt: +    case Instruction::FPToUI: +    case Instruction::FPToSI: +    case Instruction::UIToFP: +    case Instruction::SIToFP: +    case Instruction::InsertElement: +    case Instruction::ExtractElement: +    case Instruction::ShuffleVector: +    case Instruction::ExtractValue: +      break; +    case Instruction::ICmp: +      // Comparing a pointer with null, or any other constant, isn't an +      // interesting use, because we don't care what the pointer points to, or +      // about the values of any other dynamic reference-counted pointers. +      if (IsPotentialRetainableObjPtr(I->getOperand(1))) +        return ARCInstKind::User; +      break; +    default: +      // For anything else, check all the operands. +      // Note that this includes both operands of a Store: while the first +      // operand isn't actually being dereferenced, it is being stored to +      // memory where we can no longer track who might read it and dereference +      // it, so we have to consider it potentially used. +      for (User::const_op_iterator OI = I->op_begin(), OE = I->op_end(); +           OI != OE; ++OI) +        if (IsPotentialRetainableObjPtr(*OI)) +          return ARCInstKind::User; +    } +  } + +  // Otherwise, it's totally inert for ARC purposes. +  return ARCInstKind::None; +} + +/// Test if the given class is a kind of user. +bool llvm::objcarc::IsUser(ARCInstKind Class) { +  switch (Class) { +  case ARCInstKind::User: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::IntrinsicUser: +    return true; +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::RetainBlock: +  case ARCInstKind::Release: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::NoopCast: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::Call: +  case ARCInstKind::None: +  case ARCInstKind::ClaimRV: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class is objc_retain or equivalent. +bool llvm::objcarc::IsRetain(ARCInstKind Class) { +  switch (Class) { +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +    return true; +  // I believe we treat retain block as not a retain since it can copy its +  // block. +  case ARCInstKind::RetainBlock: +  case ARCInstKind::Release: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::NoopCast: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::User: +  case ARCInstKind::None: +  case ARCInstKind::ClaimRV: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class is objc_autorelease or equivalent. +bool llvm::objcarc::IsAutorelease(ARCInstKind Class) { +  switch (Class) { +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +    return true; +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::ClaimRV: +  case ARCInstKind::RetainBlock: +  case ARCInstKind::Release: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::NoopCast: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::User: +  case ARCInstKind::None: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which return their +/// argument verbatim. +bool llvm::objcarc::IsForwarding(ARCInstKind Class) { +  switch (Class) { +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::ClaimRV: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::NoopCast: +    return true; +  case ARCInstKind::RetainBlock: +  case ARCInstKind::Release: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::User: +  case ARCInstKind::None: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which do nothing if +/// passed a null pointer. +bool llvm::objcarc::IsNoopOnNull(ARCInstKind Class) { +  switch (Class) { +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::ClaimRV: +  case ARCInstKind::Release: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::RetainBlock: +    return true; +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::User: +  case ARCInstKind::None: +  case ARCInstKind::NoopCast: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which are always safe +/// to mark with the "tail" keyword. +bool llvm::objcarc::IsAlwaysTail(ARCInstKind Class) { +  // ARCInstKind::RetainBlock may be given a stack argument. +  switch (Class) { +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::ClaimRV: +  case ARCInstKind::AutoreleaseRV: +    return true; +  case ARCInstKind::Release: +  case ARCInstKind::Autorelease: +  case ARCInstKind::RetainBlock: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::User: +  case ARCInstKind::None: +  case ARCInstKind::NoopCast: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which are never safe +/// to mark with the "tail" keyword. +bool llvm::objcarc::IsNeverTail(ARCInstKind Class) { +  /// It is never safe to tail call objc_autorelease since by tail calling +  /// objc_autorelease: fast autoreleasing causing our object to be potentially +  /// reclaimed from the autorelease pool which violates the semantics of +  /// __autoreleasing types in ARC. +  switch (Class) { +  case ARCInstKind::Autorelease: +    return true; +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::ClaimRV: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::Release: +  case ARCInstKind::RetainBlock: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::User: +  case ARCInstKind::None: +  case ARCInstKind::NoopCast: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test if the given class represents instructions which are always safe +/// to mark with the nounwind attribute. +bool llvm::objcarc::IsNoThrow(ARCInstKind Class) { +  // objc_retainBlock is not nounwind because it calls user copy constructors +  // which could theoretically throw. +  switch (Class) { +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::ClaimRV: +  case ARCInstKind::Release: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +    return true; +  case ARCInstKind::RetainBlock: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::User: +  case ARCInstKind::None: +  case ARCInstKind::NoopCast: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +/// Test whether the given instruction can autorelease any pointer or cause an +/// autoreleasepool pop. +/// +/// This means that it *could* interrupt the RV optimization. +bool llvm::objcarc::CanInterruptRV(ARCInstKind Class) { +  switch (Class) { +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +    return true; +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::ClaimRV: +  case ARCInstKind::Release: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::RetainBlock: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::User: +  case ARCInstKind::None: +  case ARCInstKind::NoopCast: +    return false; +  } +  llvm_unreachable("covered switch isn't covered?"); +} + +bool llvm::objcarc::CanDecrementRefCount(ARCInstKind Kind) { +  switch (Kind) { +  case ARCInstKind::Retain: +  case ARCInstKind::RetainRV: +  case ARCInstKind::Autorelease: +  case ARCInstKind::AutoreleaseRV: +  case ARCInstKind::NoopCast: +  case ARCInstKind::FusedRetainAutorelease: +  case ARCInstKind::FusedRetainAutoreleaseRV: +  case ARCInstKind::IntrinsicUser: +  case ARCInstKind::User: +  case ARCInstKind::None: +    return false; + +  // The cases below are conservative. + +  // RetainBlock can result in user defined copy constructors being called +  // implying releases may occur. +  case ARCInstKind::RetainBlock: +  case ARCInstKind::Release: +  case ARCInstKind::AutoreleasepoolPush: +  case ARCInstKind::AutoreleasepoolPop: +  case ARCInstKind::LoadWeakRetained: +  case ARCInstKind::StoreWeak: +  case ARCInstKind::InitWeak: +  case ARCInstKind::LoadWeak: +  case ARCInstKind::MoveWeak: +  case ARCInstKind::CopyWeak: +  case ARCInstKind::DestroyWeak: +  case ARCInstKind::StoreStrong: +  case ARCInstKind::CallOrUser: +  case ARCInstKind::Call: +  case ARCInstKind::ClaimRV: +    return true; +  } + +  llvm_unreachable("covered switch isn't covered?"); +} diff --git a/contrib/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp b/contrib/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp new file mode 100644 index 000000000000..8ece0a2a3ed3 --- /dev/null +++ b/contrib/llvm/lib/Analysis/OptimizationRemarkEmitter.cpp @@ -0,0 +1,134 @@ +//===- OptimizationRemarkEmitter.cpp - Optimization Diagnostic --*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Optimization diagnostic interfaces.  It's packaged as an analysis pass so +// that by using this service passes become dependent on BFI as well.  BFI is +// used to compute the "hotness" of the diagnostic message. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" + +using namespace llvm; + +OptimizationRemarkEmitter::OptimizationRemarkEmitter(const Function *F) +    : F(F), BFI(nullptr) { +  if (!F->getContext().getDiagnosticsHotnessRequested()) +    return; + +  // First create a dominator tree. +  DominatorTree DT; +  DT.recalculate(*const_cast<Function *>(F)); + +  // Generate LoopInfo from it. +  LoopInfo LI; +  LI.analyze(DT); + +  // Then compute BranchProbabilityInfo. +  BranchProbabilityInfo BPI; +  BPI.calculate(*F, LI); + +  // Finally compute BFI. +  OwnedBFI = llvm::make_unique<BlockFrequencyInfo>(*F, BPI, LI); +  BFI = OwnedBFI.get(); +} + +bool OptimizationRemarkEmitter::invalidate( +    Function &F, const PreservedAnalyses &PA, +    FunctionAnalysisManager::Invalidator &Inv) { +  // This analysis has no state and so can be trivially preserved but it needs +  // a fresh view of BFI if it was constructed with one. +  if (BFI && Inv.invalidate<BlockFrequencyAnalysis>(F, PA)) +    return true; + +  // Otherwise this analysis result remains valid. +  return false; +} + +Optional<uint64_t> OptimizationRemarkEmitter::computeHotness(const Value *V) { +  if (!BFI) +    return None; + +  return BFI->getBlockProfileCount(cast<BasicBlock>(V)); +} + +void OptimizationRemarkEmitter::computeHotness( +    DiagnosticInfoIROptimization &OptDiag) { +  const Value *V = OptDiag.getCodeRegion(); +  if (V) +    OptDiag.setHotness(computeHotness(V)); +} + +void OptimizationRemarkEmitter::emit( +    DiagnosticInfoOptimizationBase &OptDiagBase) { +  auto &OptDiag = cast<DiagnosticInfoIROptimization>(OptDiagBase); +  computeHotness(OptDiag); + +  // Only emit it if its hotness meets the threshold. +  if (OptDiag.getHotness().getValueOr(0) < +      F->getContext().getDiagnosticsHotnessThreshold()) { +    return; +  } + +  F->getContext().diagnose(OptDiag); +} + +OptimizationRemarkEmitterWrapperPass::OptimizationRemarkEmitterWrapperPass() +    : FunctionPass(ID) { +  initializeOptimizationRemarkEmitterWrapperPassPass( +      *PassRegistry::getPassRegistry()); +} + +bool OptimizationRemarkEmitterWrapperPass::runOnFunction(Function &Fn) { +  BlockFrequencyInfo *BFI; + +  if (Fn.getContext().getDiagnosticsHotnessRequested()) +    BFI = &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI(); +  else +    BFI = nullptr; + +  ORE = llvm::make_unique<OptimizationRemarkEmitter>(&Fn, BFI); +  return false; +} + +void OptimizationRemarkEmitterWrapperPass::getAnalysisUsage( +    AnalysisUsage &AU) const { +  LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); +  AU.setPreservesAll(); +} + +AnalysisKey OptimizationRemarkEmitterAnalysis::Key; + +OptimizationRemarkEmitter +OptimizationRemarkEmitterAnalysis::run(Function &F, +                                       FunctionAnalysisManager &AM) { +  BlockFrequencyInfo *BFI; + +  if (F.getContext().getDiagnosticsHotnessRequested()) +    BFI = &AM.getResult<BlockFrequencyAnalysis>(F); +  else +    BFI = nullptr; + +  return OptimizationRemarkEmitter(&F, BFI); +} + +char OptimizationRemarkEmitterWrapperPass::ID = 0; +static const char ore_name[] = "Optimization Remark Emitter"; +#define ORE_NAME "opt-remark-emitter" + +INITIALIZE_PASS_BEGIN(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name, +                      false, true) +INITIALIZE_PASS_DEPENDENCY(LazyBFIPass) +INITIALIZE_PASS_END(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name, +                    false, true) diff --git a/contrib/llvm/lib/Analysis/OrderedBasicBlock.cpp b/contrib/llvm/lib/Analysis/OrderedBasicBlock.cpp new file mode 100644 index 000000000000..6c47651eae9e --- /dev/null +++ b/contrib/llvm/lib/Analysis/OrderedBasicBlock.cpp @@ -0,0 +1,85 @@ +//===- OrderedBasicBlock.cpp --------------------------------- -*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the OrderedBasicBlock class. OrderedBasicBlock +// maintains an interface where clients can query if one instruction comes +// before another in a BasicBlock. Since BasicBlock currently lacks a reliable +// way to query relative position between instructions one can use +// OrderedBasicBlock to do such queries. OrderedBasicBlock is lazily built on a +// source BasicBlock and maintains an internal Instruction -> Position map. A +// OrderedBasicBlock instance should be discarded whenever the source +// BasicBlock changes. +// +// It's currently used by the CaptureTracker in order to find relative +// positions of a pair of instructions inside a BasicBlock. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/IR/Instruction.h" +using namespace llvm; + +OrderedBasicBlock::OrderedBasicBlock(const BasicBlock *BasicB) +    : NextInstPos(0), BB(BasicB) { +  LastInstFound = BB->end(); +} + +/// Given no cached results, find if \p A comes before \p B in \p BB. +/// Cache and number out instruction while walking \p BB. +bool OrderedBasicBlock::comesBefore(const Instruction *A, +                                    const Instruction *B) { +  const Instruction *Inst = nullptr; +  assert(!(LastInstFound == BB->end() && NextInstPos != 0) && +         "Instruction supposed to be in NumberedInsts"); + +  // Start the search with the instruction found in the last lookup round. +  auto II = BB->begin(); +  auto IE = BB->end(); +  if (LastInstFound != IE) +    II = std::next(LastInstFound); + +  // Number all instructions up to the point where we find 'A' or 'B'. +  for (; II != IE; ++II) { +    Inst = cast<Instruction>(II); +    NumberedInsts[Inst] = NextInstPos++; +    if (Inst == A || Inst == B) +      break; +  } + +  assert(II != IE && "Instruction not found?"); +  assert((Inst == A || Inst == B) && "Should find A or B"); +  LastInstFound = II; +  return Inst != B; +} + +/// Find out whether \p A dominates \p B, meaning whether \p A +/// comes before \p B in \p BB. This is a simplification that considers +/// cached instruction positions and ignores other basic blocks, being +/// only relevant to compare relative instructions positions inside \p BB. +bool OrderedBasicBlock::dominates(const Instruction *A, const Instruction *B) { +  assert(A->getParent() == B->getParent() && +         "Instructions must be in the same basic block!"); + +  // First we lookup the instructions. If they don't exist, lookup will give us +  // back ::end(). If they both exist, we compare the numbers. Otherwise, if NA +  // exists and NB doesn't, it means NA must come before NB because we would +  // have numbered NB as well if it didn't. The same is true for NB. If it +  // exists, but NA does not, NA must come after it. If neither exist, we need +  // to number the block and cache the results (by calling comesBefore). +  auto NAI = NumberedInsts.find(A); +  auto NBI = NumberedInsts.find(B); +  if (NAI != NumberedInsts.end() && NBI != NumberedInsts.end()) +    return NAI->second < NBI->second; +  if (NAI != NumberedInsts.end()) +    return true; +  if (NBI != NumberedInsts.end()) +    return false; + +  return comesBefore(A, B); +} diff --git a/contrib/llvm/lib/Analysis/PHITransAddr.cpp b/contrib/llvm/lib/Analysis/PHITransAddr.cpp new file mode 100644 index 000000000000..858f08f6537a --- /dev/null +++ b/contrib/llvm/lib/Analysis/PHITransAddr.cpp @@ -0,0 +1,440 @@ +//===- PHITransAddr.cpp - PHI Translation for Addresses -------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the PHITransAddr class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +static bool CanPHITrans(Instruction *Inst) { +  if (isa<PHINode>(Inst) || +      isa<GetElementPtrInst>(Inst)) +    return true; + +  if (isa<CastInst>(Inst) && +      isSafeToSpeculativelyExecute(Inst)) +    return true; + +  if (Inst->getOpcode() == Instruction::Add && +      isa<ConstantInt>(Inst->getOperand(1))) +    return true; + +  //   cerr << "MEMDEP: Could not PHI translate: " << *Pointer; +  //   if (isa<BitCastInst>(PtrInst) || isa<GetElementPtrInst>(PtrInst)) +  //     cerr << "OP:\t\t\t\t" << *PtrInst->getOperand(0); +  return false; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void PHITransAddr::dump() const { +  if (!Addr) { +    dbgs() << "PHITransAddr: null\n"; +    return; +  } +  dbgs() << "PHITransAddr: " << *Addr << "\n"; +  for (unsigned i = 0, e = InstInputs.size(); i != e; ++i) +    dbgs() << "  Input #" << i << " is " << *InstInputs[i] << "\n"; +} +#endif + + +static bool VerifySubExpr(Value *Expr, +                          SmallVectorImpl<Instruction*> &InstInputs) { +  // If this is a non-instruction value, there is nothing to do. +  Instruction *I = dyn_cast<Instruction>(Expr); +  if (!I) return true; + +  // If it's an instruction, it is either in Tmp or its operands recursively +  // are. +  SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I); +  if (Entry != InstInputs.end()) { +    InstInputs.erase(Entry); +    return true; +  } + +  // If it isn't in the InstInputs list it is a subexpr incorporated into the +  // address.  Sanity check that it is phi translatable. +  if (!CanPHITrans(I)) { +    errs() << "Instruction in PHITransAddr is not phi-translatable:\n"; +    errs() << *I << '\n'; +    llvm_unreachable("Either something is missing from InstInputs or " +                     "CanPHITrans is wrong."); +  } + +  // Validate the operands of the instruction. +  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) +    if (!VerifySubExpr(I->getOperand(i), InstInputs)) +      return false; + +  return true; +} + +/// Verify - Check internal consistency of this data structure.  If the +/// structure is valid, it returns true.  If invalid, it prints errors and +/// returns false. +bool PHITransAddr::Verify() const { +  if (!Addr) return true; + +  SmallVector<Instruction*, 8> Tmp(InstInputs.begin(), InstInputs.end()); + +  if (!VerifySubExpr(Addr, Tmp)) +    return false; + +  if (!Tmp.empty()) { +    errs() << "PHITransAddr contains extra instructions:\n"; +    for (unsigned i = 0, e = InstInputs.size(); i != e; ++i) +      errs() << "  InstInput #" << i << " is " << *InstInputs[i] << "\n"; +    llvm_unreachable("This is unexpected."); +  } + +  // a-ok. +  return true; +} + + +/// IsPotentiallyPHITranslatable - If this needs PHI translation, return true +/// if we have some hope of doing it.  This should be used as a filter to +/// avoid calling PHITranslateValue in hopeless situations. +bool PHITransAddr::IsPotentiallyPHITranslatable() const { +  // If the input value is not an instruction, or if it is not defined in CurBB, +  // then we don't need to phi translate it. +  Instruction *Inst = dyn_cast<Instruction>(Addr); +  return !Inst || CanPHITrans(Inst); +} + + +static void RemoveInstInputs(Value *V, +                             SmallVectorImpl<Instruction*> &InstInputs) { +  Instruction *I = dyn_cast<Instruction>(V); +  if (!I) return; + +  // If the instruction is in the InstInputs list, remove it. +  SmallVectorImpl<Instruction *>::iterator Entry = find(InstInputs, I); +  if (Entry != InstInputs.end()) { +    InstInputs.erase(Entry); +    return; +  } + +  assert(!isa<PHINode>(I) && "Error, removing something that isn't an input"); + +  // Otherwise, it must have instruction inputs itself.  Zap them recursively. +  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { +    if (Instruction *Op = dyn_cast<Instruction>(I->getOperand(i))) +      RemoveInstInputs(Op, InstInputs); +  } +} + +Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, +                                         BasicBlock *PredBB, +                                         const DominatorTree *DT) { +  // If this is a non-instruction value, it can't require PHI translation. +  Instruction *Inst = dyn_cast<Instruction>(V); +  if (!Inst) return V; + +  // Determine whether 'Inst' is an input to our PHI translatable expression. +  bool isInput = is_contained(InstInputs, Inst); + +  // Handle inputs instructions if needed. +  if (isInput) { +    if (Inst->getParent() != CurBB) { +      // If it is an input defined in a different block, then it remains an +      // input. +      return Inst; +    } + +    // If 'Inst' is defined in this block and is an input that needs to be phi +    // translated, we need to incorporate the value into the expression or fail. + +    // In either case, the instruction itself isn't an input any longer. +    InstInputs.erase(find(InstInputs, Inst)); + +    // If this is a PHI, go ahead and translate it. +    if (PHINode *PN = dyn_cast<PHINode>(Inst)) +      return AddAsInput(PN->getIncomingValueForBlock(PredBB)); + +    // If this is a non-phi value, and it is analyzable, we can incorporate it +    // into the expression by making all instruction operands be inputs. +    if (!CanPHITrans(Inst)) +      return nullptr; + +    // All instruction operands are now inputs (and of course, they may also be +    // defined in this block, so they may need to be phi translated themselves. +    for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i) +      if (Instruction *Op = dyn_cast<Instruction>(Inst->getOperand(i))) +        InstInputs.push_back(Op); +  } + +  // Ok, it must be an intermediate result (either because it started that way +  // or because we just incorporated it into the expression).  See if its +  // operands need to be phi translated, and if so, reconstruct it. + +  if (CastInst *Cast = dyn_cast<CastInst>(Inst)) { +    if (!isSafeToSpeculativelyExecute(Cast)) return nullptr; +    Value *PHIIn = PHITranslateSubExpr(Cast->getOperand(0), CurBB, PredBB, DT); +    if (!PHIIn) return nullptr; +    if (PHIIn == Cast->getOperand(0)) +      return Cast; + +    // Find an available version of this cast. + +    // Constants are trivial to find. +    if (Constant *C = dyn_cast<Constant>(PHIIn)) +      return AddAsInput(ConstantExpr::getCast(Cast->getOpcode(), +                                              C, Cast->getType())); + +    // Otherwise we have to see if a casted version of the incoming pointer +    // is available.  If so, we can use it, otherwise we have to fail. +    for (User *U : PHIIn->users()) { +      if (CastInst *CastI = dyn_cast<CastInst>(U)) +        if (CastI->getOpcode() == Cast->getOpcode() && +            CastI->getType() == Cast->getType() && +            (!DT || DT->dominates(CastI->getParent(), PredBB))) +          return CastI; +    } +    return nullptr; +  } + +  // Handle getelementptr with at least one PHI translatable operand. +  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { +    SmallVector<Value*, 8> GEPOps; +    bool AnyChanged = false; +    for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) { +      Value *GEPOp = PHITranslateSubExpr(GEP->getOperand(i), CurBB, PredBB, DT); +      if (!GEPOp) return nullptr; + +      AnyChanged |= GEPOp != GEP->getOperand(i); +      GEPOps.push_back(GEPOp); +    } + +    if (!AnyChanged) +      return GEP; + +    // Simplify the GEP to handle 'gep x, 0' -> x etc. +    if (Value *V = SimplifyGEPInst(GEP->getSourceElementType(), +                                   GEPOps, {DL, TLI, DT, AC})) { +      for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) +        RemoveInstInputs(GEPOps[i], InstInputs); + +      return AddAsInput(V); +    } + +    // Scan to see if we have this GEP available. +    Value *APHIOp = GEPOps[0]; +    for (User *U : APHIOp->users()) { +      if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(U)) +        if (GEPI->getType() == GEP->getType() && +            GEPI->getNumOperands() == GEPOps.size() && +            GEPI->getParent()->getParent() == CurBB->getParent() && +            (!DT || DT->dominates(GEPI->getParent(), PredBB))) { +          if (std::equal(GEPOps.begin(), GEPOps.end(), GEPI->op_begin())) +            return GEPI; +        } +    } +    return nullptr; +  } + +  // Handle add with a constant RHS. +  if (Inst->getOpcode() == Instruction::Add && +      isa<ConstantInt>(Inst->getOperand(1))) { +    // PHI translate the LHS. +    Constant *RHS = cast<ConstantInt>(Inst->getOperand(1)); +    bool isNSW = cast<BinaryOperator>(Inst)->hasNoSignedWrap(); +    bool isNUW = cast<BinaryOperator>(Inst)->hasNoUnsignedWrap(); + +    Value *LHS = PHITranslateSubExpr(Inst->getOperand(0), CurBB, PredBB, DT); +    if (!LHS) return nullptr; + +    // If the PHI translated LHS is an add of a constant, fold the immediates. +    if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(LHS)) +      if (BOp->getOpcode() == Instruction::Add) +        if (ConstantInt *CI = dyn_cast<ConstantInt>(BOp->getOperand(1))) { +          LHS = BOp->getOperand(0); +          RHS = ConstantExpr::getAdd(RHS, CI); +          isNSW = isNUW = false; + +          // If the old 'LHS' was an input, add the new 'LHS' as an input. +          if (is_contained(InstInputs, BOp)) { +            RemoveInstInputs(BOp, InstInputs); +            AddAsInput(LHS); +          } +        } + +    // See if the add simplifies away. +    if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, {DL, TLI, DT, AC})) { +      // If we simplified the operands, the LHS is no longer an input, but Res +      // is. +      RemoveInstInputs(LHS, InstInputs); +      return AddAsInput(Res); +    } + +    // If we didn't modify the add, just return it. +    if (LHS == Inst->getOperand(0) && RHS == Inst->getOperand(1)) +      return Inst; + +    // Otherwise, see if we have this add available somewhere. +    for (User *U : LHS->users()) { +      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U)) +        if (BO->getOpcode() == Instruction::Add && +            BO->getOperand(0) == LHS && BO->getOperand(1) == RHS && +            BO->getParent()->getParent() == CurBB->getParent() && +            (!DT || DT->dominates(BO->getParent(), PredBB))) +          return BO; +    } + +    return nullptr; +  } + +  // Otherwise, we failed. +  return nullptr; +} + + +/// PHITranslateValue - PHI translate the current address up the CFG from +/// CurBB to Pred, updating our state to reflect any needed changes.  If +/// 'MustDominate' is true, the translated value must dominate +/// PredBB.  This returns true on failure and sets Addr to null. +bool PHITransAddr::PHITranslateValue(BasicBlock *CurBB, BasicBlock *PredBB, +                                     const DominatorTree *DT, +                                     bool MustDominate) { +  assert(DT || !MustDominate); +  assert(Verify() && "Invalid PHITransAddr!"); +  if (DT && DT->isReachableFromEntry(PredBB)) +    Addr = +        PHITranslateSubExpr(Addr, CurBB, PredBB, MustDominate ? DT : nullptr); +  else +    Addr = nullptr; +  assert(Verify() && "Invalid PHITransAddr!"); + +  if (MustDominate) +    // Make sure the value is live in the predecessor. +    if (Instruction *Inst = dyn_cast_or_null<Instruction>(Addr)) +      if (!DT->dominates(Inst->getParent(), PredBB)) +        Addr = nullptr; + +  return Addr == nullptr; +} + +/// PHITranslateWithInsertion - PHI translate this value into the specified +/// predecessor block, inserting a computation of the value if it is +/// unavailable. +/// +/// All newly created instructions are added to the NewInsts list.  This +/// returns null on failure. +/// +Value *PHITransAddr:: +PHITranslateWithInsertion(BasicBlock *CurBB, BasicBlock *PredBB, +                          const DominatorTree &DT, +                          SmallVectorImpl<Instruction*> &NewInsts) { +  unsigned NISize = NewInsts.size(); + +  // Attempt to PHI translate with insertion. +  Addr = InsertPHITranslatedSubExpr(Addr, CurBB, PredBB, DT, NewInsts); + +  // If successful, return the new value. +  if (Addr) return Addr; + +  // If not, destroy any intermediate instructions inserted. +  while (NewInsts.size() != NISize) +    NewInsts.pop_back_val()->eraseFromParent(); +  return nullptr; +} + + +/// InsertPHITranslatedPointer - Insert a computation of the PHI translated +/// version of 'V' for the edge PredBB->CurBB into the end of the PredBB +/// block.  All newly created instructions are added to the NewInsts list. +/// This returns null on failure. +/// +Value *PHITransAddr:: +InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, +                           BasicBlock *PredBB, const DominatorTree &DT, +                           SmallVectorImpl<Instruction*> &NewInsts) { +  // See if we have a version of this value already available and dominating +  // PredBB.  If so, there is no need to insert a new instance of it. +  PHITransAddr Tmp(InVal, DL, AC); +  if (!Tmp.PHITranslateValue(CurBB, PredBB, &DT, /*MustDominate=*/true)) +    return Tmp.getAddr(); + +  // We don't need to PHI translate values which aren't instructions. +  auto *Inst = dyn_cast<Instruction>(InVal); +  if (!Inst) +    return nullptr; + +  // Handle cast of PHI translatable value. +  if (CastInst *Cast = dyn_cast<CastInst>(Inst)) { +    if (!isSafeToSpeculativelyExecute(Cast)) return nullptr; +    Value *OpVal = InsertPHITranslatedSubExpr(Cast->getOperand(0), +                                              CurBB, PredBB, DT, NewInsts); +    if (!OpVal) return nullptr; + +    // Otherwise insert a cast at the end of PredBB. +    CastInst *New = CastInst::Create(Cast->getOpcode(), OpVal, InVal->getType(), +                                     InVal->getName() + ".phi.trans.insert", +                                     PredBB->getTerminator()); +    New->setDebugLoc(Inst->getDebugLoc()); +    NewInsts.push_back(New); +    return New; +  } + +  // Handle getelementptr with at least one PHI operand. +  if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Inst)) { +    SmallVector<Value*, 8> GEPOps; +    BasicBlock *CurBB = GEP->getParent(); +    for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) { +      Value *OpVal = InsertPHITranslatedSubExpr(GEP->getOperand(i), +                                                CurBB, PredBB, DT, NewInsts); +      if (!OpVal) return nullptr; +      GEPOps.push_back(OpVal); +    } + +    GetElementPtrInst *Result = GetElementPtrInst::Create( +        GEP->getSourceElementType(), GEPOps[0], makeArrayRef(GEPOps).slice(1), +        InVal->getName() + ".phi.trans.insert", PredBB->getTerminator()); +    Result->setDebugLoc(Inst->getDebugLoc()); +    Result->setIsInBounds(GEP->isInBounds()); +    NewInsts.push_back(Result); +    return Result; +  } + +#if 0 +  // FIXME: This code works, but it is unclear that we actually want to insert +  // a big chain of computation in order to make a value available in a block. +  // This needs to be evaluated carefully to consider its cost trade offs. + +  // Handle add with a constant RHS. +  if (Inst->getOpcode() == Instruction::Add && +      isa<ConstantInt>(Inst->getOperand(1))) { +    // PHI translate the LHS. +    Value *OpVal = InsertPHITranslatedSubExpr(Inst->getOperand(0), +                                              CurBB, PredBB, DT, NewInsts); +    if (OpVal == 0) return 0; + +    BinaryOperator *Res = BinaryOperator::CreateAdd(OpVal, Inst->getOperand(1), +                                           InVal->getName()+".phi.trans.insert", +                                                    PredBB->getTerminator()); +    Res->setHasNoSignedWrap(cast<BinaryOperator>(Inst)->hasNoSignedWrap()); +    Res->setHasNoUnsignedWrap(cast<BinaryOperator>(Inst)->hasNoUnsignedWrap()); +    NewInsts.push_back(Res); +    return Res; +  } +#endif + +  return nullptr; +} diff --git a/contrib/llvm/lib/Analysis/PhiValues.cpp b/contrib/llvm/lib/Analysis/PhiValues.cpp new file mode 100644 index 000000000000..ef121815d2cf --- /dev/null +++ b/contrib/llvm/lib/Analysis/PhiValues.cpp @@ -0,0 +1,196 @@ +//===- PhiValues.cpp - Phi Value Analysis ---------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PhiValues.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +bool PhiValues::invalidate(Function &, const PreservedAnalyses &PA, +                           FunctionAnalysisManager::Invalidator &) { +  // PhiValues is invalidated if it isn't preserved. +  auto PAC = PA.getChecker<PhiValuesAnalysis>(); +  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()); +} + +// The goal here is to find all of the non-phi values reachable from this phi, +// and to do the same for all of the phis reachable from this phi, as doing so +// is necessary anyway in order to get the values for this phi. We do this using +// Tarjan's algorithm with Nuutila's improvements to find the strongly connected +// components of the phi graph rooted in this phi: +//  * All phis in a strongly connected component will have the same reachable +//    non-phi values. The SCC may not be the maximal subgraph for that set of +//    reachable values, but finding out that isn't really necessary (it would +//    only reduce the amount of memory needed to store the values). +//  * Tarjan's algorithm completes components in a bottom-up manner, i.e. it +//    never completes a component before the components reachable from it have +//    been completed. This means that when we complete a component we have +//    everything we need to collect the values reachable from that component. +//  * We collect both the non-phi values reachable from each SCC, as that's what +//    we're ultimately interested in, and all of the reachable values, i.e. +//    including phis, as that makes invalidateValue easier. +void PhiValues::processPhi(const PHINode *Phi, +                           SmallVector<const PHINode *, 8> &Stack) { +  // Initialize the phi with the next depth number. +  assert(DepthMap.lookup(Phi) == 0); +  assert(NextDepthNumber != UINT_MAX); +  unsigned int DepthNumber = ++NextDepthNumber; +  DepthMap[Phi] = DepthNumber; + +  // Recursively process the incoming phis of this phi. +  for (Value *PhiOp : Phi->incoming_values()) { +    if (PHINode *PhiPhiOp = dyn_cast<PHINode>(PhiOp)) { +      // Recurse if the phi has not yet been visited. +      if (DepthMap.lookup(PhiPhiOp) == 0) +        processPhi(PhiPhiOp, Stack); +      assert(DepthMap.lookup(PhiPhiOp) != 0); +      // If the phi did not become part of a component then this phi and that +      // phi are part of the same component, so adjust the depth number. +      if (!ReachableMap.count(DepthMap[PhiPhiOp])) +        DepthMap[Phi] = std::min(DepthMap[Phi], DepthMap[PhiPhiOp]); +    } +  } + +  // Now that incoming phis have been handled, push this phi to the stack. +  Stack.push_back(Phi); + +  // If the depth number has not changed then we've finished collecting the phis +  // of a strongly connected component. +  if (DepthMap[Phi] == DepthNumber) { +    // Collect the reachable values for this component. The phis of this +    // component will be those on top of the depth stach with the same or +    // greater depth number. +    ConstValueSet Reachable; +    while (!Stack.empty() && DepthMap[Stack.back()] >= DepthNumber) { +      const PHINode *ComponentPhi = Stack.pop_back_val(); +      Reachable.insert(ComponentPhi); +      DepthMap[ComponentPhi] = DepthNumber; +      for (Value *Op : ComponentPhi->incoming_values()) { +        if (PHINode *PhiOp = dyn_cast<PHINode>(Op)) { +          // If this phi is not part of the same component then that component +          // is guaranteed to have been completed before this one. Therefore we +          // can just add its reachable values to the reachable values of this +          // component. +          auto It = ReachableMap.find(DepthMap[PhiOp]); +          if (It != ReachableMap.end()) +            Reachable.insert(It->second.begin(), It->second.end()); +        } else { +          Reachable.insert(Op); +        } +      } +    } +    ReachableMap.insert({DepthNumber,Reachable}); + +    // Filter out phis to get the non-phi reachable values. +    ValueSet NonPhi; +    for (const Value *V : Reachable) +      if (!isa<PHINode>(V)) +        NonPhi.insert(const_cast<Value*>(V)); +    NonPhiReachableMap.insert({DepthNumber,NonPhi}); +  } +} + +const PhiValues::ValueSet &PhiValues::getValuesForPhi(const PHINode *PN) { +  if (DepthMap.count(PN) == 0) { +    SmallVector<const PHINode *, 8> Stack; +    processPhi(PN, Stack); +    assert(Stack.empty()); +  } +  assert(DepthMap.lookup(PN) != 0); +  return NonPhiReachableMap[DepthMap[PN]]; +} + +void PhiValues::invalidateValue(const Value *V) { +  // Components that can reach V are invalid. +  SmallVector<unsigned int, 8> InvalidComponents; +  for (auto &Pair : ReachableMap) +    if (Pair.second.count(V)) +      InvalidComponents.push_back(Pair.first); + +  for (unsigned int N : InvalidComponents) { +    for (const Value *V : ReachableMap[N]) +      if (const PHINode *PN = dyn_cast<PHINode>(V)) +        DepthMap.erase(PN); +    NonPhiReachableMap.erase(N); +    ReachableMap.erase(N); +  } +} + +void PhiValues::releaseMemory() { +  DepthMap.clear(); +  NonPhiReachableMap.clear(); +  ReachableMap.clear(); +} + +void PhiValues::print(raw_ostream &OS) const { +  // Iterate through the phi nodes of the function rather than iterating through +  // DepthMap in order to get predictable ordering. +  for (const BasicBlock &BB : F) { +    for (const PHINode &PN : BB.phis()) { +      OS << "PHI "; +      PN.printAsOperand(OS, false); +      OS << " has values:\n"; +      unsigned int N = DepthMap.lookup(&PN); +      auto It = NonPhiReachableMap.find(N); +      if (It == NonPhiReachableMap.end()) +        OS << "  UNKNOWN\n"; +      else if (It->second.empty()) +        OS << "  NONE\n"; +      else +        for (Value *V : It->second) +          // Printing of an instruction prints two spaces at the start, so +          // handle instructions and everything else slightly differently in +          // order to get consistent indenting. +          if (Instruction *I = dyn_cast<Instruction>(V)) +            OS << *I << "\n"; +          else +            OS << "  " << *V << "\n"; +    } +  } +} + +AnalysisKey PhiValuesAnalysis::Key; +PhiValues PhiValuesAnalysis::run(Function &F, FunctionAnalysisManager &) { +  return PhiValues(F); +} + +PreservedAnalyses PhiValuesPrinterPass::run(Function &F, +                                            FunctionAnalysisManager &AM) { +  OS << "PHI Values for function: " << F.getName() << "\n"; +  PhiValues &PI = AM.getResult<PhiValuesAnalysis>(F); +  for (const BasicBlock &BB : F) +    for (const PHINode &PN : BB.phis()) +      PI.getValuesForPhi(&PN); +  PI.print(OS); +  return PreservedAnalyses::all(); +} + +PhiValuesWrapperPass::PhiValuesWrapperPass() : FunctionPass(ID) { +  initializePhiValuesWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool PhiValuesWrapperPass::runOnFunction(Function &F) { +  Result.reset(new PhiValues(F)); +  return false; +} + +void PhiValuesWrapperPass::releaseMemory() { +  Result->releaseMemory(); +} + +void PhiValuesWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +} + +char PhiValuesWrapperPass::ID = 0; + +INITIALIZE_PASS(PhiValuesWrapperPass, "phi-values", "Phi Values Analysis", false, +                true) diff --git a/contrib/llvm/lib/Analysis/PostDominators.cpp b/contrib/llvm/lib/Analysis/PostDominators.cpp new file mode 100644 index 000000000000..e6b660fe26d7 --- /dev/null +++ b/contrib/llvm/lib/Analysis/PostDominators.cpp @@ -0,0 +1,85 @@ +//===- PostDominators.cpp - Post-Dominator Calculation --------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the post-dominator construction algorithms. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PostDominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "postdomtree" + +#ifdef EXPENSIVE_CHECKS +static constexpr bool ExpensiveChecksEnabled = true; +#else +static constexpr bool ExpensiveChecksEnabled = false; +#endif + +//===----------------------------------------------------------------------===// +//  PostDominatorTree Implementation +//===----------------------------------------------------------------------===// + +char PostDominatorTreeWrapperPass::ID = 0; + +INITIALIZE_PASS(PostDominatorTreeWrapperPass, "postdomtree", +                "Post-Dominator Tree Construction", true, true) + +bool PostDominatorTree::invalidate(Function &F, const PreservedAnalyses &PA, +                                   FunctionAnalysisManager::Invalidator &) { +  // Check whether the analysis, all analyses on functions, or the function's +  // CFG have been preserved. +  auto PAC = PA.getChecker<PostDominatorTreeAnalysis>(); +  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || +           PAC.preservedSet<CFGAnalyses>()); +} + +bool PostDominatorTreeWrapperPass::runOnFunction(Function &F) { +  DT.recalculate(F); +  return false; +} + +void PostDominatorTreeWrapperPass::verifyAnalysis() const { +  if (VerifyDomInfo) +    assert(DT.verify(PostDominatorTree::VerificationLevel::Full)); +  else if (ExpensiveChecksEnabled) +    assert(DT.verify(PostDominatorTree::VerificationLevel::Basic)); +} + +void PostDominatorTreeWrapperPass::print(raw_ostream &OS, const Module *) const { +  DT.print(OS); +} + +FunctionPass* llvm::createPostDomTree() { +  return new PostDominatorTreeWrapperPass(); +} + +AnalysisKey PostDominatorTreeAnalysis::Key; + +PostDominatorTree PostDominatorTreeAnalysis::run(Function &F, +                                                 FunctionAnalysisManager &) { +  PostDominatorTree PDT(F); +  return PDT; +} + +PostDominatorTreePrinterPass::PostDominatorTreePrinterPass(raw_ostream &OS) +  : OS(OS) {} + +PreservedAnalyses +PostDominatorTreePrinterPass::run(Function &F, FunctionAnalysisManager &AM) { +  OS << "PostDominatorTree for function: " << F.getName() << "\n"; +  AM.getResult<PostDominatorTreeAnalysis>(F).print(OS); + +  return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Analysis/ProfileSummaryInfo.cpp b/contrib/llvm/lib/Analysis/ProfileSummaryInfo.cpp new file mode 100644 index 000000000000..fb591f5d6a69 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ProfileSummaryInfo.cpp @@ -0,0 +1,310 @@ +//===- ProfileSummaryInfo.cpp - Global profile summary information --------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass that provides access to the global profile summary +// information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ProfileSummary.h" +using namespace llvm; + +// The following two parameters determine the threshold for a count to be +// considered hot/cold. These two parameters are percentile values (multiplied +// by 10000). If the counts are sorted in descending order, the minimum count to +// reach ProfileSummaryCutoffHot gives the threshold to determine a hot count. +// Similarly, the minimum count to reach ProfileSummaryCutoffCold gives the +// threshold for determining cold count (everything <= this threshold is +// considered cold). + +static cl::opt<int> ProfileSummaryCutoffHot( +    "profile-summary-cutoff-hot", cl::Hidden, cl::init(990000), cl::ZeroOrMore, +    cl::desc("A count is hot if it exceeds the minimum count to" +             " reach this percentile of total counts.")); + +static cl::opt<int> ProfileSummaryCutoffCold( +    "profile-summary-cutoff-cold", cl::Hidden, cl::init(999999), cl::ZeroOrMore, +    cl::desc("A count is cold if it is below the minimum count" +             " to reach this percentile of total counts.")); + +static cl::opt<bool> ProfileSampleAccurate( +    "profile-sample-accurate", cl::Hidden, cl::init(false), +    cl::desc("If the sample profile is accurate, we will mark all un-sampled " +             "callsite as cold. Otherwise, treat un-sampled callsites as if " +             "we have no profile.")); +static cl::opt<unsigned> ProfileSummaryHugeWorkingSetSizeThreshold( +    "profile-summary-huge-working-set-size-threshold", cl::Hidden, +    cl::init(15000), cl::ZeroOrMore, +    cl::desc("The code working set size is considered huge if the number of" +             " blocks required to reach the -profile-summary-cutoff-hot" +             " percentile exceeds this count.")); + +// Find the summary entry for a desired percentile of counts. +static const ProfileSummaryEntry &getEntryForPercentile(SummaryEntryVector &DS, +                                                        uint64_t Percentile) { +  auto Compare = [](const ProfileSummaryEntry &Entry, uint64_t Percentile) { +    return Entry.Cutoff < Percentile; +  }; +  auto It = std::lower_bound(DS.begin(), DS.end(), Percentile, Compare); +  // The required percentile has to be <= one of the percentiles in the +  // detailed summary. +  if (It == DS.end()) +    report_fatal_error("Desired percentile exceeds the maximum cutoff"); +  return *It; +} + +// The profile summary metadata may be attached either by the frontend or by +// any backend passes (IR level instrumentation, for example). This method +// checks if the Summary is null and if so checks if the summary metadata is now +// available in the module and parses it to get the Summary object. Returns true +// if a valid Summary is available. +bool ProfileSummaryInfo::computeSummary() { +  if (Summary) +    return true; +  auto *SummaryMD = M.getProfileSummary(); +  if (!SummaryMD) +    return false; +  Summary.reset(ProfileSummary::getFromMD(SummaryMD)); +  return true; +} + +Optional<uint64_t> +ProfileSummaryInfo::getProfileCount(const Instruction *Inst, +                                    BlockFrequencyInfo *BFI) { +  if (!Inst) +    return None; +  assert((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) && +         "We can only get profile count for call/invoke instruction."); +  if (hasSampleProfile()) { +    // In sample PGO mode, check if there is a profile metadata on the +    // instruction. If it is present, determine hotness solely based on that, +    // since the sampled entry count may not be accurate. If there is no +    // annotated on the instruction, return None. +    uint64_t TotalCount; +    if (Inst->extractProfTotalWeight(TotalCount)) +      return TotalCount; +    return None; +  } +  if (BFI) +    return BFI->getBlockProfileCount(Inst->getParent()); +  return None; +} + +/// Returns true if the function's entry is hot. If it returns false, it +/// either means it is not hot or it is unknown whether it is hot or not (for +/// example, no profile data is available). +bool ProfileSummaryInfo::isFunctionEntryHot(const Function *F) { +  if (!F || !computeSummary()) +    return false; +  auto FunctionCount = F->getEntryCount(); +  // FIXME: The heuristic used below for determining hotness is based on +  // preliminary SPEC tuning for inliner. This will eventually be a +  // convenience method that calls isHotCount. +  return FunctionCount && isHotCount(FunctionCount.getCount()); +} + +/// Returns true if the function contains hot code. This can include a hot +/// function entry count, hot basic block, or (in the case of Sample PGO) +/// hot total call edge count. +/// If it returns false, it either means it is not hot or it is unknown +/// (for example, no profile data is available). +bool ProfileSummaryInfo::isFunctionHotInCallGraph(const Function *F, +                                                  BlockFrequencyInfo &BFI) { +  if (!F || !computeSummary()) +    return false; +  if (auto FunctionCount = F->getEntryCount()) +    if (isHotCount(FunctionCount.getCount())) +      return true; + +  if (hasSampleProfile()) { +    uint64_t TotalCallCount = 0; +    for (const auto &BB : *F) +      for (const auto &I : BB) +        if (isa<CallInst>(I) || isa<InvokeInst>(I)) +          if (auto CallCount = getProfileCount(&I, nullptr)) +            TotalCallCount += CallCount.getValue(); +    if (isHotCount(TotalCallCount)) +      return true; +  } +  for (const auto &BB : *F) +    if (isHotBB(&BB, &BFI)) +      return true; +  return false; +} + +/// Returns true if the function only contains cold code. This means that +/// the function entry and blocks are all cold, and (in the case of Sample PGO) +/// the total call edge count is cold. +/// If it returns false, it either means it is not cold or it is unknown +/// (for example, no profile data is available). +bool ProfileSummaryInfo::isFunctionColdInCallGraph(const Function *F, +                                                   BlockFrequencyInfo &BFI) { +  if (!F || !computeSummary()) +    return false; +  if (auto FunctionCount = F->getEntryCount()) +    if (!isColdCount(FunctionCount.getCount())) +      return false; + +  if (hasSampleProfile()) { +    uint64_t TotalCallCount = 0; +    for (const auto &BB : *F) +      for (const auto &I : BB) +        if (isa<CallInst>(I) || isa<InvokeInst>(I)) +          if (auto CallCount = getProfileCount(&I, nullptr)) +            TotalCallCount += CallCount.getValue(); +    if (!isColdCount(TotalCallCount)) +      return false; +  } +  for (const auto &BB : *F) +    if (!isColdBB(&BB, &BFI)) +      return false; +  return true; +} + +/// Returns true if the function's entry is a cold. If it returns false, it +/// either means it is not cold or it is unknown whether it is cold or not (for +/// example, no profile data is available). +bool ProfileSummaryInfo::isFunctionEntryCold(const Function *F) { +  if (!F) +    return false; +  if (F->hasFnAttribute(Attribute::Cold)) +    return true; +  if (!computeSummary()) +    return false; +  auto FunctionCount = F->getEntryCount(); +  // FIXME: The heuristic used below for determining coldness is based on +  // preliminary SPEC tuning for inliner. This will eventually be a +  // convenience method that calls isHotCount. +  return FunctionCount && isColdCount(FunctionCount.getCount()); +} + +/// Compute the hot and cold thresholds. +void ProfileSummaryInfo::computeThresholds() { +  if (!computeSummary()) +    return; +  auto &DetailedSummary = Summary->getDetailedSummary(); +  auto &HotEntry = +      getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffHot); +  HotCountThreshold = HotEntry.MinCount; +  auto &ColdEntry = +      getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffCold); +  ColdCountThreshold = ColdEntry.MinCount; +  HasHugeWorkingSetSize = +      HotEntry.NumCounts > ProfileSummaryHugeWorkingSetSizeThreshold; +} + +bool ProfileSummaryInfo::hasHugeWorkingSetSize() { +  if (!HasHugeWorkingSetSize) +    computeThresholds(); +  return HasHugeWorkingSetSize && HasHugeWorkingSetSize.getValue(); +} + +bool ProfileSummaryInfo::isHotCount(uint64_t C) { +  if (!HotCountThreshold) +    computeThresholds(); +  return HotCountThreshold && C >= HotCountThreshold.getValue(); +} + +bool ProfileSummaryInfo::isColdCount(uint64_t C) { +  if (!ColdCountThreshold) +    computeThresholds(); +  return ColdCountThreshold && C <= ColdCountThreshold.getValue(); +} + +uint64_t ProfileSummaryInfo::getOrCompHotCountThreshold() { +  if (!HotCountThreshold) +    computeThresholds(); +  return HotCountThreshold && HotCountThreshold.getValue(); +} + +uint64_t ProfileSummaryInfo::getOrCompColdCountThreshold() { +  if (!ColdCountThreshold) +    computeThresholds(); +  return ColdCountThreshold && ColdCountThreshold.getValue(); +} + +bool ProfileSummaryInfo::isHotBB(const BasicBlock *B, BlockFrequencyInfo *BFI) { +  auto Count = BFI->getBlockProfileCount(B); +  return Count && isHotCount(*Count); +} + +bool ProfileSummaryInfo::isColdBB(const BasicBlock *B, +                                  BlockFrequencyInfo *BFI) { +  auto Count = BFI->getBlockProfileCount(B); +  return Count && isColdCount(*Count); +} + +bool ProfileSummaryInfo::isHotCallSite(const CallSite &CS, +                                       BlockFrequencyInfo *BFI) { +  auto C = getProfileCount(CS.getInstruction(), BFI); +  return C && isHotCount(*C); +} + +bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, +                                        BlockFrequencyInfo *BFI) { +  auto C = getProfileCount(CS.getInstruction(), BFI); +  if (C) +    return isColdCount(*C); + +  // In SamplePGO, if the caller has been sampled, and there is no profile +  // annotated on the callsite, we consider the callsite as cold. +  // If there is no profile for the caller, and we know the profile is +  // accurate, we consider the callsite as cold. +  return (hasSampleProfile() && +          (CS.getCaller()->hasProfileData() || ProfileSampleAccurate || +           CS.getCaller()->hasFnAttribute("profile-sample-accurate"))); +} + +INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", +                "Profile summary info", false, true) + +ProfileSummaryInfoWrapperPass::ProfileSummaryInfoWrapperPass() +    : ImmutablePass(ID) { +  initializeProfileSummaryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ProfileSummaryInfoWrapperPass::doInitialization(Module &M) { +  PSI.reset(new ProfileSummaryInfo(M)); +  return false; +} + +bool ProfileSummaryInfoWrapperPass::doFinalization(Module &M) { +  PSI.reset(); +  return false; +} + +AnalysisKey ProfileSummaryAnalysis::Key; +ProfileSummaryInfo ProfileSummaryAnalysis::run(Module &M, +                                               ModuleAnalysisManager &) { +  return ProfileSummaryInfo(M); +} + +PreservedAnalyses ProfileSummaryPrinterPass::run(Module &M, +                                                 ModuleAnalysisManager &AM) { +  ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M); + +  OS << "Functions in " << M.getName() << " with hot/cold annotations: \n"; +  for (auto &F : M) { +    OS << F.getName(); +    if (PSI.isFunctionEntryHot(&F)) +      OS << " :hot entry "; +    else if (PSI.isFunctionEntryCold(&F)) +      OS << " :cold entry "; +    OS << "\n"; +  } +  return PreservedAnalyses::all(); +} + +char ProfileSummaryInfoWrapperPass::ID = 0; diff --git a/contrib/llvm/lib/Analysis/PtrUseVisitor.cpp b/contrib/llvm/lib/Analysis/PtrUseVisitor.cpp new file mode 100644 index 000000000000..1fdaf4d55b59 --- /dev/null +++ b/contrib/llvm/lib/Analysis/PtrUseVisitor.cpp @@ -0,0 +1,39 @@ +//===- PtrUseVisitor.cpp - InstVisitors over a pointers uses --------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// \file +/// Implementation of the pointer use visitors. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/PtrUseVisitor.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include <algorithm> + +using namespace llvm; + +void detail::PtrUseVisitorBase::enqueueUsers(Instruction &I) { +  for (Use &U : I.uses()) { +    if (VisitedUses.insert(&U).second) { +      UseToVisit NewU = { +        UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown), +        Offset +      }; +      Worklist.push_back(std::move(NewU)); +    } +  } +} + +bool detail::PtrUseVisitorBase::adjustOffsetForGEP(GetElementPtrInst &GEPI) { +  if (!IsOffsetKnown) +    return false; + +  return GEPI.accumulateConstantOffset(DL, Offset); +} diff --git a/contrib/llvm/lib/Analysis/RegionInfo.cpp b/contrib/llvm/lib/Analysis/RegionInfo.cpp new file mode 100644 index 000000000000..2bd611350f46 --- /dev/null +++ b/contrib/llvm/lib/Analysis/RegionInfo.cpp @@ -0,0 +1,216 @@ +//===- RegionInfo.cpp - SESE region detection analysis --------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Detects single entry single exit regions in the control flow graph. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/ADT/Statistic.h" +#ifndef NDEBUG +#include "llvm/Analysis/RegionPrinter.h" +#endif +#include "llvm/Analysis/RegionInfoImpl.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "region" + +namespace llvm { + +template class RegionBase<RegionTraits<Function>>; +template class RegionNodeBase<RegionTraits<Function>>; +template class RegionInfoBase<RegionTraits<Function>>; + +} // end namespace llvm + +STATISTIC(numRegions,       "The # of regions"); +STATISTIC(numSimpleRegions, "The # of simple regions"); + +// Always verify if expensive checking is enabled. + +static cl::opt<bool,true> +VerifyRegionInfoX( +  "verify-region-info", +  cl::location(RegionInfoBase<RegionTraits<Function>>::VerifyRegionInfo), +  cl::desc("Verify region info (time consuming)")); + +static cl::opt<Region::PrintStyle, true> printStyleX("print-region-style", +  cl::location(RegionInfo::printStyle), +  cl::Hidden, +  cl::desc("style of printing regions"), +  cl::values( +    clEnumValN(Region::PrintNone, "none",  "print no details"), +    clEnumValN(Region::PrintBB, "bb", +               "print regions in detail with block_iterator"), +    clEnumValN(Region::PrintRN, "rn", +               "print regions in detail with element_iterator"))); + +//===----------------------------------------------------------------------===// +// Region implementation +// + +Region::Region(BasicBlock *Entry, BasicBlock *Exit, +               RegionInfo* RI, +               DominatorTree *DT, Region *Parent) : +  RegionBase<RegionTraits<Function>>(Entry, Exit, RI, DT, Parent) { + +} + +Region::~Region() = default; + +//===----------------------------------------------------------------------===// +// RegionInfo implementation +// + +RegionInfo::RegionInfo() = default; + +RegionInfo::~RegionInfo() = default; + +bool RegionInfo::invalidate(Function &F, const PreservedAnalyses &PA, +                            FunctionAnalysisManager::Invalidator &) { +  // Check whether the analysis, all analyses on functions, or the function's +  // CFG has been preserved. +  auto PAC = PA.getChecker<RegionInfoAnalysis>(); +  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>() || +           PAC.preservedSet<CFGAnalyses>()); +} + +void RegionInfo::updateStatistics(Region *R) { +  ++numRegions; + +  // TODO: Slow. Should only be enabled if -stats is used. +  if (R->isSimple()) +    ++numSimpleRegions; +} + +void RegionInfo::recalculate(Function &F, DominatorTree *DT_, +                             PostDominatorTree *PDT_, DominanceFrontier *DF_) { +  DT = DT_; +  PDT = PDT_; +  DF = DF_; + +  TopLevelRegion = new Region(&F.getEntryBlock(), nullptr, +                              this, DT, nullptr); +  updateStatistics(TopLevelRegion); +  calculate(F); +} + +#ifndef NDEBUG +void RegionInfo::view() { viewRegion(this); } + +void RegionInfo::viewOnly() { viewRegionOnly(this); } +#endif + +//===----------------------------------------------------------------------===// +// RegionInfoPass implementation +// + +RegionInfoPass::RegionInfoPass() : FunctionPass(ID) { +  initializeRegionInfoPassPass(*PassRegistry::getPassRegistry()); +} + +RegionInfoPass::~RegionInfoPass() = default; + +bool RegionInfoPass::runOnFunction(Function &F) { +  releaseMemory(); + +  auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +  auto PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); +  auto DF = &getAnalysis<DominanceFrontierWrapperPass>().getDominanceFrontier(); + +  RI.recalculate(F, DT, PDT, DF); +  return false; +} + +void RegionInfoPass::releaseMemory() { +  RI.releaseMemory(); +} + +void RegionInfoPass::verifyAnalysis() const { +    RI.verifyAnalysis(); +} + +void RegionInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequiredTransitive<DominatorTreeWrapperPass>(); +  AU.addRequired<PostDominatorTreeWrapperPass>(); +  AU.addRequired<DominanceFrontierWrapperPass>(); +} + +void RegionInfoPass::print(raw_ostream &OS, const Module *) const { +  RI.print(OS); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void RegionInfoPass::dump() const { +  RI.dump(); +} +#endif + +char RegionInfoPass::ID = 0; + +INITIALIZE_PASS_BEGIN(RegionInfoPass, "regions", +                "Detect single entry single exit regions", true, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominanceFrontierWrapperPass) +INITIALIZE_PASS_END(RegionInfoPass, "regions", +                "Detect single entry single exit regions", true, true) + +// Create methods available outside of this file, to use them +// "include/llvm/LinkAllPasses.h". Otherwise the pass would be deleted by +// the link time optimization. + +namespace llvm { + +  FunctionPass *createRegionInfoPass() { +    return new RegionInfoPass(); +  } + +} // end namespace llvm + +//===----------------------------------------------------------------------===// +// RegionInfoAnalysis implementation +// + +AnalysisKey RegionInfoAnalysis::Key; + +RegionInfo RegionInfoAnalysis::run(Function &F, FunctionAnalysisManager &AM) { +  RegionInfo RI; +  auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); +  auto *PDT = &AM.getResult<PostDominatorTreeAnalysis>(F); +  auto *DF = &AM.getResult<DominanceFrontierAnalysis>(F); + +  RI.recalculate(F, DT, PDT, DF); +  return RI; +} + +RegionInfoPrinterPass::RegionInfoPrinterPass(raw_ostream &OS) +  : OS(OS) {} + +PreservedAnalyses RegionInfoPrinterPass::run(Function &F, +                                             FunctionAnalysisManager &AM) { +  OS << "Region Tree for function: " << F.getName() << "\n"; +  AM.getResult<RegionInfoAnalysis>(F).print(OS); + +  return PreservedAnalyses::all(); +} + +PreservedAnalyses RegionInfoVerifierPass::run(Function &F, +                                              FunctionAnalysisManager &AM) { +  AM.getResult<RegionInfoAnalysis>(F).verifyAnalysis(); + +  return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Analysis/RegionPass.cpp b/contrib/llvm/lib/Analysis/RegionPass.cpp new file mode 100644 index 000000000000..ed17df2e7e93 --- /dev/null +++ b/contrib/llvm/lib/Analysis/RegionPass.cpp @@ -0,0 +1,294 @@ +//===- RegionPass.cpp - Region Pass and Region Pass Manager ---------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements RegionPass and RGPassManager. All region optimization +// and transformation passes are derived from RegionPass. RGPassManager is +// responsible for managing RegionPasses. +// Most of this code has been COPIED from LoopPass.cpp +// +//===----------------------------------------------------------------------===// +#include "llvm/Analysis/RegionPass.h" +#include "llvm/IR/OptBisect.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" +#include "llvm/Support/raw_ostream.h" +using namespace llvm; + +#define DEBUG_TYPE "regionpassmgr" + +//===----------------------------------------------------------------------===// +// RGPassManager +// + +char RGPassManager::ID = 0; + +RGPassManager::RGPassManager() +  : FunctionPass(ID), PMDataManager() { +  skipThisRegion = false; +  redoThisRegion = false; +  RI = nullptr; +  CurrentRegion = nullptr; +} + +// Recurse through all subregions and all regions  into RQ. +static void addRegionIntoQueue(Region &R, std::deque<Region *> &RQ) { +  RQ.push_back(&R); +  for (const auto &E : R) +    addRegionIntoQueue(*E, RQ); +} + +/// Pass Manager itself does not invalidate any analysis info. +void RGPassManager::getAnalysisUsage(AnalysisUsage &Info) const { +  Info.addRequired<RegionInfoPass>(); +  Info.setPreservesAll(); +} + +/// run - Execute all of the passes scheduled for execution.  Keep track of +/// whether any of the passes modifies the function, and if so, return true. +bool RGPassManager::runOnFunction(Function &F) { +  RI = &getAnalysis<RegionInfoPass>().getRegionInfo(); +  bool Changed = false; + +  // Collect inherited analysis from Module level pass manager. +  populateInheritedAnalysis(TPM->activeStack); + +  addRegionIntoQueue(*RI->getTopLevelRegion(), RQ); + +  if (RQ.empty()) // No regions, skip calling finalizers +    return false; + +  // Initialization +  for (Region *R : RQ) { +    for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +      RegionPass *RP = (RegionPass *)getContainedPass(Index); +      Changed |= RP->doInitialization(R, *this); +    } +  } + +  // Walk Regions +  while (!RQ.empty()) { + +    CurrentRegion  = RQ.back(); +    skipThisRegion = false; +    redoThisRegion = false; + +    // Run all passes on the current Region. +    for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +      RegionPass *P = (RegionPass*)getContainedPass(Index); + +      if (isPassDebuggingExecutionsOrMore()) { +        dumpPassInfo(P, EXECUTION_MSG, ON_REGION_MSG, +                     CurrentRegion->getNameStr()); +        dumpRequiredSet(P); +      } + +      initializeAnalysisImpl(P); + +      { +        PassManagerPrettyStackEntry X(P, *CurrentRegion->getEntry()); + +        TimeRegion PassTimer(getPassTimer(P)); +        Changed |= P->runOnRegion(CurrentRegion, *this); +      } + +      if (isPassDebuggingExecutionsOrMore()) { +        if (Changed) +          dumpPassInfo(P, MODIFICATION_MSG, ON_REGION_MSG, +                       skipThisRegion ? "<deleted>" : +                                      CurrentRegion->getNameStr()); +        dumpPreservedSet(P); +      } + +      if (!skipThisRegion) { +        // Manually check that this region is still healthy. This is done +        // instead of relying on RegionInfo::verifyRegion since RegionInfo +        // is a function pass and it's really expensive to verify every +        // Region in the function every time. That level of checking can be +        // enabled with the -verify-region-info option. +        { +          TimeRegion PassTimer(getPassTimer(P)); +          CurrentRegion->verifyRegion(); +        } + +        // Then call the regular verifyAnalysis functions. +        verifyPreservedAnalysis(P); +      } + +      removeNotPreservedAnalysis(P); +      recordAvailableAnalysis(P); +      removeDeadPasses(P, +                       (!isPassDebuggingExecutionsOrMore() || skipThisRegion) ? +                       "<deleted>" :  CurrentRegion->getNameStr(), +                       ON_REGION_MSG); + +      if (skipThisRegion) +        // Do not run other passes on this region. +        break; +    } + +    // If the region was deleted, release all the region passes. This frees up +    // some memory, and avoids trouble with the pass manager trying to call +    // verifyAnalysis on them. +    if (skipThisRegion) +      for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +        Pass *P = getContainedPass(Index); +        freePass(P, "<deleted>", ON_REGION_MSG); +      } + +    // Pop the region from queue after running all passes. +    RQ.pop_back(); + +    if (redoThisRegion) +      RQ.push_back(CurrentRegion); + +    // Free all region nodes created in region passes. +    RI->clearNodeCache(); +  } + +  // Finalization +  for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +    RegionPass *P = (RegionPass*)getContainedPass(Index); +    Changed |= P->doFinalization(); +  } + +  // Print the region tree after all pass. +  LLVM_DEBUG(dbgs() << "\nRegion tree of function " << F.getName() +                    << " after all region Pass:\n"; +             RI->dump(); dbgs() << "\n";); + +  return Changed; +} + +/// Print passes managed by this manager +void RGPassManager::dumpPassStructure(unsigned Offset) { +  errs().indent(Offset*2) << "Region Pass Manager\n"; +  for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { +    Pass *P = getContainedPass(Index); +    P->dumpPassStructure(Offset + 1); +    dumpLastUses(P, Offset+1); +  } +} + +namespace { +//===----------------------------------------------------------------------===// +// PrintRegionPass +class PrintRegionPass : public RegionPass { +private: +  std::string Banner; +  raw_ostream &Out;       // raw_ostream to print on. + +public: +  static char ID; +  PrintRegionPass(const std::string &B, raw_ostream &o) +      : RegionPass(ID), Banner(B), Out(o) {} + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.setPreservesAll(); +  } + +  bool runOnRegion(Region *R, RGPassManager &RGM) override { +    Out << Banner; +    for (const auto *BB : R->blocks()) { +      if (BB) +        BB->print(Out); +      else +        Out << "Printing <null> Block"; +    } + +    return false; +  } + +  StringRef getPassName() const override { return "Print Region IR"; } +}; + +char PrintRegionPass::ID = 0; +}  //end anonymous namespace + +//===----------------------------------------------------------------------===// +// RegionPass + +// Check if this pass is suitable for the current RGPassManager, if +// available. This pass P is not suitable for a RGPassManager if P +// is not preserving higher level analysis info used by other +// RGPassManager passes. In such case, pop RGPassManager from the +// stack. This will force assignPassManager() to create new +// LPPassManger as expected. +void RegionPass::preparePassManager(PMStack &PMS) { + +  // Find RGPassManager +  while (!PMS.empty() && +         PMS.top()->getPassManagerType() > PMT_RegionPassManager) +    PMS.pop(); + + +  // If this pass is destroying high level information that is used +  // by other passes that are managed by LPM then do not insert +  // this pass in current LPM. Use new RGPassManager. +  if (PMS.top()->getPassManagerType() == PMT_RegionPassManager && +    !PMS.top()->preserveHigherLevelAnalysis(this)) +    PMS.pop(); +} + +/// Assign pass manager to manage this pass. +void RegionPass::assignPassManager(PMStack &PMS, +                                 PassManagerType PreferredType) { +  // Find RGPassManager +  while (!PMS.empty() && +         PMS.top()->getPassManagerType() > PMT_RegionPassManager) +    PMS.pop(); + +  RGPassManager *RGPM; + +  // Create new Region Pass Manager if it does not exist. +  if (PMS.top()->getPassManagerType() == PMT_RegionPassManager) +    RGPM = (RGPassManager*)PMS.top(); +  else { + +    assert (!PMS.empty() && "Unable to create Region Pass Manager"); +    PMDataManager *PMD = PMS.top(); + +    // [1] Create new Region Pass Manager +    RGPM = new RGPassManager(); +    RGPM->populateInheritedAnalysis(PMS); + +    // [2] Set up new manager's top level manager +    PMTopLevelManager *TPM = PMD->getTopLevelManager(); +    TPM->addIndirectPassManager(RGPM); + +    // [3] Assign manager to manage this new manager. This may create +    // and push new managers into PMS +    TPM->schedulePass(RGPM); + +    // [4] Push new manager into PMS +    PMS.push(RGPM); +  } + +  RGPM->add(this); +} + +/// Get the printer pass +Pass *RegionPass::createPrinterPass(raw_ostream &O, +                                  const std::string &Banner) const { +  return new PrintRegionPass(Banner, O); +} + +bool RegionPass::skipRegion(Region &R) const { +  Function &F = *R.getEntry()->getParent(); +  if (!F.getContext().getOptPassGate().shouldRunPass(this, R)) +    return true; + +  if (F.hasFnAttribute(Attribute::OptimizeNone)) { +    // Report this only once per function. +    if (R.getEntry() == &F.getEntryBlock()) +      LLVM_DEBUG(dbgs() << "Skipping pass '" << getPassName() +                        << "' on function " << F.getName() << "\n"); +    return true; +  } +  return false; +} diff --git a/contrib/llvm/lib/Analysis/RegionPrinter.cpp b/contrib/llvm/lib/Analysis/RegionPrinter.cpp new file mode 100644 index 000000000000..5986b8c4e0c3 --- /dev/null +++ b/contrib/llvm/lib/Analysis/RegionPrinter.cpp @@ -0,0 +1,267 @@ +//===- RegionPrinter.cpp - Print regions tree pass ------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Print out the region tree of a function using dotty/graphviz. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/RegionPrinter.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DOTGraphTraitsPass.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/Analysis/RegionIterator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#ifndef NDEBUG +#include "llvm/IR/LegacyPassManager.h" +#endif + +using namespace llvm; + +//===----------------------------------------------------------------------===// +/// onlySimpleRegion - Show only the simple regions in the RegionViewer. +static cl::opt<bool> +onlySimpleRegions("only-simple-regions", +                  cl::desc("Show only simple regions in the graphviz viewer"), +                  cl::Hidden, +                  cl::init(false)); + +namespace llvm { +template<> +struct DOTGraphTraits<RegionNode*> : public DefaultDOTGraphTraits { + +  DOTGraphTraits (bool isSimple=false) +    : DefaultDOTGraphTraits(isSimple) {} + +  std::string getNodeLabel(RegionNode *Node, RegionNode *Graph) { + +    if (!Node->isSubRegion()) { +      BasicBlock *BB = Node->getNodeAs<BasicBlock>(); + +      if (isSimple()) +        return DOTGraphTraits<const Function*> +          ::getSimpleNodeLabel(BB, BB->getParent()); +      else +        return DOTGraphTraits<const Function*> +          ::getCompleteNodeLabel(BB, BB->getParent()); +    } + +    return "Not implemented"; +  } +}; + +template <> +struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> { + +  DOTGraphTraits (bool isSimple = false) +    : DOTGraphTraits<RegionNode*>(isSimple) {} + +  static std::string getGraphName(const RegionInfo *) { return "Region Graph"; } + +  std::string getNodeLabel(RegionNode *Node, RegionInfo *G) { +    return DOTGraphTraits<RegionNode *>::getNodeLabel( +        Node, reinterpret_cast<RegionNode *>(G->getTopLevelRegion())); +  } + +  std::string getEdgeAttributes(RegionNode *srcNode, +                                GraphTraits<RegionInfo *>::ChildIteratorType CI, +                                RegionInfo *G) { +    RegionNode *destNode = *CI; + +    if (srcNode->isSubRegion() || destNode->isSubRegion()) +      return ""; + +    // In case of a backedge, do not use it to define the layout of the nodes. +    BasicBlock *srcBB = srcNode->getNodeAs<BasicBlock>(); +    BasicBlock *destBB = destNode->getNodeAs<BasicBlock>(); + +    Region *R = G->getRegionFor(destBB); + +    while (R && R->getParent()) +      if (R->getParent()->getEntry() == destBB) +        R = R->getParent(); +      else +        break; + +    if (R && R->getEntry() == destBB && R->contains(srcBB)) +      return "constraint=false"; + +    return ""; +  } + +  // Print the cluster of the subregions. This groups the single basic blocks +  // and adds a different background color for each group. +  static void printRegionCluster(const Region &R, GraphWriter<RegionInfo *> &GW, +                                 unsigned depth = 0) { +    raw_ostream &O = GW.getOStream(); +    O.indent(2 * depth) << "subgraph cluster_" << static_cast<const void*>(&R) +      << " {\n"; +    O.indent(2 * (depth + 1)) << "label = \"\";\n"; + +    if (!onlySimpleRegions || R.isSimple()) { +      O.indent(2 * (depth + 1)) << "style = filled;\n"; +      O.indent(2 * (depth + 1)) << "color = " +        << ((R.getDepth() * 2 % 12) + 1) << "\n"; + +    } else { +      O.indent(2 * (depth + 1)) << "style = solid;\n"; +      O.indent(2 * (depth + 1)) << "color = " +        << ((R.getDepth() * 2 % 12) + 2) << "\n"; +    } + +    for (const auto &RI : R) +      printRegionCluster(*RI, GW, depth + 1); + +    const RegionInfo &RI = *static_cast<const RegionInfo*>(R.getRegionInfo()); + +    for (auto *BB : R.blocks()) +      if (RI.getRegionFor(BB) == &R) +        O.indent(2 * (depth + 1)) << "Node" +          << static_cast<const void*>(RI.getTopLevelRegion()->getBBNode(BB)) +          << ";\n"; + +    O.indent(2 * depth) << "}\n"; +  } + +  static void addCustomGraphFeatures(const RegionInfo *G, +                                     GraphWriter<RegionInfo *> &GW) { +    raw_ostream &O = GW.getOStream(); +    O << "\tcolorscheme = \"paired12\"\n"; +    printRegionCluster(*G->getTopLevelRegion(), GW, 4); +  } +}; +} //end namespace llvm + +namespace { + +struct RegionInfoPassGraphTraits { +  static RegionInfo *getGraph(RegionInfoPass *RIP) { +    return &RIP->getRegionInfo(); +  } +}; + +struct RegionPrinter +    : public DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *, +                                   RegionInfoPassGraphTraits> { +  static char ID; +  RegionPrinter() +      : DOTGraphTraitsPrinter<RegionInfoPass, false, RegionInfo *, +                              RegionInfoPassGraphTraits>("reg", ID) { +    initializeRegionPrinterPass(*PassRegistry::getPassRegistry()); +  } +}; +char RegionPrinter::ID = 0; + +struct RegionOnlyPrinter +    : public DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *, +                                   RegionInfoPassGraphTraits> { +  static char ID; +  RegionOnlyPrinter() +      : DOTGraphTraitsPrinter<RegionInfoPass, true, RegionInfo *, +                              RegionInfoPassGraphTraits>("reg", ID) { +    initializeRegionOnlyPrinterPass(*PassRegistry::getPassRegistry()); +  } +}; +char RegionOnlyPrinter::ID = 0; + +struct RegionViewer +    : public DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *, +                                  RegionInfoPassGraphTraits> { +  static char ID; +  RegionViewer() +      : DOTGraphTraitsViewer<RegionInfoPass, false, RegionInfo *, +                             RegionInfoPassGraphTraits>("reg", ID) { +    initializeRegionViewerPass(*PassRegistry::getPassRegistry()); +  } +}; +char RegionViewer::ID = 0; + +struct RegionOnlyViewer +    : public DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *, +                                  RegionInfoPassGraphTraits> { +  static char ID; +  RegionOnlyViewer() +      : DOTGraphTraitsViewer<RegionInfoPass, true, RegionInfo *, +                             RegionInfoPassGraphTraits>("regonly", ID) { +    initializeRegionOnlyViewerPass(*PassRegistry::getPassRegistry()); +  } +}; +char RegionOnlyViewer::ID = 0; + +} //end anonymous namespace + +INITIALIZE_PASS(RegionPrinter, "dot-regions", +                "Print regions of function to 'dot' file", true, true) + +INITIALIZE_PASS( +    RegionOnlyPrinter, "dot-regions-only", +    "Print regions of function to 'dot' file (with no function bodies)", true, +    true) + +INITIALIZE_PASS(RegionViewer, "view-regions", "View regions of function", +                true, true) + +INITIALIZE_PASS(RegionOnlyViewer, "view-regions-only", +                "View regions of function (with no function bodies)", +                true, true) + +FunctionPass *llvm::createRegionPrinterPass() { return new RegionPrinter(); } + +FunctionPass *llvm::createRegionOnlyPrinterPass() { +  return new RegionOnlyPrinter(); +} + +FunctionPass* llvm::createRegionViewerPass() { +  return new RegionViewer(); +} + +FunctionPass* llvm::createRegionOnlyViewerPass() { +  return new RegionOnlyViewer(); +} + +#ifndef NDEBUG +static void viewRegionInfo(RegionInfo *RI, bool ShortNames) { +  assert(RI && "Argument must be non-null"); + +  llvm::Function *F = RI->getTopLevelRegion()->getEntry()->getParent(); +  std::string GraphName = DOTGraphTraits<RegionInfo *>::getGraphName(RI); + +  llvm::ViewGraph(RI, "reg", ShortNames, +                  Twine(GraphName) + " for '" + F->getName() + "' function"); +} + +static void invokeFunctionPass(const Function *F, FunctionPass *ViewerPass) { +  assert(F && "Argument must be non-null"); +  assert(!F->isDeclaration() && "Function must have an implementation"); + +  // The viewer and analysis passes do not modify anything, so we can safely +  // remove the const qualifier +  auto NonConstF = const_cast<Function *>(F); + +  llvm::legacy::FunctionPassManager FPM(NonConstF->getParent()); +  FPM.add(ViewerPass); +  FPM.doInitialization(); +  FPM.run(*NonConstF); +  FPM.doFinalization(); +} + +void llvm::viewRegion(RegionInfo *RI) { viewRegionInfo(RI, false); } + +void llvm::viewRegion(const Function *F) { +  invokeFunctionPass(F, createRegionViewerPass()); +} + +void llvm::viewRegionOnly(RegionInfo *RI) { viewRegionInfo(RI, true); } + +void llvm::viewRegionOnly(const Function *F) { +  invokeFunctionPass(F, createRegionOnlyViewerPass()); +} +#endif diff --git a/contrib/llvm/lib/Analysis/ScalarEvolution.cpp b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp new file mode 100644 index 000000000000..0e715b8814ff --- /dev/null +++ b/contrib/llvm/lib/Analysis/ScalarEvolution.cpp @@ -0,0 +1,12293 @@ +//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the scalar evolution analysis +// engine, which is used primarily to analyze expressions involving induction +// variables in loops. +// +// There are several aspects to this library.  First is the representation of +// scalar expressions, which are represented as subclasses of the SCEV class. +// These classes are used to represent certain types of subexpressions that we +// can handle. We only create one SCEV of a particular shape, so +// pointer-comparisons for equality are legal. +// +// One important aspect of the SCEV objects is that they are never cyclic, even +// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If +// the PHI node is one of the idioms that we can represent (e.g., a polynomial +// recurrence) then we represent it directly as a recurrence node, otherwise we +// represent it as a SCEVUnknown node. +// +// In addition to being able to represent expressions of various types, we also +// have folders that are used to build the *canonical* representation for a +// particular expression.  These folders are capable of using a variety of +// rewrite rules to simplify the expressions. +// +// Once the folders are defined, we can implement the more interesting +// higher-level code, such as the code that recognizes PHI nodes of various +// types, computes the execution count of a loop, etc. +// +// TODO: We should use these routines and value representations to implement +// dependence analysis! +// +//===----------------------------------------------------------------------===// +// +// There are several good references for the techniques used in this analysis. +// +//  Chains of recurrences -- a method to expedite the evaluation +//  of closed-form functions +//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima +// +//  On computational properties of chains of recurrences +//  Eugene V. Zima +// +//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization +//  Robert A. van Engelen +// +//  Efficient Symbolic Analysis for Optimizing Compilers +//  Robert A. van Engelen +// +//  Using the chains of recurrences algebra for data dependence testing and +//  induction variable substitution +//  MS Thesis, Johnie Birch +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/SaveAndRestore.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <climits> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <map> +#include <memory> +#include <tuple> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "scalar-evolution" + +STATISTIC(NumArrayLenItCounts, +          "Number of trip counts computed with array length"); +STATISTIC(NumTripCountsComputed, +          "Number of loops with predictable loop counts"); +STATISTIC(NumTripCountsNotComputed, +          "Number of loops without predictable loop counts"); +STATISTIC(NumBruteForceTripCountsComputed, +          "Number of loops with trip counts computed by force"); + +static cl::opt<unsigned> +MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, +                        cl::desc("Maximum number of iterations SCEV will " +                                 "symbolically execute a constant " +                                 "derived loop"), +                        cl::init(100)); + +// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. +static cl::opt<bool> VerifySCEV( +    "verify-scev", cl::Hidden, +    cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); +static cl::opt<bool> +    VerifySCEVMap("verify-scev-maps", cl::Hidden, +                  cl::desc("Verify no dangling value in ScalarEvolution's " +                           "ExprValueMap (slow)")); + +static cl::opt<unsigned> MulOpsInlineThreshold( +    "scev-mulops-inline-threshold", cl::Hidden, +    cl::desc("Threshold for inlining multiplication operands into a SCEV"), +    cl::init(32)); + +static cl::opt<unsigned> AddOpsInlineThreshold( +    "scev-addops-inline-threshold", cl::Hidden, +    cl::desc("Threshold for inlining addition operands into a SCEV"), +    cl::init(500)); + +static cl::opt<unsigned> MaxSCEVCompareDepth( +    "scalar-evolution-max-scev-compare-depth", cl::Hidden, +    cl::desc("Maximum depth of recursive SCEV complexity comparisons"), +    cl::init(32)); + +static cl::opt<unsigned> MaxSCEVOperationsImplicationDepth( +    "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, +    cl::desc("Maximum depth of recursive SCEV operations implication analysis"), +    cl::init(2)); + +static cl::opt<unsigned> MaxValueCompareDepth( +    "scalar-evolution-max-value-compare-depth", cl::Hidden, +    cl::desc("Maximum depth of recursive value complexity comparisons"), +    cl::init(2)); + +static cl::opt<unsigned> +    MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, +                  cl::desc("Maximum depth of recursive arithmetics"), +                  cl::init(32)); + +static cl::opt<unsigned> MaxConstantEvolvingDepth( +    "scalar-evolution-max-constant-evolving-depth", cl::Hidden, +    cl::desc("Maximum depth of recursive constant evolving"), cl::init(32)); + +static cl::opt<unsigned> +    MaxExtDepth("scalar-evolution-max-ext-depth", cl::Hidden, +                cl::desc("Maximum depth of recursive SExt/ZExt"), +                cl::init(8)); + +static cl::opt<unsigned> +    MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, +                  cl::desc("Max coefficients in AddRec during evolving"), +                  cl::init(16)); + +//===----------------------------------------------------------------------===// +//                           SCEV class definitions +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Implementation of the SCEV class. +// + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void SCEV::dump() const { +  print(dbgs()); +  dbgs() << '\n'; +} +#endif + +void SCEV::print(raw_ostream &OS) const { +  switch (static_cast<SCEVTypes>(getSCEVType())) { +  case scConstant: +    cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); +    return; +  case scTruncate: { +    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this); +    const SCEV *Op = Trunc->getOperand(); +    OS << "(trunc " << *Op->getType() << " " << *Op << " to " +       << *Trunc->getType() << ")"; +    return; +  } +  case scZeroExtend: { +    const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this); +    const SCEV *Op = ZExt->getOperand(); +    OS << "(zext " << *Op->getType() << " " << *Op << " to " +       << *ZExt->getType() << ")"; +    return; +  } +  case scSignExtend: { +    const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this); +    const SCEV *Op = SExt->getOperand(); +    OS << "(sext " << *Op->getType() << " " << *Op << " to " +       << *SExt->getType() << ")"; +    return; +  } +  case scAddRecExpr: { +    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this); +    OS << "{" << *AR->getOperand(0); +    for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) +      OS << ",+," << *AR->getOperand(i); +    OS << "}<"; +    if (AR->hasNoUnsignedWrap()) +      OS << "nuw><"; +    if (AR->hasNoSignedWrap()) +      OS << "nsw><"; +    if (AR->hasNoSelfWrap() && +        !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) +      OS << "nw><"; +    AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); +    OS << ">"; +    return; +  } +  case scAddExpr: +  case scMulExpr: +  case scUMaxExpr: +  case scSMaxExpr: { +    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this); +    const char *OpStr = nullptr; +    switch (NAry->getSCEVType()) { +    case scAddExpr: OpStr = " + "; break; +    case scMulExpr: OpStr = " * "; break; +    case scUMaxExpr: OpStr = " umax "; break; +    case scSMaxExpr: OpStr = " smax "; break; +    } +    OS << "("; +    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); +         I != E; ++I) { +      OS << **I; +      if (std::next(I) != E) +        OS << OpStr; +    } +    OS << ")"; +    switch (NAry->getSCEVType()) { +    case scAddExpr: +    case scMulExpr: +      if (NAry->hasNoUnsignedWrap()) +        OS << "<nuw>"; +      if (NAry->hasNoSignedWrap()) +        OS << "<nsw>"; +    } +    return; +  } +  case scUDivExpr: { +    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this); +    OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; +    return; +  } +  case scUnknown: { +    const SCEVUnknown *U = cast<SCEVUnknown>(this); +    Type *AllocTy; +    if (U->isSizeOf(AllocTy)) { +      OS << "sizeof(" << *AllocTy << ")"; +      return; +    } +    if (U->isAlignOf(AllocTy)) { +      OS << "alignof(" << *AllocTy << ")"; +      return; +    } + +    Type *CTy; +    Constant *FieldNo; +    if (U->isOffsetOf(CTy, FieldNo)) { +      OS << "offsetof(" << *CTy << ", "; +      FieldNo->printAsOperand(OS, false); +      OS << ")"; +      return; +    } + +    // Otherwise just print it normally. +    U->getValue()->printAsOperand(OS, false); +    return; +  } +  case scCouldNotCompute: +    OS << "***COULDNOTCOMPUTE***"; +    return; +  } +  llvm_unreachable("Unknown SCEV kind!"); +} + +Type *SCEV::getType() const { +  switch (static_cast<SCEVTypes>(getSCEVType())) { +  case scConstant: +    return cast<SCEVConstant>(this)->getType(); +  case scTruncate: +  case scZeroExtend: +  case scSignExtend: +    return cast<SCEVCastExpr>(this)->getType(); +  case scAddRecExpr: +  case scMulExpr: +  case scUMaxExpr: +  case scSMaxExpr: +    return cast<SCEVNAryExpr>(this)->getType(); +  case scAddExpr: +    return cast<SCEVAddExpr>(this)->getType(); +  case scUDivExpr: +    return cast<SCEVUDivExpr>(this)->getType(); +  case scUnknown: +    return cast<SCEVUnknown>(this)->getType(); +  case scCouldNotCompute: +    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); +  } +  llvm_unreachable("Unknown SCEV kind!"); +} + +bool SCEV::isZero() const { +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) +    return SC->getValue()->isZero(); +  return false; +} + +bool SCEV::isOne() const { +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) +    return SC->getValue()->isOne(); +  return false; +} + +bool SCEV::isAllOnesValue() const { +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this)) +    return SC->getValue()->isMinusOne(); +  return false; +} + +bool SCEV::isNonConstantNegative() const { +  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); +  if (!Mul) return false; + +  // If there is a constant factor, it will be first. +  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0)); +  if (!SC) return false; + +  // Return true if the value is negative, this matches things like (-42 * V). +  return SC->getAPInt().isNegative(); +} + +SCEVCouldNotCompute::SCEVCouldNotCompute() : +  SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {} + +bool SCEVCouldNotCompute::classof(const SCEV *S) { +  return S->getSCEVType() == scCouldNotCompute; +} + +const SCEV *ScalarEvolution::getConstant(ConstantInt *V) { +  FoldingSetNodeID ID; +  ID.AddInteger(scConstant); +  ID.AddPointer(V); +  void *IP = nullptr; +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V); +  UniqueSCEVs.InsertNode(S, IP); +  return S; +} + +const SCEV *ScalarEvolution::getConstant(const APInt &Val) { +  return getConstant(ConstantInt::get(getContext(), Val)); +} + +const SCEV * +ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { +  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty)); +  return getConstant(ConstantInt::get(ITy, V, isSigned)); +} + +SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, +                           unsigned SCEVTy, const SCEV *op, Type *ty) +  : SCEV(ID, SCEVTy), Op(op), Ty(ty) {} + +SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, +                                   const SCEV *op, Type *ty) +  : SCEVCastExpr(ID, scTruncate, op, ty) { +  assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot truncate non-integer value!"); +} + +SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, +                                       const SCEV *op, Type *ty) +  : SCEVCastExpr(ID, scZeroExtend, op, ty) { +  assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot zero extend non-integer value!"); +} + +SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, +                                       const SCEV *op, Type *ty) +  : SCEVCastExpr(ID, scSignExtend, op, ty) { +  assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot sign extend non-integer value!"); +} + +void SCEVUnknown::deleted() { +  // Clear this SCEVUnknown from various maps. +  SE->forgetMemoizedResults(this); + +  // Remove this SCEVUnknown from the uniquing map. +  SE->UniqueSCEVs.RemoveNode(this); + +  // Release the value. +  setValPtr(nullptr); +} + +void SCEVUnknown::allUsesReplacedWith(Value *New) { +  // Remove this SCEVUnknown from the uniquing map. +  SE->UniqueSCEVs.RemoveNode(this); + +  // Update this SCEVUnknown to point to the new value. This is needed +  // because there may still be outstanding SCEVs which still point to +  // this SCEVUnknown. +  setValPtr(New); +} + +bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { +  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) +    if (VCE->getOpcode() == Instruction::PtrToInt) +      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) +        if (CE->getOpcode() == Instruction::GetElementPtr && +            CE->getOperand(0)->isNullValue() && +            CE->getNumOperands() == 2) +          if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1))) +            if (CI->isOne()) { +              AllocTy = cast<PointerType>(CE->getOperand(0)->getType()) +                                 ->getElementType(); +              return true; +            } + +  return false; +} + +bool SCEVUnknown::isAlignOf(Type *&AllocTy) const { +  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) +    if (VCE->getOpcode() == Instruction::PtrToInt) +      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) +        if (CE->getOpcode() == Instruction::GetElementPtr && +            CE->getOperand(0)->isNullValue()) { +          Type *Ty = +            cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); +          if (StructType *STy = dyn_cast<StructType>(Ty)) +            if (!STy->isPacked() && +                CE->getNumOperands() == 3 && +                CE->getOperand(1)->isNullValue()) { +              if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2))) +                if (CI->isOne() && +                    STy->getNumElements() == 2 && +                    STy->getElementType(0)->isIntegerTy(1)) { +                  AllocTy = STy->getElementType(1); +                  return true; +                } +            } +        } + +  return false; +} + +bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { +  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue())) +    if (VCE->getOpcode() == Instruction::PtrToInt) +      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0))) +        if (CE->getOpcode() == Instruction::GetElementPtr && +            CE->getNumOperands() == 3 && +            CE->getOperand(0)->isNullValue() && +            CE->getOperand(1)->isNullValue()) { +          Type *Ty = +            cast<PointerType>(CE->getOperand(0)->getType())->getElementType(); +          // Ignore vector types here so that ScalarEvolutionExpander doesn't +          // emit getelementptrs that index into vectors. +          if (Ty->isStructTy() || Ty->isArrayTy()) { +            CTy = Ty; +            FieldNo = CE->getOperand(2); +            return true; +          } +        } + +  return false; +} + +//===----------------------------------------------------------------------===// +//                               SCEV Utilities +//===----------------------------------------------------------------------===// + +/// Compare the two values \p LV and \p RV in terms of their "complexity" where +/// "complexity" is a partial (and somewhat ad-hoc) relation used to order +/// operands in SCEV expressions.  \p EqCache is a set of pairs of values that +/// have been previously deemed to be "equally complex" by this routine.  It is +/// intended to avoid exponential time complexity in cases like: +/// +///   %a = f(%x, %y) +///   %b = f(%a, %a) +///   %c = f(%b, %b) +/// +///   %d = f(%x, %y) +///   %e = f(%d, %d) +///   %f = f(%e, %e) +/// +///   CompareValueComplexity(%f, %c) +/// +/// Since we do not continue running this routine on expression trees once we +/// have seen unequal values, there is no need to track them in the cache. +static int +CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue, +                       const LoopInfo *const LI, Value *LV, Value *RV, +                       unsigned Depth) { +  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV)) +    return 0; + +  // Order pointer values after integer values. This helps SCEVExpander form +  // GEPs. +  bool LIsPointer = LV->getType()->isPointerTy(), +       RIsPointer = RV->getType()->isPointerTy(); +  if (LIsPointer != RIsPointer) +    return (int)LIsPointer - (int)RIsPointer; + +  // Compare getValueID values. +  unsigned LID = LV->getValueID(), RID = RV->getValueID(); +  if (LID != RID) +    return (int)LID - (int)RID; + +  // Sort arguments by their position. +  if (const auto *LA = dyn_cast<Argument>(LV)) { +    const auto *RA = cast<Argument>(RV); +    unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); +    return (int)LArgNo - (int)RArgNo; +  } + +  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) { +    const auto *RGV = cast<GlobalValue>(RV); + +    const auto IsGVNameSemantic = [&](const GlobalValue *GV) { +      auto LT = GV->getLinkage(); +      return !(GlobalValue::isPrivateLinkage(LT) || +               GlobalValue::isInternalLinkage(LT)); +    }; + +    // Use the names to distinguish the two values, but only if the +    // names are semantically important. +    if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV)) +      return LGV->getName().compare(RGV->getName()); +  } + +  // For instructions, compare their loop depth, and their operand count.  This +  // is pretty loose. +  if (const auto *LInst = dyn_cast<Instruction>(LV)) { +    const auto *RInst = cast<Instruction>(RV); + +    // Compare loop depths. +    const BasicBlock *LParent = LInst->getParent(), +                     *RParent = RInst->getParent(); +    if (LParent != RParent) { +      unsigned LDepth = LI->getLoopDepth(LParent), +               RDepth = LI->getLoopDepth(RParent); +      if (LDepth != RDepth) +        return (int)LDepth - (int)RDepth; +    } + +    // Compare the number of operands. +    unsigned LNumOps = LInst->getNumOperands(), +             RNumOps = RInst->getNumOperands(); +    if (LNumOps != RNumOps) +      return (int)LNumOps - (int)RNumOps; + +    for (unsigned Idx : seq(0u, LNumOps)) { +      int Result = +          CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx), +                                 RInst->getOperand(Idx), Depth + 1); +      if (Result != 0) +        return Result; +    } +  } + +  EqCacheValue.unionSets(LV, RV); +  return 0; +} + +// Return negative, zero, or positive, if LHS is less than, equal to, or greater +// than RHS, respectively. A three-way result allows recursive comparisons to be +// more efficient. +static int CompareSCEVComplexity( +    EquivalenceClasses<const SCEV *> &EqCacheSCEV, +    EquivalenceClasses<const Value *> &EqCacheValue, +    const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, +    DominatorTree &DT, unsigned Depth = 0) { +  // Fast-path: SCEVs are uniqued so we can do a quick equality check. +  if (LHS == RHS) +    return 0; + +  // Primarily, sort the SCEVs by their getSCEVType(). +  unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); +  if (LType != RType) +    return (int)LType - (int)RType; + +  if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS)) +    return 0; +  // Aside from the getSCEVType() ordering, the particular ordering +  // isn't very important except that it's beneficial to be consistent, +  // so that (a + b) and (b + a) don't end up as different expressions. +  switch (static_cast<SCEVTypes>(LType)) { +  case scUnknown: { +    const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); +    const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); + +    int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(), +                                   RU->getValue(), Depth + 1); +    if (X == 0) +      EqCacheSCEV.unionSets(LHS, RHS); +    return X; +  } + +  case scConstant: { +    const SCEVConstant *LC = cast<SCEVConstant>(LHS); +    const SCEVConstant *RC = cast<SCEVConstant>(RHS); + +    // Compare constant values. +    const APInt &LA = LC->getAPInt(); +    const APInt &RA = RC->getAPInt(); +    unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); +    if (LBitWidth != RBitWidth) +      return (int)LBitWidth - (int)RBitWidth; +    return LA.ult(RA) ? -1 : 1; +  } + +  case scAddRecExpr: { +    const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); +    const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); + +    // There is always a dominance between two recs that are used by one SCEV, +    // so we can safely sort recs by loop header dominance. We require such +    // order in getAddExpr. +    const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); +    if (LLoop != RLoop) { +      const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader(); +      assert(LHead != RHead && "Two loops share the same header?"); +      if (DT.dominates(LHead, RHead)) +        return 1; +      else +        assert(DT.dominates(RHead, LHead) && +               "No dominance between recurrences used by one SCEV?"); +      return -1; +    } + +    // Addrec complexity grows with operand count. +    unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); +    if (LNumOps != RNumOps) +      return (int)LNumOps - (int)RNumOps; + +    // Compare NoWrap flags. +    if (LA->getNoWrapFlags() != RA->getNoWrapFlags()) +      return (int)LA->getNoWrapFlags() - (int)RA->getNoWrapFlags(); + +    // Lexicographically compare. +    for (unsigned i = 0; i != LNumOps; ++i) { +      int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                    LA->getOperand(i), RA->getOperand(i), DT, +                                    Depth + 1); +      if (X != 0) +        return X; +    } +    EqCacheSCEV.unionSets(LHS, RHS); +    return 0; +  } + +  case scAddExpr: +  case scMulExpr: +  case scSMaxExpr: +  case scUMaxExpr: { +    const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); +    const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); + +    // Lexicographically compare n-ary expressions. +    unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); +    if (LNumOps != RNumOps) +      return (int)LNumOps - (int)RNumOps; + +    // Compare NoWrap flags. +    if (LC->getNoWrapFlags() != RC->getNoWrapFlags()) +      return (int)LC->getNoWrapFlags() - (int)RC->getNoWrapFlags(); + +    for (unsigned i = 0; i != LNumOps; ++i) { +      int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                    LC->getOperand(i), RC->getOperand(i), DT, +                                    Depth + 1); +      if (X != 0) +        return X; +    } +    EqCacheSCEV.unionSets(LHS, RHS); +    return 0; +  } + +  case scUDivExpr: { +    const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); +    const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); + +    // Lexicographically compare udiv expressions. +    int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(), +                                  RC->getLHS(), DT, Depth + 1); +    if (X != 0) +      return X; +    X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(), +                              RC->getRHS(), DT, Depth + 1); +    if (X == 0) +      EqCacheSCEV.unionSets(LHS, RHS); +    return X; +  } + +  case scTruncate: +  case scZeroExtend: +  case scSignExtend: { +    const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); +    const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); + +    // Compare cast expressions by operand. +    int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                  LC->getOperand(), RC->getOperand(), DT, +                                  Depth + 1); +    if (X == 0) +      EqCacheSCEV.unionSets(LHS, RHS); +    return X; +  } + +  case scCouldNotCompute: +    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); +  } +  llvm_unreachable("Unknown SCEV kind!"); +} + +/// Given a list of SCEV objects, order them by their complexity, and group +/// objects of the same complexity together by value.  When this routine is +/// finished, we know that any duplicates in the vector are consecutive and that +/// complexity is monotonically increasing. +/// +/// Note that we go take special precautions to ensure that we get deterministic +/// results from this routine.  In other words, we don't want the results of +/// this to depend on where the addresses of various SCEV objects happened to +/// land in memory. +static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, +                              LoopInfo *LI, DominatorTree &DT) { +  if (Ops.size() < 2) return;  // Noop + +  EquivalenceClasses<const SCEV *> EqCacheSCEV; +  EquivalenceClasses<const Value *> EqCacheValue; +  if (Ops.size() == 2) { +    // This is the common case, which also happens to be trivially simple. +    // Special case it. +    const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; +    if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0) +      std::swap(LHS, RHS); +    return; +  } + +  // Do the rough sort by complexity. +  std::stable_sort(Ops.begin(), Ops.end(), +                   [&](const SCEV *LHS, const SCEV *RHS) { +                     return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                                  LHS, RHS, DT) < 0; +                   }); + +  // Now that we are sorted by complexity, group elements of the same +  // complexity.  Note that this is, at worst, N^2, but the vector is likely to +  // be extremely short in practice.  Note that we take this approach because we +  // do not want to depend on the addresses of the objects we are grouping. +  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) { +    const SCEV *S = Ops[i]; +    unsigned Complexity = S->getSCEVType(); + +    // If there are any objects of the same complexity and same value as this +    // one, group them. +    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) { +      if (Ops[j] == S) { // Found a duplicate. +        // Move it to immediately after i'th element. +        std::swap(Ops[i+1], Ops[j]); +        ++i;   // no need to rescan it. +        if (i == e-2) return;  // Done! +      } +    } +  } +} + +// Returns the size of the SCEV S. +static inline int sizeOfSCEV(const SCEV *S) { +  struct FindSCEVSize { +    int Size = 0; + +    FindSCEVSize() = default; + +    bool follow(const SCEV *S) { +      ++Size; +      // Keep looking at all operands of S. +      return true; +    } + +    bool isDone() const { +      return false; +    } +  }; + +  FindSCEVSize F; +  SCEVTraversal<FindSCEVSize> ST(F); +  ST.visitAll(S); +  return F.Size; +} + +namespace { + +struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { +public: +  // Computes the Quotient and Remainder of the division of Numerator by +  // Denominator. +  static void divide(ScalarEvolution &SE, const SCEV *Numerator, +                     const SCEV *Denominator, const SCEV **Quotient, +                     const SCEV **Remainder) { +    assert(Numerator && Denominator && "Uninitialized SCEV"); + +    SCEVDivision D(SE, Numerator, Denominator); + +    // Check for the trivial case here to avoid having to check for it in the +    // rest of the code. +    if (Numerator == Denominator) { +      *Quotient = D.One; +      *Remainder = D.Zero; +      return; +    } + +    if (Numerator->isZero()) { +      *Quotient = D.Zero; +      *Remainder = D.Zero; +      return; +    } + +    // A simple case when N/1. The quotient is N. +    if (Denominator->isOne()) { +      *Quotient = Numerator; +      *Remainder = D.Zero; +      return; +    } + +    // Split the Denominator when it is a product. +    if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { +      const SCEV *Q, *R; +      *Quotient = Numerator; +      for (const SCEV *Op : T->operands()) { +        divide(SE, *Quotient, Op, &Q, &R); +        *Quotient = Q; + +        // Bail out when the Numerator is not divisible by one of the terms of +        // the Denominator. +        if (!R->isZero()) { +          *Quotient = D.Zero; +          *Remainder = Numerator; +          return; +        } +      } +      *Remainder = D.Zero; +      return; +    } + +    D.visit(Numerator); +    *Quotient = D.Quotient; +    *Remainder = D.Remainder; +  } + +  // Except in the trivial case described above, we do not know how to divide +  // Expr by Denominator for the following functions with empty implementation. +  void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} +  void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} +  void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} +  void visitUDivExpr(const SCEVUDivExpr *Numerator) {} +  void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} +  void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} +  void visitUnknown(const SCEVUnknown *Numerator) {} +  void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + +  void visitConstant(const SCEVConstant *Numerator) { +    if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { +      APInt NumeratorVal = Numerator->getAPInt(); +      APInt DenominatorVal = D->getAPInt(); +      uint32_t NumeratorBW = NumeratorVal.getBitWidth(); +      uint32_t DenominatorBW = DenominatorVal.getBitWidth(); + +      if (NumeratorBW > DenominatorBW) +        DenominatorVal = DenominatorVal.sext(NumeratorBW); +      else if (NumeratorBW < DenominatorBW) +        NumeratorVal = NumeratorVal.sext(DenominatorBW); + +      APInt QuotientVal(NumeratorVal.getBitWidth(), 0); +      APInt RemainderVal(NumeratorVal.getBitWidth(), 0); +      APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); +      Quotient = SE.getConstant(QuotientVal); +      Remainder = SE.getConstant(RemainderVal); +      return; +    } +  } + +  void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { +    const SCEV *StartQ, *StartR, *StepQ, *StepR; +    if (!Numerator->isAffine()) +      return cannotDivide(Numerator); +    divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); +    divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); +    // Bail out if the types do not match. +    Type *Ty = Denominator->getType(); +    if (Ty != StartQ->getType() || Ty != StartR->getType() || +        Ty != StepQ->getType() || Ty != StepR->getType()) +      return cannotDivide(Numerator); +    Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), +                                Numerator->getNoWrapFlags()); +    Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), +                                 Numerator->getNoWrapFlags()); +  } + +  void visitAddExpr(const SCEVAddExpr *Numerator) { +    SmallVector<const SCEV *, 2> Qs, Rs; +    Type *Ty = Denominator->getType(); + +    for (const SCEV *Op : Numerator->operands()) { +      const SCEV *Q, *R; +      divide(SE, Op, Denominator, &Q, &R); + +      // Bail out if types do not match. +      if (Ty != Q->getType() || Ty != R->getType()) +        return cannotDivide(Numerator); + +      Qs.push_back(Q); +      Rs.push_back(R); +    } + +    if (Qs.size() == 1) { +      Quotient = Qs[0]; +      Remainder = Rs[0]; +      return; +    } + +    Quotient = SE.getAddExpr(Qs); +    Remainder = SE.getAddExpr(Rs); +  } + +  void visitMulExpr(const SCEVMulExpr *Numerator) { +    SmallVector<const SCEV *, 2> Qs; +    Type *Ty = Denominator->getType(); + +    bool FoundDenominatorTerm = false; +    for (const SCEV *Op : Numerator->operands()) { +      // Bail out if types do not match. +      if (Ty != Op->getType()) +        return cannotDivide(Numerator); + +      if (FoundDenominatorTerm) { +        Qs.push_back(Op); +        continue; +      } + +      // Check whether Denominator divides one of the product operands. +      const SCEV *Q, *R; +      divide(SE, Op, Denominator, &Q, &R); +      if (!R->isZero()) { +        Qs.push_back(Op); +        continue; +      } + +      // Bail out if types do not match. +      if (Ty != Q->getType()) +        return cannotDivide(Numerator); + +      FoundDenominatorTerm = true; +      Qs.push_back(Q); +    } + +    if (FoundDenominatorTerm) { +      Remainder = Zero; +      if (Qs.size() == 1) +        Quotient = Qs[0]; +      else +        Quotient = SE.getMulExpr(Qs); +      return; +    } + +    if (!isa<SCEVUnknown>(Denominator)) +      return cannotDivide(Numerator); + +    // The Remainder is obtained by replacing Denominator by 0 in Numerator. +    ValueToValueMap RewriteMap; +    RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = +        cast<SCEVConstant>(Zero)->getValue(); +    Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + +    if (Remainder->isZero()) { +      // The Quotient is obtained by replacing Denominator by 1 in Numerator. +      RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = +          cast<SCEVConstant>(One)->getValue(); +      Quotient = +          SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); +      return; +    } + +    // Quotient is (Numerator - Remainder) divided by Denominator. +    const SCEV *Q, *R; +    const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); +    // This SCEV does not seem to simplify: fail the division here. +    if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) +      return cannotDivide(Numerator); +    divide(SE, Diff, Denominator, &Q, &R); +    if (R != Zero) +      return cannotDivide(Numerator); +    Quotient = Q; +  } + +private: +  SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, +               const SCEV *Denominator) +      : SE(S), Denominator(Denominator) { +    Zero = SE.getZero(Denominator->getType()); +    One = SE.getOne(Denominator->getType()); + +    // We generally do not know how to divide Expr by Denominator. We +    // initialize the division to a "cannot divide" state to simplify the rest +    // of the code. +    cannotDivide(Numerator); +  } + +  // Convenience function for giving up on the division. We set the quotient to +  // be equal to zero and the remainder to be equal to the numerator. +  void cannotDivide(const SCEV *Numerator) { +    Quotient = Zero; +    Remainder = Numerator; +  } + +  ScalarEvolution &SE; +  const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +//                      Simple SCEV method implementations +//===----------------------------------------------------------------------===// + +/// Compute BC(It, K).  The result has width W.  Assume, K > 0. +static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, +                                       ScalarEvolution &SE, +                                       Type *ResultTy) { +  // Handle the simplest case efficiently. +  if (K == 1) +    return SE.getTruncateOrZeroExtend(It, ResultTy); + +  // We are using the following formula for BC(It, K): +  // +  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K! +  // +  // Suppose, W is the bitwidth of the return value.  We must be prepared for +  // overflow.  Hence, we must assure that the result of our computation is +  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't +  // safe in modular arithmetic. +  // +  // However, this code doesn't use exactly that formula; the formula it uses +  // is something like the following, where T is the number of factors of 2 in +  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is +  // exponentiation: +  // +  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T) +  // +  // This formula is trivially equivalent to the previous formula.  However, +  // this formula can be implemented much more efficiently.  The trick is that +  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular +  // arithmetic.  To do exact division in modular arithmetic, all we have +  // to do is multiply by the inverse.  Therefore, this step can be done at +  // width W. +  // +  // The next issue is how to safely do the division by 2^T.  The way this +  // is done is by doing the multiplication step at a width of at least W + T +  // bits.  This way, the bottom W+T bits of the product are accurate. Then, +  // when we perform the division by 2^T (which is equivalent to a right shift +  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get +  // truncated out after the division by 2^T. +  // +  // In comparison to just directly using the first formula, this technique +  // is much more efficient; using the first formula requires W * K bits, +  // but this formula less than W + K bits. Also, the first formula requires +  // a division step, whereas this formula only requires multiplies and shifts. +  // +  // It doesn't matter whether the subtraction step is done in the calculation +  // width or the input iteration count's width; if the subtraction overflows, +  // the result must be zero anyway.  We prefer here to do it in the width of +  // the induction variable because it helps a lot for certain cases; CodeGen +  // isn't smart enough to ignore the overflow, which leads to much less +  // efficient code if the width of the subtraction is wider than the native +  // register width. +  // +  // (It's possible to not widen at all by pulling out factors of 2 before +  // the multiplication; for example, K=2 can be calculated as +  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires +  // extra arithmetic, so it's not an obvious win, and it gets +  // much more complicated for K > 3.) + +  // Protection from insane SCEVs; this bound is conservative, +  // but it probably doesn't matter. +  if (K > 1000) +    return SE.getCouldNotCompute(); + +  unsigned W = SE.getTypeSizeInBits(ResultTy); + +  // Calculate K! / 2^T and T; we divide out the factors of two before +  // multiplying for calculating K! / 2^T to avoid overflow. +  // Other overflow doesn't matter because we only care about the bottom +  // W bits of the result. +  APInt OddFactorial(W, 1); +  unsigned T = 1; +  for (unsigned i = 3; i <= K; ++i) { +    APInt Mult(W, i); +    unsigned TwoFactors = Mult.countTrailingZeros(); +    T += TwoFactors; +    Mult.lshrInPlace(TwoFactors); +    OddFactorial *= Mult; +  } + +  // We need at least W + T bits for the multiplication step +  unsigned CalculationBits = W + T; + +  // Calculate 2^T, at width T+W. +  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T); + +  // Calculate the multiplicative inverse of K! / 2^T; +  // this multiplication factor will perform the exact division by +  // K! / 2^T. +  APInt Mod = APInt::getSignedMinValue(W+1); +  APInt MultiplyFactor = OddFactorial.zext(W+1); +  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod); +  MultiplyFactor = MultiplyFactor.trunc(W); + +  // Calculate the product, at width T+W +  IntegerType *CalculationTy = IntegerType::get(SE.getContext(), +                                                      CalculationBits); +  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy); +  for (unsigned i = 1; i != K; ++i) { +    const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i)); +    Dividend = SE.getMulExpr(Dividend, +                             SE.getTruncateOrZeroExtend(S, CalculationTy)); +  } + +  // Divide by 2^T +  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor)); + +  // Truncate the result, and divide by K! / 2^T. + +  return SE.getMulExpr(SE.getConstant(MultiplyFactor), +                       SE.getTruncateOrZeroExtend(DivResult, ResultTy)); +} + +/// Return the value of this chain of recurrences at the specified iteration +/// number.  We can evaluate this recurrence by multiplying each element in the +/// chain by the binomial coefficient corresponding to it.  In other words, we +/// can evaluate {A,+,B,+,C,+,D} as: +/// +///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) +/// +/// where BC(It, k) stands for binomial coefficient. +const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, +                                                ScalarEvolution &SE) const { +  const SCEV *Result = getStart(); +  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) { +    // The computation is correct in the face of overflow provided that the +    // multiplication is performed _after_ the evaluation of the binomial +    // coefficient. +    const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType()); +    if (isa<SCEVCouldNotCompute>(Coeff)) +      return Coeff; + +    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff)); +  } +  return Result; +} + +//===----------------------------------------------------------------------===// +//                    SCEV Expression folder implementations +//===----------------------------------------------------------------------===// + +const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, +                                             Type *Ty) { +  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && +         "This is not a truncating conversion!"); +  assert(isSCEVable(Ty) && +         "This is not a conversion to a SCEVable type!"); +  Ty = getEffectiveSCEVType(Ty); + +  FoldingSetNodeID ID; +  ID.AddInteger(scTruncate); +  ID.AddPointer(Op); +  ID.AddPointer(Ty); +  void *IP = nullptr; +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; + +  // Fold if the operand is constant. +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) +    return getConstant( +      cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty))); + +  // trunc(trunc(x)) --> trunc(x) +  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) +    return getTruncateExpr(ST->getOperand(), Ty); + +  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing +  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) +    return getTruncateOrSignExtend(SS->getOperand(), Ty); + +  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing +  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) +    return getTruncateOrZeroExtend(SZ->getOperand(), Ty); + +  // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and +  // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN), +  // if after transforming we have at most one truncate, not counting truncates +  // that replace other casts. +  if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) { +    auto *CommOp = cast<SCEVCommutativeExpr>(Op); +    SmallVector<const SCEV *, 4> Operands; +    unsigned numTruncs = 0; +    for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2; +         ++i) { +      const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty); +      if (!isa<SCEVCastExpr>(CommOp->getOperand(i)) && isa<SCEVTruncateExpr>(S)) +        numTruncs++; +      Operands.push_back(S); +    } +    if (numTruncs < 2) { +      if (isa<SCEVAddExpr>(Op)) +        return getAddExpr(Operands); +      else if (isa<SCEVMulExpr>(Op)) +        return getMulExpr(Operands); +      else +        llvm_unreachable("Unexpected SCEV type for Op."); +    } +    // Although we checked in the beginning that ID is not in the cache, it is +    // possible that during recursion and different modification ID was inserted +    // into the cache. So if we find it, just return it. +    if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) +      return S; +  } + +  // If the input value is a chrec scev, truncate the chrec's operands. +  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { +    SmallVector<const SCEV *, 4> Operands; +    for (const SCEV *Op : AddRec->operands()) +      Operands.push_back(getTruncateExpr(Op, Ty)); +    return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); +  } + +  // The cast wasn't folded; create an explicit cast node. We can reuse +  // the existing insert position since if we get here, we won't have +  // made any changes which would invalidate it. +  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), +                                                 Op, Ty); +  UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S); +  return S; +} + +// Get the limit of a recurrence such that incrementing by Step cannot cause +// signed overflow as long as the value of the recurrence within the +// loop does not exceed this limit before incrementing. +static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step, +                                                 ICmpInst::Predicate *Pred, +                                                 ScalarEvolution *SE) { +  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); +  if (SE->isKnownPositive(Step)) { +    *Pred = ICmpInst::ICMP_SLT; +    return SE->getConstant(APInt::getSignedMinValue(BitWidth) - +                           SE->getSignedRangeMax(Step)); +  } +  if (SE->isKnownNegative(Step)) { +    *Pred = ICmpInst::ICMP_SGT; +    return SE->getConstant(APInt::getSignedMaxValue(BitWidth) - +                           SE->getSignedRangeMin(Step)); +  } +  return nullptr; +} + +// Get the limit of a recurrence such that incrementing by Step cannot cause +// unsigned overflow as long as the value of the recurrence within the loop does +// not exceed this limit before incrementing. +static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step, +                                                   ICmpInst::Predicate *Pred, +                                                   ScalarEvolution *SE) { +  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType()); +  *Pred = ICmpInst::ICMP_ULT; + +  return SE->getConstant(APInt::getMinValue(BitWidth) - +                         SE->getUnsignedRangeMax(Step)); +} + +namespace { + +struct ExtendOpTraitsBase { +  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *, +                                                          unsigned); +}; + +// Used to make code generic over signed and unsigned overflow. +template <typename ExtendOp> struct ExtendOpTraits { +  // Members present: +  // +  // static const SCEV::NoWrapFlags WrapType; +  // +  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr; +  // +  // static const SCEV *getOverflowLimitForStep(const SCEV *Step, +  //                                           ICmpInst::Predicate *Pred, +  //                                           ScalarEvolution *SE); +}; + +template <> +struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase { +  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW; + +  static const GetExtendExprTy GetExtendExpr; + +  static const SCEV *getOverflowLimitForStep(const SCEV *Step, +                                             ICmpInst::Predicate *Pred, +                                             ScalarEvolution *SE) { +    return getSignedOverflowLimitForStep(Step, Pred, SE); +  } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< +    SCEVSignExtendExpr>::GetExtendExpr = &ScalarEvolution::getSignExtendExpr; + +template <> +struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { +  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW; + +  static const GetExtendExprTy GetExtendExpr; + +  static const SCEV *getOverflowLimitForStep(const SCEV *Step, +                                             ICmpInst::Predicate *Pred, +                                             ScalarEvolution *SE) { +    return getUnsignedOverflowLimitForStep(Step, Pred, SE); +  } +}; + +const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< +    SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; + +} // end anonymous namespace + +// The recurrence AR has been shown to have no signed/unsigned wrap or something +// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as +// easily prove NSW/NUW for its preincrement or postincrement sibling. This +// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step + +// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the +// expression "Step + sext/zext(PreIncAR)" is congruent with +// "sext/zext(PostIncAR)" +template <typename ExtendOpTy> +static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, +                                        ScalarEvolution *SE, unsigned Depth) { +  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; +  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; + +  const Loop *L = AR->getLoop(); +  const SCEV *Start = AR->getStart(); +  const SCEV *Step = AR->getStepRecurrence(*SE); + +  // Check for a simple looking step prior to loop entry. +  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start); +  if (!SA) +    return nullptr; + +  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV +  // subtraction is expensive. For this purpose, perform a quick and dirty +  // difference, by checking for Step in the operand list. +  SmallVector<const SCEV *, 4> DiffOps; +  for (const SCEV *Op : SA->operands()) +    if (Op != Step) +      DiffOps.push_back(Op); + +  if (DiffOps.size() == SA->getNumOperands()) +    return nullptr; + +  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` + +  // `Step`: + +  // 1. NSW/NUW flags on the step increment. +  auto PreStartFlags = +    ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); +  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); +  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>( +      SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); + +  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies +  // "S+X does not sign/unsign-overflow". +  // + +  const SCEV *BECount = SE->getBackedgeTakenCount(L); +  if (PreAR && PreAR->getNoWrapFlags(WrapType) && +      !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount)) +    return PreStart; + +  // 2. Direct overflow check on the step operation's expression. +  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType()); +  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2); +  const SCEV *OperandExtendedStart = +      SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth), +                     (SE->*GetExtendExpr)(Step, WideTy, Depth)); +  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) { +    if (PreAR && AR->getNoWrapFlags(WrapType)) { +      // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW +      // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then +      // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`.  Cache this fact. +      const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType); +    } +    return PreStart; +  } + +  // 3. Loop precondition. +  ICmpInst::Predicate Pred; +  const SCEV *OverflowLimit = +      ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE); + +  if (OverflowLimit && +      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) +    return PreStart; + +  return nullptr; +} + +// Get the normalized zero or sign extended expression for this AddRec's Start. +template <typename ExtendOpTy> +static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, +                                        ScalarEvolution *SE, +                                        unsigned Depth) { +  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr; + +  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth); +  if (!PreStart) +    return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth); + +  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty, +                                             Depth), +                        (SE->*GetExtendExpr)(PreStart, Ty, Depth)); +} + +// Try to prove away overflow by looking at "nearby" add recurrences.  A +// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it +// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`. +// +// Formally: +// +//     {S,+,X} == {S-T,+,X} + T +//  => Ext({S,+,X}) == Ext({S-T,+,X} + T) +// +// If ({S-T,+,X} + T) does not overflow  ... (1) +// +//  RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T) +// +// If {S-T,+,X} does not overflow  ... (2) +// +//  RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T) +//      == {Ext(S-T)+Ext(T),+,Ext(X)} +// +// If (S-T)+T does not overflow  ... (3) +// +//  RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)} +//      == {Ext(S),+,Ext(X)} == LHS +// +// Thus, if (1), (2) and (3) are true for some T, then +//   Ext({S,+,X}) == {Ext(S),+,Ext(X)} +// +// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T) +// does not overflow" restricted to the 0th iteration.  Therefore we only need +// to check for (1) and (2). +// +// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T +// is `Delta` (defined below). +template <typename ExtendOpTy> +bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, +                                                const SCEV *Step, +                                                const Loop *L) { +  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType; + +  // We restrict `Start` to a constant to prevent SCEV from spending too much +  // time here.  It is correct (but more expensive) to continue with a +  // non-constant `Start` and do a general SCEV subtraction to compute +  // `PreStart` below. +  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start); +  if (!StartC) +    return false; + +  APInt StartAI = StartC->getAPInt(); + +  for (unsigned Delta : {-2, -1, 1, 2}) { +    const SCEV *PreStart = getConstant(StartAI - Delta); + +    FoldingSetNodeID ID; +    ID.AddInteger(scAddRecExpr); +    ID.AddPointer(PreStart); +    ID.AddPointer(Step); +    ID.AddPointer(L); +    void *IP = nullptr; +    const auto *PreAR = +      static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + +    // Give up if we don't already have the add recurrence we need because +    // actually constructing an add recurrence is relatively expensive. +    if (PreAR && PreAR->getNoWrapFlags(WrapType)) {  // proves (2) +      const SCEV *DeltaS = getConstant(StartC->getType(), Delta); +      ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; +      const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep( +          DeltaS, &Pred, this); +      if (Limit && isKnownPredicate(Pred, PreAR, Limit))  // proves (1) +        return true; +    } +  } + +  return false; +} + +// Finds an integer D for an expression (C + x + y + ...) such that the top +// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or +// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is +// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and +// the (C + x + y + ...) expression is \p WholeAddExpr. +static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, +                                            const SCEVConstant *ConstantTerm, +                                            const SCEVAddExpr *WholeAddExpr) { +  const APInt C = ConstantTerm->getAPInt(); +  const unsigned BitWidth = C.getBitWidth(); +  // Find number of trailing zeros of (x + y + ...) w/o the C first: +  uint32_t TZ = BitWidth; +  for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I) +    TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I))); +  if (TZ) { +    // Set D to be as many least significant bits of C as possible while still +    // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap: +    return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C; +  } +  return APInt(BitWidth, 0); +} + +// Finds an integer D for an affine AddRec expression {C,+,x} such that the top +// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the +// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p +// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count. +static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, +                                            const APInt &ConstantStart, +                                            const SCEV *Step) { +  const unsigned BitWidth = ConstantStart.getBitWidth(); +  const uint32_t TZ = SE.GetMinTrailingZeros(Step); +  if (TZ) +    return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth) +                         : ConstantStart; +  return APInt(BitWidth, 0); +} + +const SCEV * +ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { +  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && +         "This is not an extending conversion!"); +  assert(isSCEVable(Ty) && +         "This is not a conversion to a SCEVable type!"); +  Ty = getEffectiveSCEVType(Ty); + +  // Fold if the operand is constant. +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) +    return getConstant( +      cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty))); + +  // zext(zext(x)) --> zext(x) +  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) +    return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); + +  // Before doing any expensive analysis, check to see if we've already +  // computed a SCEV for this Op and Ty. +  FoldingSetNodeID ID; +  ID.AddInteger(scZeroExtend); +  ID.AddPointer(Op); +  ID.AddPointer(Ty); +  void *IP = nullptr; +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  if (Depth > MaxExtDepth) { +    SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), +                                                     Op, Ty); +    UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S); +    return S; +  } + +  // zext(trunc(x)) --> zext(x) or x or trunc(x) +  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { +    // It's possible the bits taken off by the truncate were all zero bits. If +    // so, we should be able to simplify this further. +    const SCEV *X = ST->getOperand(); +    ConstantRange CR = getUnsignedRange(X); +    unsigned TruncBits = getTypeSizeInBits(ST->getType()); +    unsigned NewBits = getTypeSizeInBits(Ty); +    if (CR.truncate(TruncBits).zeroExtend(NewBits).contains( +            CR.zextOrTrunc(NewBits))) +      return getTruncateOrZeroExtend(X, Ty); +  } + +  // If the input value is a chrec scev, and we can prove that the value +  // did not overflow the old, smaller, value, we can zero extend all of the +  // operands (often constants).  This allows analysis of something like +  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; } +  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) +    if (AR->isAffine()) { +      const SCEV *Start = AR->getStart(); +      const SCEV *Step = AR->getStepRecurrence(*this); +      unsigned BitWidth = getTypeSizeInBits(AR->getType()); +      const Loop *L = AR->getLoop(); + +      if (!AR->hasNoUnsignedWrap()) { +        auto NewFlags = proveNoWrapViaConstantRanges(AR); +        const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); +      } + +      // If we have special knowledge that this addrec won't overflow, +      // we don't need to do any further analysis. +      if (AR->hasNoUnsignedWrap()) +        return getAddRecExpr( +            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), +            getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + +      // Check whether the backedge-taken count is SCEVCouldNotCompute. +      // Note that this serves two purposes: It filters out loops that are +      // simply not analyzable, and it covers the case where this code is +      // being called from within backedge-taken count analysis, such that +      // attempting to ask for the backedge-taken count would likely result +      // in infinite recursion. In the later case, the analysis code will +      // cope with a conservative value, and it will take care to purge +      // that value once it has finished. +      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); +      if (!isa<SCEVCouldNotCompute>(MaxBECount)) { +        // Manually compute the final value for AR, checking for +        // overflow. + +        // Check whether the backedge-taken count can be losslessly casted to +        // the addrec's type. The count is always unsigned. +        const SCEV *CastedMaxBECount = +          getTruncateOrZeroExtend(MaxBECount, Start->getType()); +        const SCEV *RecastedMaxBECount = +          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); +        if (MaxBECount == RecastedMaxBECount) { +          Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); +          // Check whether Start+Step*MaxBECount has no unsigned overflow. +          const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step, +                                        SCEV::FlagAnyWrap, Depth + 1); +          const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul, +                                                          SCEV::FlagAnyWrap, +                                                          Depth + 1), +                                               WideTy, Depth + 1); +          const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1); +          const SCEV *WideMaxBECount = +            getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); +          const SCEV *OperandExtendedAdd = +            getAddExpr(WideStart, +                       getMulExpr(WideMaxBECount, +                                  getZeroExtendExpr(Step, WideTy, Depth + 1), +                                  SCEV::FlagAnyWrap, Depth + 1), +                       SCEV::FlagAnyWrap, Depth + 1); +          if (ZAdd == OperandExtendedAdd) { +            // Cache knowledge of AR NUW, which is propagated to this AddRec. +            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); +            // Return the expression with the addrec on the outside. +            return getAddRecExpr( +                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, +                                                         Depth + 1), +                getZeroExtendExpr(Step, Ty, Depth + 1), L, +                AR->getNoWrapFlags()); +          } +          // Similar to above, only this time treat the step value as signed. +          // This covers loops that count down. +          OperandExtendedAdd = +            getAddExpr(WideStart, +                       getMulExpr(WideMaxBECount, +                                  getSignExtendExpr(Step, WideTy, Depth + 1), +                                  SCEV::FlagAnyWrap, Depth + 1), +                       SCEV::FlagAnyWrap, Depth + 1); +          if (ZAdd == OperandExtendedAdd) { +            // Cache knowledge of AR NW, which is propagated to this AddRec. +            // Negative step causes unsigned wrap, but it still can't self-wrap. +            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); +            // Return the expression with the addrec on the outside. +            return getAddRecExpr( +                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, +                                                         Depth + 1), +                getSignExtendExpr(Step, Ty, Depth + 1), L, +                AR->getNoWrapFlags()); +          } +        } +      } + +      // Normally, in the cases we can prove no-overflow via a +      // backedge guarding condition, we can also compute a backedge +      // taken count for the loop.  The exceptions are assumptions and +      // guards present in the loop -- SCEV is not great at exploiting +      // these to compute max backedge taken counts, but can still use +      // these to prove lack of overflow.  Use this fact to avoid +      // doing extra work that may not pay off. +      if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || +          !AC.assumptions().empty()) { +        // If the backedge is guarded by a comparison with the pre-inc +        // value the addrec is safe. Also, if the entry is guarded by +        // a comparison with the start value and the backedge is +        // guarded by a comparison with the post-inc value, the addrec +        // is safe. +        if (isKnownPositive(Step)) { +          const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - +                                      getUnsignedRangeMax(Step)); +          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || +              isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { +            // Cache knowledge of AR NUW, which is propagated to this +            // AddRec. +            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); +            // Return the expression with the addrec on the outside. +            return getAddRecExpr( +                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, +                                                         Depth + 1), +                getZeroExtendExpr(Step, Ty, Depth + 1), L, +                AR->getNoWrapFlags()); +          } +        } else if (isKnownNegative(Step)) { +          const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - +                                      getSignedRangeMin(Step)); +          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || +              isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) { +            // Cache knowledge of AR NW, which is propagated to this +            // AddRec.  Negative step causes unsigned wrap, but it +            // still can't self-wrap. +            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); +            // Return the expression with the addrec on the outside. +            return getAddRecExpr( +                getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, +                                                         Depth + 1), +                getSignExtendExpr(Step, Ty, Depth + 1), L, +                AR->getNoWrapFlags()); +          } +        } +      } + +      // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw> +      // if D + (C - D + Step * n) could be proven to not unsigned wrap +      // where D maximizes the number of trailing zeros of (C - D + Step * n) +      if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { +        const APInt &C = SC->getAPInt(); +        const APInt &D = extractConstantWithoutWrapping(*this, C, Step); +        if (D != 0) { +          const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); +          const SCEV *SResidual = +              getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); +          const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); +          return getAddExpr(SZExtD, SZExtR, +                            (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), +                            Depth + 1); +        } +      } + +      if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { +        const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); +        return getAddRecExpr( +            getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), +            getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); +      } +    } + +  // zext(A % B) --> zext(A) % zext(B) +  { +    const SCEV *LHS; +    const SCEV *RHS; +    if (matchURem(Op, LHS, RHS)) +      return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1), +                         getZeroExtendExpr(RHS, Ty, Depth + 1)); +  } + +  // zext(A / B) --> zext(A) / zext(B). +  if (auto *Div = dyn_cast<SCEVUDivExpr>(Op)) +    return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1), +                       getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1)); + +  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { +    // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> +    if (SA->hasNoUnsignedWrap()) { +      // If the addition does not unsign overflow then we can, by definition, +      // commute the zero extension with the addition operation. +      SmallVector<const SCEV *, 4> Ops; +      for (const auto *Op : SA->operands()) +        Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); +      return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1); +    } + +    // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...)) +    // if D + (C - D + x + y + ...) could be proven to not unsigned wrap +    // where D maximizes the number of trailing zeros of (C - D + x + y + ...) +    // +    // Often address arithmetics contain expressions like +    // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))). +    // This transformation is useful while proving that such expressions are +    // equal or differ by a small constant amount, see LoadStoreVectorizer pass. +    if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { +      const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); +      if (D != 0) { +        const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); +        const SCEV *SResidual = +            getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); +        const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); +        return getAddExpr(SZExtD, SZExtR, +                          (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), +                          Depth + 1); +      } +    } +  } + +  if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) { +    // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw> +    if (SM->hasNoUnsignedWrap()) { +      // If the multiply does not unsign overflow then we can, by definition, +      // commute the zero extension with the multiply operation. +      SmallVector<const SCEV *, 4> Ops; +      for (const auto *Op : SM->operands()) +        Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); +      return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1); +    } + +    // zext(2^K * (trunc X to iN)) to iM -> +    // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw> +    // +    // Proof: +    // +    //     zext(2^K * (trunc X to iN)) to iM +    //   = zext((trunc X to iN) << K) to iM +    //   = zext((trunc X to i{N-K}) << K)<nuw> to iM +    //     (because shl removes the top K bits) +    //   = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM +    //   = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>. +    // +    if (SM->getNumOperands() == 2) +      if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0))) +        if (MulLHS->getAPInt().isPowerOf2()) +          if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) { +            int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) - +                               MulLHS->getAPInt().logBase2(); +            Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); +            return getMulExpr( +                getZeroExtendExpr(MulLHS, Ty), +                getZeroExtendExpr( +                    getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty), +                SCEV::FlagNUW, Depth + 1); +          } +  } + +  // The cast wasn't folded; create an explicit cast node. +  // Recompute the insert position, as it may have been invalidated. +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), +                                                   Op, Ty); +  UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S); +  return S; +} + +const SCEV * +ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { +  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && +         "This is not an extending conversion!"); +  assert(isSCEVable(Ty) && +         "This is not a conversion to a SCEVable type!"); +  Ty = getEffectiveSCEVType(Ty); + +  // Fold if the operand is constant. +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) +    return getConstant( +      cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty))); + +  // sext(sext(x)) --> sext(x) +  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op)) +    return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1); + +  // sext(zext(x)) --> zext(x) +  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) +    return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1); + +  // Before doing any expensive analysis, check to see if we've already +  // computed a SCEV for this Op and Ty. +  FoldingSetNodeID ID; +  ID.AddInteger(scSignExtend); +  ID.AddPointer(Op); +  ID.AddPointer(Ty); +  void *IP = nullptr; +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  // Limit recursion depth. +  if (Depth > MaxExtDepth) { +    SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), +                                                     Op, Ty); +    UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S); +    return S; +  } + +  // sext(trunc(x)) --> sext(x) or x or trunc(x) +  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { +    // It's possible the bits taken off by the truncate were all sign bits. If +    // so, we should be able to simplify this further. +    const SCEV *X = ST->getOperand(); +    ConstantRange CR = getSignedRange(X); +    unsigned TruncBits = getTypeSizeInBits(ST->getType()); +    unsigned NewBits = getTypeSizeInBits(Ty); +    if (CR.truncate(TruncBits).signExtend(NewBits).contains( +            CR.sextOrTrunc(NewBits))) +      return getTruncateOrSignExtend(X, Ty); +  } + +  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { +    // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> +    if (SA->hasNoSignedWrap()) { +      // If the addition does not sign overflow then we can, by definition, +      // commute the sign extension with the addition operation. +      SmallVector<const SCEV *, 4> Ops; +      for (const auto *Op : SA->operands()) +        Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1)); +      return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1); +    } + +    // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...)) +    // if D + (C - D + x + y + ...) could be proven to not signed wrap +    // where D maximizes the number of trailing zeros of (C - D + x + y + ...) +    // +    // For instance, this will bring two seemingly different expressions: +    //     1 + sext(5 + 20 * %x + 24 * %y)  and +    //         sext(6 + 20 * %x + 24 * %y) +    // to the same form: +    //     2 + sext(4 + 20 * %x + 24 * %y) +    if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { +      const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); +      if (D != 0) { +        const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); +        const SCEV *SResidual = +            getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); +        const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); +        return getAddExpr(SSExtD, SSExtR, +                          (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), +                          Depth + 1); +      } +    } +  } +  // If the input value is a chrec scev, and we can prove that the value +  // did not overflow the old, smaller, value, we can sign extend all of the +  // operands (often constants).  This allows analysis of something like +  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; } +  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) +    if (AR->isAffine()) { +      const SCEV *Start = AR->getStart(); +      const SCEV *Step = AR->getStepRecurrence(*this); +      unsigned BitWidth = getTypeSizeInBits(AR->getType()); +      const Loop *L = AR->getLoop(); + +      if (!AR->hasNoSignedWrap()) { +        auto NewFlags = proveNoWrapViaConstantRanges(AR); +        const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); +      } + +      // If we have special knowledge that this addrec won't overflow, +      // we don't need to do any further analysis. +      if (AR->hasNoSignedWrap()) +        return getAddRecExpr( +            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), +            getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW); + +      // Check whether the backedge-taken count is SCEVCouldNotCompute. +      // Note that this serves two purposes: It filters out loops that are +      // simply not analyzable, and it covers the case where this code is +      // being called from within backedge-taken count analysis, such that +      // attempting to ask for the backedge-taken count would likely result +      // in infinite recursion. In the later case, the analysis code will +      // cope with a conservative value, and it will take care to purge +      // that value once it has finished. +      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L); +      if (!isa<SCEVCouldNotCompute>(MaxBECount)) { +        // Manually compute the final value for AR, checking for +        // overflow. + +        // Check whether the backedge-taken count can be losslessly casted to +        // the addrec's type. The count is always unsigned. +        const SCEV *CastedMaxBECount = +          getTruncateOrZeroExtend(MaxBECount, Start->getType()); +        const SCEV *RecastedMaxBECount = +          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType()); +        if (MaxBECount == RecastedMaxBECount) { +          Type *WideTy = IntegerType::get(getContext(), BitWidth * 2); +          // Check whether Start+Step*MaxBECount has no signed overflow. +          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step, +                                        SCEV::FlagAnyWrap, Depth + 1); +          const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul, +                                                          SCEV::FlagAnyWrap, +                                                          Depth + 1), +                                               WideTy, Depth + 1); +          const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1); +          const SCEV *WideMaxBECount = +            getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1); +          const SCEV *OperandExtendedAdd = +            getAddExpr(WideStart, +                       getMulExpr(WideMaxBECount, +                                  getSignExtendExpr(Step, WideTy, Depth + 1), +                                  SCEV::FlagAnyWrap, Depth + 1), +                       SCEV::FlagAnyWrap, Depth + 1); +          if (SAdd == OperandExtendedAdd) { +            // Cache knowledge of AR NSW, which is propagated to this AddRec. +            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); +            // Return the expression with the addrec on the outside. +            return getAddRecExpr( +                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, +                                                         Depth + 1), +                getSignExtendExpr(Step, Ty, Depth + 1), L, +                AR->getNoWrapFlags()); +          } +          // Similar to above, only this time treat the step value as unsigned. +          // This covers loops that count up with an unsigned step. +          OperandExtendedAdd = +            getAddExpr(WideStart, +                       getMulExpr(WideMaxBECount, +                                  getZeroExtendExpr(Step, WideTy, Depth + 1), +                                  SCEV::FlagAnyWrap, Depth + 1), +                       SCEV::FlagAnyWrap, Depth + 1); +          if (SAdd == OperandExtendedAdd) { +            // If AR wraps around then +            // +            //    abs(Step) * MaxBECount > unsigned-max(AR->getType()) +            // => SAdd != OperandExtendedAdd +            // +            // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> +            // (SAdd == OperandExtendedAdd => AR is NW) + +            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); + +            // Return the expression with the addrec on the outside. +            return getAddRecExpr( +                getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, +                                                         Depth + 1), +                getZeroExtendExpr(Step, Ty, Depth + 1), L, +                AR->getNoWrapFlags()); +          } +        } +      } + +      // Normally, in the cases we can prove no-overflow via a +      // backedge guarding condition, we can also compute a backedge +      // taken count for the loop.  The exceptions are assumptions and +      // guards present in the loop -- SCEV is not great at exploiting +      // these to compute max backedge taken counts, but can still use +      // these to prove lack of overflow.  Use this fact to avoid +      // doing extra work that may not pay off. + +      if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || +          !AC.assumptions().empty()) { +        // If the backedge is guarded by a comparison with the pre-inc +        // value the addrec is safe. Also, if the entry is guarded by +        // a comparison with the start value and the backedge is +        // guarded by a comparison with the post-inc value, the addrec +        // is safe. +        ICmpInst::Predicate Pred; +        const SCEV *OverflowLimit = +            getSignedOverflowLimitForStep(Step, &Pred, this); +        if (OverflowLimit && +            (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || +             isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { +          // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. +          const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); +          return getAddRecExpr( +              getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), +              getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); +        } +      } + +      // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw> +      // if D + (C - D + Step * n) could be proven to not signed wrap +      // where D maximizes the number of trailing zeros of (C - D + Step * n) +      if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { +        const APInt &C = SC->getAPInt(); +        const APInt &D = extractConstantWithoutWrapping(*this, C, Step); +        if (D != 0) { +          const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); +          const SCEV *SResidual = +              getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); +          const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); +          return getAddExpr(SSExtD, SSExtR, +                            (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), +                            Depth + 1); +        } +      } + +      if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { +        const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); +        return getAddRecExpr( +            getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), +            getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); +      } +    } + +  // If the input value is provably positive and we could not simplify +  // away the sext build a zext instead. +  if (isKnownNonNegative(Op)) +    return getZeroExtendExpr(Op, Ty, Depth + 1); + +  // The cast wasn't folded; create an explicit cast node. +  // Recompute the insert position, as it may have been invalidated. +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), +                                                   Op, Ty); +  UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S); +  return S; +} + +/// getAnyExtendExpr - Return a SCEV for the given operand extended with +/// unspecified bits out to the given type. +const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, +                                              Type *Ty) { +  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && +         "This is not an extending conversion!"); +  assert(isSCEVable(Ty) && +         "This is not a conversion to a SCEVable type!"); +  Ty = getEffectiveSCEVType(Ty); + +  // Sign-extend negative constants. +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) +    if (SC->getAPInt().isNegative()) +      return getSignExtendExpr(Op, Ty); + +  // Peel off a truncate cast. +  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) { +    const SCEV *NewOp = T->getOperand(); +    if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty)) +      return getAnyExtendExpr(NewOp, Ty); +    return getTruncateOrNoop(NewOp, Ty); +  } + +  // Next try a zext cast. If the cast is folded, use it. +  const SCEV *ZExt = getZeroExtendExpr(Op, Ty); +  if (!isa<SCEVZeroExtendExpr>(ZExt)) +    return ZExt; + +  // Next try a sext cast. If the cast is folded, use it. +  const SCEV *SExt = getSignExtendExpr(Op, Ty); +  if (!isa<SCEVSignExtendExpr>(SExt)) +    return SExt; + +  // Force the cast to be folded into the operands of an addrec. +  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) { +    SmallVector<const SCEV *, 4> Ops; +    for (const SCEV *Op : AR->operands()) +      Ops.push_back(getAnyExtendExpr(Op, Ty)); +    return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW); +  } + +  // If the expression is obviously signed, use the sext cast value. +  if (isa<SCEVSMaxExpr>(Op)) +    return SExt; + +  // Absent any other information, use the zext cast value. +  return ZExt; +} + +/// Process the given Ops list, which is a list of operands to be added under +/// the given scale, update the given map. This is a helper function for +/// getAddRecExpr. As an example of what it does, given a sequence of operands +/// that would form an add expression like this: +/// +///    m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) +/// +/// where A and B are constants, update the map with these values: +/// +///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0) +/// +/// and add 13 + A*B*29 to AccumulatedConstant. +/// This will allow getAddRecExpr to produce this: +/// +///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B) +/// +/// This form often exposes folding opportunities that are hidden in +/// the original operand list. +/// +/// Return true iff it appears that any interesting folding opportunities +/// may be exposed. This helps getAddRecExpr short-circuit extra work in +/// the common case where no interesting opportunities are present, and +/// is also used as a check to avoid infinite recursion. +static bool +CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, +                             SmallVectorImpl<const SCEV *> &NewOps, +                             APInt &AccumulatedConstant, +                             const SCEV *const *Ops, size_t NumOperands, +                             const APInt &Scale, +                             ScalarEvolution &SE) { +  bool Interesting = false; + +  // Iterate over the add operands. They are sorted, with constants first. +  unsigned i = 0; +  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { +    ++i; +    // Pull a buried constant out to the outside. +    if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero()) +      Interesting = true; +    AccumulatedConstant += Scale * C->getAPInt(); +  } + +  // Next comes everything else. We're especially interested in multiplies +  // here, but they're in the middle, so just visit the rest with one loop. +  for (; i != NumOperands; ++i) { +    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]); +    if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) { +      APInt NewScale = +          Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt(); +      if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) { +        // A multiplication of a constant with another add; recurse. +        const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1)); +        Interesting |= +          CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, +                                       Add->op_begin(), Add->getNumOperands(), +                                       NewScale, SE); +      } else { +        // A multiplication of a constant with some other value. Update +        // the map. +        SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); +        const SCEV *Key = SE.getMulExpr(MulOps); +        auto Pair = M.insert({Key, NewScale}); +        if (Pair.second) { +          NewOps.push_back(Pair.first->first); +        } else { +          Pair.first->second += NewScale; +          // The map already had an entry for this value, which may indicate +          // a folding opportunity. +          Interesting = true; +        } +      } +    } else { +      // An ordinary operand. Update the map. +      std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = +          M.insert({Ops[i], Scale}); +      if (Pair.second) { +        NewOps.push_back(Pair.first->first); +      } else { +        Pair.first->second += Scale; +        // The map already had an entry for this value, which may indicate +        // a folding opportunity. +        Interesting = true; +      } +    } +  } + +  return Interesting; +} + +// We're trying to construct a SCEV of type `Type' with `Ops' as operands and +// `OldFlags' as can't-wrap behavior.  Infer a more aggressive set of +// can't-overflow flags for the operation if possible. +static SCEV::NoWrapFlags +StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, +                      const SmallVectorImpl<const SCEV *> &Ops, +                      SCEV::NoWrapFlags Flags) { +  using namespace std::placeholders; + +  using OBO = OverflowingBinaryOperator; + +  bool CanAnalyze = +      Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; +  (void)CanAnalyze; +  assert(CanAnalyze && "don't call from other places!"); + +  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; +  SCEV::NoWrapFlags SignOrUnsignWrap = +      ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + +  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. +  auto IsKnownNonNegative = [&](const SCEV *S) { +    return SE->isKnownNonNegative(S); +  }; + +  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) +    Flags = +        ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); + +  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + +  if (SignOrUnsignWrap != SignOrUnsignMask && +      (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 && +      isa<SCEVConstant>(Ops[0])) { + +    auto Opcode = [&] { +      switch (Type) { +      case scAddExpr: +        return Instruction::Add; +      case scMulExpr: +        return Instruction::Mul; +      default: +        llvm_unreachable("Unexpected SCEV op."); +      } +    }(); + +    const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); + +    // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow. +    if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { +      auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( +          Opcode, C, OBO::NoSignedWrap); +      if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) +        Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); +    } + +    // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow. +    if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { +      auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( +          Opcode, C, OBO::NoUnsignedWrap); +      if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) +        Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); +    } +  } + +  return Flags; +} + +bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { +  return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader()); +} + +/// Get a canonical add expression, or something simpler if possible. +const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, +                                        SCEV::NoWrapFlags Flags, +                                        unsigned Depth) { +  assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && +         "only nuw or nsw allowed"); +  assert(!Ops.empty() && "Cannot get empty add!"); +  if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG +  Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); +  for (unsigned i = 1, e = Ops.size(); i != e; ++i) +    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && +           "SCEVAddExpr operand types don't match!"); +#endif + +  // Sort by complexity, this groups all similar expression types together. +  GroupByComplexity(Ops, &LI, DT); + +  Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); + +  // If there are any constants, fold them together. +  unsigned Idx = 0; +  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { +    ++Idx; +    assert(Idx < Ops.size()); +    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { +      // We found two constants, fold them together! +      Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt()); +      if (Ops.size() == 2) return Ops[0]; +      Ops.erase(Ops.begin()+1);  // Erase the folded element +      LHSC = cast<SCEVConstant>(Ops[0]); +    } + +    // If we are left with a constant zero being added, strip it off. +    if (LHSC->getValue()->isZero()) { +      Ops.erase(Ops.begin()); +      --Idx; +    } + +    if (Ops.size() == 1) return Ops[0]; +  } + +  // Limit recursion calls depth. +  if (Depth > MaxArithDepth) +    return getOrCreateAddExpr(Ops, Flags); + +  // Okay, check to see if the same value occurs in the operand list more than +  // once.  If so, merge them together into an multiply expression.  Since we +  // sorted the list, these values are required to be adjacent. +  Type *Ty = Ops[0]->getType(); +  bool FoundMatch = false; +  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i) +    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2 +      // Scan ahead to count how many equal operands there are. +      unsigned Count = 2; +      while (i+Count != e && Ops[i+Count] == Ops[i]) +        ++Count; +      // Merge the values into a multiply. +      const SCEV *Scale = getConstant(Ty, Count); +      const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1); +      if (Ops.size() == Count) +        return Mul; +      Ops[i] = Mul; +      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count); +      --i; e -= Count - 1; +      FoundMatch = true; +    } +  if (FoundMatch) +    return getAddExpr(Ops, Flags, Depth + 1); + +  // Check for truncates. If all the operands are truncated from the same +  // type, see if factoring out the truncate would permit the result to be +  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y) +  // if the contents of the resulting outer trunc fold to something simple. +  auto FindTruncSrcType = [&]() -> Type * { +    // We're ultimately looking to fold an addrec of truncs and muls of only +    // constants and truncs, so if we find any other types of SCEV +    // as operands of the addrec then we bail and return nullptr here. +    // Otherwise, we return the type of the operand of a trunc that we find. +    if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx])) +      return T->getOperand()->getType(); +    if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { +      const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1); +      if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp)) +        return T->getOperand()->getType(); +    } +    return nullptr; +  }; +  if (auto *SrcType = FindTruncSrcType()) { +    SmallVector<const SCEV *, 8> LargeOps; +    bool Ok = true; +    // Check all the operands to see if they can be represented in the +    // source type of the truncate. +    for (unsigned i = 0, e = Ops.size(); i != e; ++i) { +      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) { +        if (T->getOperand()->getType() != SrcType) { +          Ok = false; +          break; +        } +        LargeOps.push_back(T->getOperand()); +      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) { +        LargeOps.push_back(getAnyExtendExpr(C, SrcType)); +      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) { +        SmallVector<const SCEV *, 8> LargeMulOps; +        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) { +          if (const SCEVTruncateExpr *T = +                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) { +            if (T->getOperand()->getType() != SrcType) { +              Ok = false; +              break; +            } +            LargeMulOps.push_back(T->getOperand()); +          } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) { +            LargeMulOps.push_back(getAnyExtendExpr(C, SrcType)); +          } else { +            Ok = false; +            break; +          } +        } +        if (Ok) +          LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1)); +      } else { +        Ok = false; +        break; +      } +    } +    if (Ok) { +      // Evaluate the expression in the larger type. +      const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); +      // If it folds to something simple, use it. Otherwise, don't. +      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) +        return getTruncateExpr(Fold, Ty); +    } +  } + +  // Skip past any other cast SCEVs. +  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr) +    ++Idx; + +  // If there are add operands they would be next. +  if (Idx < Ops.size()) { +    bool DeletedAdd = false; +    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) { +      if (Ops.size() > AddOpsInlineThreshold || +          Add->getNumOperands() > AddOpsInlineThreshold) +        break; +      // If we have an add, expand the add operands onto the end of the operands +      // list. +      Ops.erase(Ops.begin()+Idx); +      Ops.append(Add->op_begin(), Add->op_end()); +      DeletedAdd = true; +    } + +    // If we deleted at least one add, we added operands to the end of the list, +    // and they are not necessarily sorted.  Recurse to resort and resimplify +    // any operands we just acquired. +    if (DeletedAdd) +      return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +  } + +  // Skip over the add expression until we get to a multiply. +  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) +    ++Idx; + +  // Check to see if there are any folding opportunities present with +  // operands multiplied by constant values. +  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) { +    uint64_t BitWidth = getTypeSizeInBits(Ty); +    DenseMap<const SCEV *, APInt> M; +    SmallVector<const SCEV *, 8> NewOps; +    APInt AccumulatedConstant(BitWidth, 0); +    if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant, +                                     Ops.data(), Ops.size(), +                                     APInt(BitWidth, 1), *this)) { +      struct APIntCompare { +        bool operator()(const APInt &LHS, const APInt &RHS) const { +          return LHS.ult(RHS); +        } +      }; + +      // Some interesting folding opportunity is present, so its worthwhile to +      // re-generate the operands list. Group the operands by constant scale, +      // to avoid multiplying by the same constant scale multiple times. +      std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists; +      for (const SCEV *NewOp : NewOps) +        MulOpLists[M.find(NewOp)->second].push_back(NewOp); +      // Re-generate the operands list. +      Ops.clear(); +      if (AccumulatedConstant != 0) +        Ops.push_back(getConstant(AccumulatedConstant)); +      for (auto &MulOp : MulOpLists) +        if (MulOp.first != 0) +          Ops.push_back(getMulExpr( +              getConstant(MulOp.first), +              getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1), +              SCEV::FlagAnyWrap, Depth + 1)); +      if (Ops.empty()) +        return getZero(Ty); +      if (Ops.size() == 1) +        return Ops[0]; +      return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +    } +  } + +  // If we are adding something to a multiply expression, make sure the +  // something is not already an operand of the multiply.  If so, merge it into +  // the multiply. +  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) { +    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]); +    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) { +      const SCEV *MulOpSCEV = Mul->getOperand(MulOp); +      if (isa<SCEVConstant>(MulOpSCEV)) +        continue; +      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp) +        if (MulOpSCEV == Ops[AddOp]) { +          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1)) +          const SCEV *InnerMul = Mul->getOperand(MulOp == 0); +          if (Mul->getNumOperands() != 2) { +            // If the multiply has more than two operands, we must get the +            // Y*Z term. +            SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), +                                                Mul->op_begin()+MulOp); +            MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); +            InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); +          } +          SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul}; +          const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); +          const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV, +                                            SCEV::FlagAnyWrap, Depth + 1); +          if (Ops.size() == 2) return OuterMul; +          if (AddOp < Idx) { +            Ops.erase(Ops.begin()+AddOp); +            Ops.erase(Ops.begin()+Idx-1); +          } else { +            Ops.erase(Ops.begin()+Idx); +            Ops.erase(Ops.begin()+AddOp-1); +          } +          Ops.push_back(OuterMul); +          return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +        } + +      // Check this multiply against other multiplies being added together. +      for (unsigned OtherMulIdx = Idx+1; +           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]); +           ++OtherMulIdx) { +        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]); +        // If MulOp occurs in OtherMul, we can fold the two multiplies +        // together. +        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands(); +             OMulOp != e; ++OMulOp) +          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) { +            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E)) +            const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0); +            if (Mul->getNumOperands() != 2) { +              SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(), +                                                  Mul->op_begin()+MulOp); +              MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end()); +              InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); +            } +            const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0); +            if (OtherMul->getNumOperands() != 2) { +              SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(), +                                                  OtherMul->op_begin()+OMulOp); +              MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end()); +              InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1); +            } +            SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2}; +            const SCEV *InnerMulSum = +                getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); +            const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum, +                                              SCEV::FlagAnyWrap, Depth + 1); +            if (Ops.size() == 2) return OuterMul; +            Ops.erase(Ops.begin()+Idx); +            Ops.erase(Ops.begin()+OtherMulIdx-1); +            Ops.push_back(OuterMul); +            return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +          } +      } +    } +  } + +  // If there are any add recurrences in the operands list, see if any other +  // added values are loop invariant.  If so, we can fold them into the +  // recurrence. +  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) +    ++Idx; + +  // Scan over all recurrences, trying to fold loop invariants into them. +  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { +    // Scan all of the other operands to this add and add them to the vector if +    // they are loop invariant w.r.t. the recurrence. +    SmallVector<const SCEV *, 8> LIOps; +    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); +    const Loop *AddRecLoop = AddRec->getLoop(); +    for (unsigned i = 0, e = Ops.size(); i != e; ++i) +      if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { +        LIOps.push_back(Ops[i]); +        Ops.erase(Ops.begin()+i); +        --i; --e; +      } + +    // If we found some loop invariants, fold them into the recurrence. +    if (!LIOps.empty()) { +      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step} +      LIOps.push_back(AddRec->getStart()); + +      SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), +                                             AddRec->op_end()); +      // This follows from the fact that the no-wrap flags on the outer add +      // expression are applicable on the 0th iteration, when the add recurrence +      // will be equal to its start value. +      AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1); + +      // Build the new addrec. Propagate the NUW and NSW flags if both the +      // outer add and the inner addrec are guaranteed to have no overflow. +      // Always propagate NW. +      Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); +      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); + +      // If all of the other operands were loop invariant, we are done. +      if (Ops.size() == 1) return NewRec; + +      // Otherwise, add the folded AddRec by the non-invariant parts. +      for (unsigned i = 0;; ++i) +        if (Ops[i] == AddRec) { +          Ops[i] = NewRec; +          break; +        } +      return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +    } + +    // Okay, if there weren't any loop invariants to be folded, check to see if +    // there are multiple AddRec's with the same loop induction variable being +    // added together.  If so, we can fold them. +    for (unsigned OtherIdx = Idx+1; +         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); +         ++OtherIdx) { +      // We expect the AddRecExpr's to be sorted in reverse dominance order, +      // so that the 1st found AddRecExpr is dominated by all others. +      assert(DT.dominates( +           cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(), +           AddRec->getLoop()->getHeader()) && +        "AddRecExprs are not sorted in reverse dominance order?"); +      if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { +        // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L> +        SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), +                                               AddRec->op_end()); +        for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); +             ++OtherIdx) { +          const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); +          if (OtherAddRec->getLoop() == AddRecLoop) { +            for (unsigned i = 0, e = OtherAddRec->getNumOperands(); +                 i != e; ++i) { +              if (i >= AddRecOps.size()) { +                AddRecOps.append(OtherAddRec->op_begin()+i, +                                 OtherAddRec->op_end()); +                break; +              } +              SmallVector<const SCEV *, 2> TwoOps = { +                  AddRecOps[i], OtherAddRec->getOperand(i)}; +              AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1); +            } +            Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; +          } +        } +        // Step size has changed, so we cannot guarantee no self-wraparound. +        Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap); +        return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +      } +    } + +    // Otherwise couldn't fold anything into this recurrence.  Move onto the +    // next one. +  } + +  // Okay, it looks like we really DO need an add expr.  Check to see if we +  // already have one, otherwise create a new one. +  return getOrCreateAddExpr(Ops, Flags); +} + +const SCEV * +ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, +                                    SCEV::NoWrapFlags Flags) { +  FoldingSetNodeID ID; +  ID.AddInteger(scAddExpr); +  for (const SCEV *Op : Ops) +    ID.AddPointer(Op); +  void *IP = nullptr; +  SCEVAddExpr *S = +      static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); +  if (!S) { +    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); +    std::uninitialized_copy(Ops.begin(), Ops.end(), O); +    S = new (SCEVAllocator) +        SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); +    UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S); +  } +  S->setNoWrapFlags(Flags); +  return S; +} + +const SCEV * +ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops, +                                    SCEV::NoWrapFlags Flags) { +  FoldingSetNodeID ID; +  ID.AddInteger(scMulExpr); +  for (unsigned i = 0, e = Ops.size(); i != e; ++i) +    ID.AddPointer(Ops[i]); +  void *IP = nullptr; +  SCEVMulExpr *S = +    static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); +  if (!S) { +    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); +    std::uninitialized_copy(Ops.begin(), Ops.end(), O); +    S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), +                                        O, Ops.size()); +    UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S); +  } +  S->setNoWrapFlags(Flags); +  return S; +} + +static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) { +  uint64_t k = i*j; +  if (j > 1 && k / j != i) Overflow = true; +  return k; +} + +/// Compute the result of "n choose k", the binomial coefficient.  If an +/// intermediate computation overflows, Overflow will be set and the return will +/// be garbage. Overflow is not cleared on absence of overflow. +static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { +  // We use the multiplicative formula: +  //     n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 . +  // At each iteration, we take the n-th term of the numeral and divide by the +  // (k-n)th term of the denominator.  This division will always produce an +  // integral result, and helps reduce the chance of overflow in the +  // intermediate computations. However, we can still overflow even when the +  // final result would fit. + +  if (n == 0 || n == k) return 1; +  if (k > n) return 0; + +  if (k > n/2) +    k = n-k; + +  uint64_t r = 1; +  for (uint64_t i = 1; i <= k; ++i) { +    r = umul_ov(r, n-(i-1), Overflow); +    r /= i; +  } +  return r; +} + +/// Determine if any of the operands in this SCEV are a constant or if +/// any of the add or multiply expressions in this SCEV contain a constant. +static bool containsConstantInAddMulChain(const SCEV *StartExpr) { +  struct FindConstantInAddMulChain { +    bool FoundConstant = false; + +    bool follow(const SCEV *S) { +      FoundConstant |= isa<SCEVConstant>(S); +      return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S); +    } + +    bool isDone() const { +      return FoundConstant; +    } +  }; + +  FindConstantInAddMulChain F; +  SCEVTraversal<FindConstantInAddMulChain> ST(F); +  ST.visitAll(StartExpr); +  return F.FoundConstant; +} + +/// Get a canonical multiply expression, or something simpler if possible. +const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, +                                        SCEV::NoWrapFlags Flags, +                                        unsigned Depth) { +  assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && +         "only nuw or nsw allowed"); +  assert(!Ops.empty() && "Cannot get empty mul!"); +  if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG +  Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); +  for (unsigned i = 1, e = Ops.size(); i != e; ++i) +    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && +           "SCEVMulExpr operand types don't match!"); +#endif + +  // Sort by complexity, this groups all similar expression types together. +  GroupByComplexity(Ops, &LI, DT); + +  Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); + +  // Limit recursion calls depth. +  if (Depth > MaxArithDepth) +    return getOrCreateMulExpr(Ops, Flags); + +  // If there are any constants, fold them together. +  unsigned Idx = 0; +  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + +    if (Ops.size() == 2) +      // C1*(C2+V) -> C1*C2 + C1*V +      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) +        // If any of Add's ops are Adds or Muls with a constant, apply this +        // transformation as well. +        // +        // TODO: There are some cases where this transformation is not +        // profitable; for example, Add = (C0 + X) * Y + Z.  Maybe the scope of +        // this transformation should be narrowed down. +        if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) +          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0), +                                       SCEV::FlagAnyWrap, Depth + 1), +                            getMulExpr(LHSC, Add->getOperand(1), +                                       SCEV::FlagAnyWrap, Depth + 1), +                            SCEV::FlagAnyWrap, Depth + 1); + +    ++Idx; +    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { +      // We found two constants, fold them together! +      ConstantInt *Fold = +          ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt()); +      Ops[0] = getConstant(Fold); +      Ops.erase(Ops.begin()+1);  // Erase the folded element +      if (Ops.size() == 1) return Ops[0]; +      LHSC = cast<SCEVConstant>(Ops[0]); +    } + +    // If we are left with a constant one being multiplied, strip it off. +    if (cast<SCEVConstant>(Ops[0])->getValue()->isOne()) { +      Ops.erase(Ops.begin()); +      --Idx; +    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) { +      // If we have a multiply of zero, it will always be zero. +      return Ops[0]; +    } else if (Ops[0]->isAllOnesValue()) { +      // If we have a mul by -1 of an add, try distributing the -1 among the +      // add operands. +      if (Ops.size() == 2) { +        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) { +          SmallVector<const SCEV *, 4> NewOps; +          bool AnyFolded = false; +          for (const SCEV *AddOp : Add->operands()) { +            const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap, +                                         Depth + 1); +            if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true; +            NewOps.push_back(Mul); +          } +          if (AnyFolded) +            return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1); +        } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) { +          // Negation preserves a recurrence's no self-wrap property. +          SmallVector<const SCEV *, 4> Operands; +          for (const SCEV *AddRecOp : AddRec->operands()) +            Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap, +                                          Depth + 1)); + +          return getAddRecExpr(Operands, AddRec->getLoop(), +                               AddRec->getNoWrapFlags(SCEV::FlagNW)); +        } +      } +    } + +    if (Ops.size() == 1) +      return Ops[0]; +  } + +  // Skip over the add expression until we get to a multiply. +  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr) +    ++Idx; + +  // If there are mul operands inline them all into this expression. +  if (Idx < Ops.size()) { +    bool DeletedMul = false; +    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { +      if (Ops.size() > MulOpsInlineThreshold) +        break; +      // If we have an mul, expand the mul operands onto the end of the +      // operands list. +      Ops.erase(Ops.begin()+Idx); +      Ops.append(Mul->op_begin(), Mul->op_end()); +      DeletedMul = true; +    } + +    // If we deleted at least one mul, we added operands to the end of the +    // list, and they are not necessarily sorted.  Recurse to resort and +    // resimplify any operands we just acquired. +    if (DeletedMul) +      return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +  } + +  // If there are any add recurrences in the operands list, see if any other +  // added values are loop invariant.  If so, we can fold them into the +  // recurrence. +  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr) +    ++Idx; + +  // Scan over all recurrences, trying to fold loop invariants into them. +  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) { +    // Scan all of the other operands to this mul and add them to the vector +    // if they are loop invariant w.r.t. the recurrence. +    SmallVector<const SCEV *, 8> LIOps; +    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]); +    const Loop *AddRecLoop = AddRec->getLoop(); +    for (unsigned i = 0, e = Ops.size(); i != e; ++i) +      if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) { +        LIOps.push_back(Ops[i]); +        Ops.erase(Ops.begin()+i); +        --i; --e; +      } + +    // If we found some loop invariants, fold them into the recurrence. +    if (!LIOps.empty()) { +      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step} +      SmallVector<const SCEV *, 4> NewOps; +      NewOps.reserve(AddRec->getNumOperands()); +      const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1); +      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) +        NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i), +                                    SCEV::FlagAnyWrap, Depth + 1)); + +      // Build the new addrec. Propagate the NUW and NSW flags if both the +      // outer mul and the inner addrec are guaranteed to have no overflow. +      // +      // No self-wrap cannot be guaranteed after changing the step size, but +      // will be inferred if either NUW or NSW is true. +      Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW)); +      const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags); + +      // If all of the other operands were loop invariant, we are done. +      if (Ops.size() == 1) return NewRec; + +      // Otherwise, multiply the folded AddRec by the non-invariant parts. +      for (unsigned i = 0;; ++i) +        if (Ops[i] == AddRec) { +          Ops[i] = NewRec; +          break; +        } +      return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); +    } + +    // Okay, if there weren't any loop invariants to be folded, check to see +    // if there are multiple AddRec's with the same loop induction variable +    // being multiplied together.  If so, we can fold them. + +    // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> +    // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ +    //       choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z +    //   ]]],+,...up to x=2n}. +    // Note that the arguments to choose() are always integers with values +    // known at compile time, never SCEV objects. +    // +    // The implementation avoids pointless extra computations when the two +    // addrec's are of different length (mathematically, it's equivalent to +    // an infinite stream of zeros on the right). +    bool OpsModified = false; +    for (unsigned OtherIdx = Idx+1; +         OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); +         ++OtherIdx) { +      const SCEVAddRecExpr *OtherAddRec = +        dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); +      if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) +        continue; + +      // Limit max number of arguments to avoid creation of unreasonably big +      // SCEVAddRecs with very complex operands. +      if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 > +          MaxAddRecSize) +        continue; + +      bool Overflow = false; +      Type *Ty = AddRec->getType(); +      bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; +      SmallVector<const SCEV*, 7> AddRecOps; +      for (int x = 0, xe = AddRec->getNumOperands() + +             OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { +        const SCEV *Term = getZero(Ty); +        for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { +          uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); +          for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), +                 ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); +               z < ze && !Overflow; ++z) { +            uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); +            uint64_t Coeff; +            if (LargerThan64Bits) +              Coeff = umul_ov(Coeff1, Coeff2, Overflow); +            else +              Coeff = Coeff1*Coeff2; +            const SCEV *CoeffTerm = getConstant(Ty, Coeff); +            const SCEV *Term1 = AddRec->getOperand(y-z); +            const SCEV *Term2 = OtherAddRec->getOperand(z); +            Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1, Term2, +                                               SCEV::FlagAnyWrap, Depth + 1), +                              SCEV::FlagAnyWrap, Depth + 1); +          } +        } +        AddRecOps.push_back(Term); +      } +      if (!Overflow) { +        const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), +                                              SCEV::FlagAnyWrap); +        if (Ops.size() == 2) return NewAddRec; +        Ops[Idx] = NewAddRec; +        Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; +        OpsModified = true; +        AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); +        if (!AddRec) +          break; +      } +    } +    if (OpsModified) +      return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1); + +    // Otherwise couldn't fold anything into this recurrence.  Move onto the +    // next one. +  } + +  // Okay, it looks like we really DO need an mul expr.  Check to see if we +  // already have one, otherwise create a new one. +  return getOrCreateMulExpr(Ops, Flags); +} + +/// Represents an unsigned remainder expression based on unsigned division. +const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, +                                         const SCEV *RHS) { +  assert(getEffectiveSCEVType(LHS->getType()) == +         getEffectiveSCEVType(RHS->getType()) && +         "SCEVURemExpr operand types don't match!"); + +  // Short-circuit easy cases +  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { +    // If constant is one, the result is trivial +    if (RHSC->getValue()->isOne()) +      return getZero(LHS->getType()); // X urem 1 --> 0 + +    // If constant is a power of two, fold into a zext(trunc(LHS)). +    if (RHSC->getAPInt().isPowerOf2()) { +      Type *FullTy = LHS->getType(); +      Type *TruncTy = +          IntegerType::get(getContext(), RHSC->getAPInt().logBase2()); +      return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy); +    } +  } + +  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y) +  const SCEV *UDiv = getUDivExpr(LHS, RHS); +  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); +  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW); +} + +/// Get a canonical unsigned division expression, or something simpler if +/// possible. +const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, +                                         const SCEV *RHS) { +  assert(getEffectiveSCEVType(LHS->getType()) == +         getEffectiveSCEVType(RHS->getType()) && +         "SCEVUDivExpr operand types don't match!"); + +  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { +    if (RHSC->getValue()->isOne()) +      return LHS;                               // X udiv 1 --> x +    // If the denominator is zero, the result of the udiv is undefined. Don't +    // try to analyze it, because the resolution chosen here may differ from +    // the resolution chosen in other parts of the compiler. +    if (!RHSC->getValue()->isZero()) { +      // Determine if the division can be folded into the operands of +      // its operands. +      // TODO: Generalize this to non-constants by using known-bits information. +      Type *Ty = LHS->getType(); +      unsigned LZ = RHSC->getAPInt().countLeadingZeros(); +      unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1; +      // For non-power-of-two values, effectively round the value up to the +      // nearest power of two. +      if (!RHSC->getAPInt().isPowerOf2()) +        ++MaxShiftAmt; +      IntegerType *ExtTy = +        IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt); +      if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) +        if (const SCEVConstant *Step = +            dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) { +          // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded. +          const APInt &StepInt = Step->getAPInt(); +          const APInt &DivInt = RHSC->getAPInt(); +          if (!StepInt.urem(DivInt) && +              getZeroExtendExpr(AR, ExtTy) == +              getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), +                            getZeroExtendExpr(Step, ExtTy), +                            AR->getLoop(), SCEV::FlagAnyWrap)) { +            SmallVector<const SCEV *, 4> Operands; +            for (const SCEV *Op : AR->operands()) +              Operands.push_back(getUDivExpr(Op, RHS)); +            return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW); +          } +          /// Get a canonical UDivExpr for a recurrence. +          /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0. +          // We can currently only fold X%N if X is constant. +          const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart()); +          if (StartC && !DivInt.urem(StepInt) && +              getZeroExtendExpr(AR, ExtTy) == +              getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy), +                            getZeroExtendExpr(Step, ExtTy), +                            AR->getLoop(), SCEV::FlagAnyWrap)) { +            const APInt &StartInt = StartC->getAPInt(); +            const APInt &StartRem = StartInt.urem(StepInt); +            if (StartRem != 0) +              LHS = getAddRecExpr(getConstant(StartInt - StartRem), Step, +                                  AR->getLoop(), SCEV::FlagNW); +          } +        } +      // (A*B)/C --> A*(B/C) if safe and B/C can be folded. +      if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) { +        SmallVector<const SCEV *, 4> Operands; +        for (const SCEV *Op : M->operands()) +          Operands.push_back(getZeroExtendExpr(Op, ExtTy)); +        if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) +          // Find an operand that's safely divisible. +          for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) { +            const SCEV *Op = M->getOperand(i); +            const SCEV *Div = getUDivExpr(Op, RHSC); +            if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { +              Operands = SmallVector<const SCEV *, 4>(M->op_begin(), +                                                      M->op_end()); +              Operands[i] = Div; +              return getMulExpr(Operands); +            } +          } +      } + +      // (A/B)/C --> A/(B*C) if safe and B*C can be folded. +      if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) { +        if (auto *DivisorConstant = +                dyn_cast<SCEVConstant>(OtherDiv->getRHS())) { +          bool Overflow = false; +          APInt NewRHS = +              DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow); +          if (Overflow) { +            return getConstant(RHSC->getType(), 0, false); +          } +          return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS)); +        } +      } + +      // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. +      if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) { +        SmallVector<const SCEV *, 4> Operands; +        for (const SCEV *Op : A->operands()) +          Operands.push_back(getZeroExtendExpr(Op, ExtTy)); +        if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) { +          Operands.clear(); +          for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) { +            const SCEV *Op = getUDivExpr(A->getOperand(i), RHS); +            if (isa<SCEVUDivExpr>(Op) || +                getMulExpr(Op, RHS) != A->getOperand(i)) +              break; +            Operands.push_back(Op); +          } +          if (Operands.size() == A->getNumOperands()) +            return getAddExpr(Operands); +        } +      } + +      // Fold if both operands are constant. +      if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { +        Constant *LHSCV = LHSC->getValue(); +        Constant *RHSCV = RHSC->getValue(); +        return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV, +                                                                   RHSCV))); +      } +    } +  } + +  FoldingSetNodeID ID; +  ID.AddInteger(scUDivExpr); +  ID.AddPointer(LHS); +  ID.AddPointer(RHS); +  void *IP = nullptr; +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), +                                             LHS, RHS); +  UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S); +  return S; +} + +static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { +  APInt A = C1->getAPInt().abs(); +  APInt B = C2->getAPInt().abs(); +  uint32_t ABW = A.getBitWidth(); +  uint32_t BBW = B.getBitWidth(); + +  if (ABW > BBW) +    B = B.zext(ABW); +  else if (ABW < BBW) +    A = A.zext(BBW); + +  return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B)); +} + +/// Get a canonical unsigned division expression, or something simpler if +/// possible. There is no representation for an exact udiv in SCEV IR, but we +/// can attempt to remove factors from the LHS and RHS.  We can't do this when +/// it's not exact because the udiv may be clearing bits. +const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, +                                              const SCEV *RHS) { +  // TODO: we could try to find factors in all sorts of things, but for now we +  // just deal with u/exact (multiply, constant). See SCEVDivision towards the +  // end of this file for inspiration. + +  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS); +  if (!Mul || !Mul->hasNoUnsignedWrap()) +    return getUDivExpr(LHS, RHS); + +  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) { +    // If the mulexpr multiplies by a constant, then that constant must be the +    // first element of the mulexpr. +    if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) { +      if (LHSCst == RHSCst) { +        SmallVector<const SCEV *, 2> Operands; +        Operands.append(Mul->op_begin() + 1, Mul->op_end()); +        return getMulExpr(Operands); +      } + +      // We can't just assume that LHSCst divides RHSCst cleanly, it could be +      // that there's a factor provided by one of the other terms. We need to +      // check. +      APInt Factor = gcd(LHSCst, RHSCst); +      if (!Factor.isIntN(1)) { +        LHSCst = +            cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor))); +        RHSCst = +            cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor))); +        SmallVector<const SCEV *, 2> Operands; +        Operands.push_back(LHSCst); +        Operands.append(Mul->op_begin() + 1, Mul->op_end()); +        LHS = getMulExpr(Operands); +        RHS = RHSCst; +        Mul = dyn_cast<SCEVMulExpr>(LHS); +        if (!Mul) +          return getUDivExactExpr(LHS, RHS); +      } +    } +  } + +  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) { +    if (Mul->getOperand(i) == RHS) { +      SmallVector<const SCEV *, 2> Operands; +      Operands.append(Mul->op_begin(), Mul->op_begin() + i); +      Operands.append(Mul->op_begin() + i + 1, Mul->op_end()); +      return getMulExpr(Operands); +    } +  } + +  return getUDivExpr(LHS, RHS); +} + +/// Get an add recurrence expression for the specified loop.  Simplify the +/// expression as much as possible. +const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, +                                           const Loop *L, +                                           SCEV::NoWrapFlags Flags) { +  SmallVector<const SCEV *, 4> Operands; +  Operands.push_back(Start); +  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step)) +    if (StepChrec->getLoop() == L) { +      Operands.append(StepChrec->op_begin(), StepChrec->op_end()); +      return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); +    } + +  Operands.push_back(Step); +  return getAddRecExpr(Operands, L, Flags); +} + +/// Get an add recurrence expression for the specified loop.  Simplify the +/// expression as much as possible. +const SCEV * +ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, +                               const Loop *L, SCEV::NoWrapFlags Flags) { +  if (Operands.size() == 1) return Operands[0]; +#ifndef NDEBUG +  Type *ETy = getEffectiveSCEVType(Operands[0]->getType()); +  for (unsigned i = 1, e = Operands.size(); i != e; ++i) +    assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy && +           "SCEVAddRecExpr operand types don't match!"); +  for (unsigned i = 0, e = Operands.size(); i != e; ++i) +    assert(isLoopInvariant(Operands[i], L) && +           "SCEVAddRecExpr operand is not loop-invariant!"); +#endif + +  if (Operands.back()->isZero()) { +    Operands.pop_back(); +    return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X +  } + +  // It's tempting to want to call getMaxBackedgeTakenCount count here and +  // use that information to infer NUW and NSW flags. However, computing a +  // BE count requires calling getAddRecExpr, so we may not yet have a +  // meaningful BE count at this point (and if we don't, we'd be stuck +  // with a SCEVCouldNotCompute as the cached BE count). + +  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); + +  // Canonicalize nested AddRecs in by nesting them in order of loop depth. +  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { +    const Loop *NestedLoop = NestedAR->getLoop(); +    if (L->contains(NestedLoop) +            ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) +            : (!NestedLoop->contains(L) && +               DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { +      SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(), +                                                  NestedAR->op_end()); +      Operands[0] = NestedAR->getStart(); +      // AddRecs require their operands be loop-invariant with respect to their +      // loops. Don't perform this transformation if it would break this +      // requirement. +      bool AllInvariant = all_of( +          Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); }); + +      if (AllInvariant) { +        // Create a recurrence for the outer loop with the same step size. +        // +        // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the +        // inner recurrence has the same property. +        SCEV::NoWrapFlags OuterFlags = +          maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); + +        NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); +        AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { +          return isLoopInvariant(Op, NestedLoop); +        }); + +        if (AllInvariant) { +          // Ok, both add recurrences are valid after the transformation. +          // +          // The inner recurrence keeps its NW flag but only keeps NUW/NSW if +          // the outer recurrence has the same property. +          SCEV::NoWrapFlags InnerFlags = +            maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); +          return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); +        } +      } +      // Reset Operands to its original state. +      Operands[0] = NestedAR; +    } +  } + +  // Okay, it looks like we really DO need an addrec expr.  Check to see if we +  // already have one, otherwise create a new one. +  FoldingSetNodeID ID; +  ID.AddInteger(scAddRecExpr); +  for (unsigned i = 0, e = Operands.size(); i != e; ++i) +    ID.AddPointer(Operands[i]); +  ID.AddPointer(L); +  void *IP = nullptr; +  SCEVAddRecExpr *S = +    static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); +  if (!S) { +    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size()); +    std::uninitialized_copy(Operands.begin(), Operands.end(), O); +    S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), +                                           O, Operands.size(), L); +    UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S); +  } +  S->setNoWrapFlags(Flags); +  return S; +} + +const SCEV * +ScalarEvolution::getGEPExpr(GEPOperator *GEP, +                            const SmallVectorImpl<const SCEV *> &IndexExprs) { +  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand()); +  // getSCEV(Base)->getType() has the same address space as Base->getType() +  // because SCEV::getType() preserves the address space. +  Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); +  // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP +  // instruction to its SCEV, because the Instruction may be guarded by control +  // flow and the no-overflow bits may not be valid for the expression in any +  // context. This can be fixed similarly to how these flags are handled for +  // adds. +  SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW +                                             : SCEV::FlagAnyWrap; + +  const SCEV *TotalOffset = getZero(IntPtrTy); +  // The array size is unimportant. The first thing we do on CurTy is getting +  // its element type. +  Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0); +  for (const SCEV *IndexExpr : IndexExprs) { +    // Compute the (potentially symbolic) offset in bytes for this index. +    if (StructType *STy = dyn_cast<StructType>(CurTy)) { +      // For a struct, add the member offset. +      ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue(); +      unsigned FieldNo = Index->getZExtValue(); +      const SCEV *FieldOffset = getOffsetOfExpr(IntPtrTy, STy, FieldNo); + +      // Add the field offset to the running total offset. +      TotalOffset = getAddExpr(TotalOffset, FieldOffset); + +      // Update CurTy to the type of the field at Index. +      CurTy = STy->getTypeAtIndex(Index); +    } else { +      // Update CurTy to its element type. +      CurTy = cast<SequentialType>(CurTy)->getElementType(); +      // For an array, add the element offset, explicitly scaled. +      const SCEV *ElementSize = getSizeOfExpr(IntPtrTy, CurTy); +      // Getelementptr indices are signed. +      IndexExpr = getTruncateOrSignExtend(IndexExpr, IntPtrTy); + +      // Multiply the index by the element size to compute the element offset. +      const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap); + +      // Add the element offset to the running total offset. +      TotalOffset = getAddExpr(TotalOffset, LocalOffset); +    } +  } + +  // Add the total offset from all the GEP indices to the base. +  return getAddExpr(BaseExpr, TotalOffset, Wrap); +} + +const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, +                                         const SCEV *RHS) { +  SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; +  return getSMaxExpr(Ops); +} + +const SCEV * +ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { +  assert(!Ops.empty() && "Cannot get empty smax!"); +  if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG +  Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); +  for (unsigned i = 1, e = Ops.size(); i != e; ++i) +    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && +           "SCEVSMaxExpr operand types don't match!"); +#endif + +  // Sort by complexity, this groups all similar expression types together. +  GroupByComplexity(Ops, &LI, DT); + +  // If there are any constants, fold them together. +  unsigned Idx = 0; +  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { +    ++Idx; +    assert(Idx < Ops.size()); +    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { +      // We found two constants, fold them together! +      ConstantInt *Fold = ConstantInt::get( +          getContext(), APIntOps::smax(LHSC->getAPInt(), RHSC->getAPInt())); +      Ops[0] = getConstant(Fold); +      Ops.erase(Ops.begin()+1);  // Erase the folded element +      if (Ops.size() == 1) return Ops[0]; +      LHSC = cast<SCEVConstant>(Ops[0]); +    } + +    // If we are left with a constant minimum-int, strip it off. +    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) { +      Ops.erase(Ops.begin()); +      --Idx; +    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) { +      // If we have an smax with a constant maximum-int, it will always be +      // maximum-int. +      return Ops[0]; +    } + +    if (Ops.size() == 1) return Ops[0]; +  } + +  // Find the first SMax +  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr) +    ++Idx; + +  // Check to see if one of the operands is an SMax. If so, expand its operands +  // onto our operand list, and recurse to simplify. +  if (Idx < Ops.size()) { +    bool DeletedSMax = false; +    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) { +      Ops.erase(Ops.begin()+Idx); +      Ops.append(SMax->op_begin(), SMax->op_end()); +      DeletedSMax = true; +    } + +    if (DeletedSMax) +      return getSMaxExpr(Ops); +  } + +  // Okay, check to see if the same value occurs in the operand list twice.  If +  // so, delete one.  Since we sorted the list, these values are required to +  // be adjacent. +  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) +    //  X smax Y smax Y  -->  X smax Y +    //  X smax Y         -->  X, if X is always greater than Y +    if (Ops[i] == Ops[i+1] || +        isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) { +      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); +      --i; --e; +    } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) { +      Ops.erase(Ops.begin()+i, Ops.begin()+i+1); +      --i; --e; +    } + +  if (Ops.size() == 1) return Ops[0]; + +  assert(!Ops.empty() && "Reduced smax down to nothing!"); + +  // Okay, it looks like we really DO need an smax expr.  Check to see if we +  // already have one, otherwise create a new one. +  FoldingSetNodeID ID; +  ID.AddInteger(scSMaxExpr); +  for (unsigned i = 0, e = Ops.size(); i != e; ++i) +    ID.AddPointer(Ops[i]); +  void *IP = nullptr; +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); +  std::uninitialized_copy(Ops.begin(), Ops.end(), O); +  SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), +                                             O, Ops.size()); +  UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S); +  return S; +} + +const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, +                                         const SCEV *RHS) { +  SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; +  return getUMaxExpr(Ops); +} + +const SCEV * +ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { +  assert(!Ops.empty() && "Cannot get empty umax!"); +  if (Ops.size() == 1) return Ops[0]; +#ifndef NDEBUG +  Type *ETy = getEffectiveSCEVType(Ops[0]->getType()); +  for (unsigned i = 1, e = Ops.size(); i != e; ++i) +    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy && +           "SCEVUMaxExpr operand types don't match!"); +#endif + +  // Sort by complexity, this groups all similar expression types together. +  GroupByComplexity(Ops, &LI, DT); + +  // If there are any constants, fold them together. +  unsigned Idx = 0; +  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { +    ++Idx; +    assert(Idx < Ops.size()); +    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { +      // We found two constants, fold them together! +      ConstantInt *Fold = ConstantInt::get( +          getContext(), APIntOps::umax(LHSC->getAPInt(), RHSC->getAPInt())); +      Ops[0] = getConstant(Fold); +      Ops.erase(Ops.begin()+1);  // Erase the folded element +      if (Ops.size() == 1) return Ops[0]; +      LHSC = cast<SCEVConstant>(Ops[0]); +    } + +    // If we are left with a constant minimum-int, strip it off. +    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) { +      Ops.erase(Ops.begin()); +      --Idx; +    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) { +      // If we have an umax with a constant maximum-int, it will always be +      // maximum-int. +      return Ops[0]; +    } + +    if (Ops.size() == 1) return Ops[0]; +  } + +  // Find the first UMax +  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr) +    ++Idx; + +  // Check to see if one of the operands is a UMax. If so, expand its operands +  // onto our operand list, and recurse to simplify. +  if (Idx < Ops.size()) { +    bool DeletedUMax = false; +    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) { +      Ops.erase(Ops.begin()+Idx); +      Ops.append(UMax->op_begin(), UMax->op_end()); +      DeletedUMax = true; +    } + +    if (DeletedUMax) +      return getUMaxExpr(Ops); +  } + +  // Okay, check to see if the same value occurs in the operand list twice.  If +  // so, delete one.  Since we sorted the list, these values are required to +  // be adjacent. +  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) +    //  X umax Y umax Y  -->  X umax Y +    //  X umax Y         -->  X, if X is always greater than Y +    if (Ops[i] == Ops[i + 1] || isKnownViaNonRecursiveReasoning( +                                    ICmpInst::ICMP_UGE, Ops[i], Ops[i + 1])) { +      Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2); +      --i; --e; +    } else if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, Ops[i], +                                               Ops[i + 1])) { +      Ops.erase(Ops.begin() + i, Ops.begin() + i + 1); +      --i; --e; +    } + +  if (Ops.size() == 1) return Ops[0]; + +  assert(!Ops.empty() && "Reduced umax down to nothing!"); + +  // Okay, it looks like we really DO need a umax expr.  Check to see if we +  // already have one, otherwise create a new one. +  FoldingSetNodeID ID; +  ID.AddInteger(scUMaxExpr); +  for (unsigned i = 0, e = Ops.size(); i != e; ++i) +    ID.AddPointer(Ops[i]); +  void *IP = nullptr; +  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; +  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); +  std::uninitialized_copy(Ops.begin(), Ops.end(), O); +  SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), +                                             O, Ops.size()); +  UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S); +  return S; +} + +const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, +                                         const SCEV *RHS) { +  SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; +  return getSMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) { +  // ~smax(~x, ~y, ~z) == smin(x, y, z). +  SmallVector<const SCEV *, 2> NotOps; +  for (auto *S : Ops) +    NotOps.push_back(getNotSCEV(S)); +  return getNotSCEV(getSMaxExpr(NotOps)); +} + +const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, +                                         const SCEV *RHS) { +  SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; +  return getUMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) { +  assert(!Ops.empty() && "At least one operand must be!"); +  // Trivial case. +  if (Ops.size() == 1) +    return Ops[0]; + +  // ~umax(~x, ~y, ~z) == umin(x, y, z). +  SmallVector<const SCEV *, 2> NotOps; +  for (auto *S : Ops) +    NotOps.push_back(getNotSCEV(S)); +  return getNotSCEV(getUMaxExpr(NotOps)); +} + +const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { +  // We can bypass creating a target-independent +  // constant expression and then folding it back into a ConstantInt. +  // This is just a compile-time optimization. +  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); +} + +const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, +                                             StructType *STy, +                                             unsigned FieldNo) { +  // We can bypass creating a target-independent +  // constant expression and then folding it back into a ConstantInt. +  // This is just a compile-time optimization. +  return getConstant( +      IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); +} + +const SCEV *ScalarEvolution::getUnknown(Value *V) { +  // Don't attempt to do anything other than create a SCEVUnknown object +  // here.  createSCEV only calls getUnknown after checking for all other +  // interesting possibilities, and any other code that calls getUnknown +  // is doing so in order to hide a value from SCEV canonicalization. + +  FoldingSetNodeID ID; +  ID.AddInteger(scUnknown); +  ID.AddPointer(V); +  void *IP = nullptr; +  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) { +    assert(cast<SCEVUnknown>(S)->getValue() == V && +           "Stale SCEVUnknown in uniquing map!"); +    return S; +  } +  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this, +                                            FirstUnknown); +  FirstUnknown = cast<SCEVUnknown>(S); +  UniqueSCEVs.InsertNode(S, IP); +  return S; +} + +//===----------------------------------------------------------------------===// +//            Basic SCEV Analysis and PHI Idiom Recognition Code +// + +/// Test if values of the given type are analyzable within the SCEV +/// framework. This primarily includes integer types, and it can optionally +/// include pointer types if the ScalarEvolution class has access to +/// target-specific information. +bool ScalarEvolution::isSCEVable(Type *Ty) const { +  // Integers and pointers are always SCEVable. +  return Ty->isIntOrPtrTy(); +} + +/// Return the size in bits of the specified type, for which isSCEVable must +/// return true. +uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { +  assert(isSCEVable(Ty) && "Type is not SCEVable!"); +  if (Ty->isPointerTy()) +    return getDataLayout().getIndexTypeSizeInBits(Ty); +  return getDataLayout().getTypeSizeInBits(Ty); +} + +/// Return a type with the same bitwidth as the given type and which represents +/// how SCEV will treat the given type, for which isSCEVable must return +/// true. For pointer types, this is the pointer-sized integer type. +Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { +  assert(isSCEVable(Ty) && "Type is not SCEVable!"); + +  if (Ty->isIntegerTy()) +    return Ty; + +  // The only other support type is pointer. +  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!"); +  return getDataLayout().getIntPtrType(Ty); +} + +Type *ScalarEvolution::getWiderType(Type *T1, Type *T2) const { +  return  getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2; +} + +const SCEV *ScalarEvolution::getCouldNotCompute() { +  return CouldNotCompute.get(); +} + +bool ScalarEvolution::checkValidity(const SCEV *S) const { +  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { +    auto *SU = dyn_cast<SCEVUnknown>(S); +    return SU && SU->getValue() == nullptr; +  }); + +  return !ContainsNulls; +} + +bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { +  HasRecMapType::iterator I = HasRecMap.find(S); +  if (I != HasRecMap.end()) +    return I->second; + +  bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>); +  HasRecMap.insert({S, FoundAddRec}); +  return FoundAddRec; +} + +/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. +/// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an +/// offset I, then return {S', I}, else return {\p S, nullptr}. +static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) { +  const auto *Add = dyn_cast<SCEVAddExpr>(S); +  if (!Add) +    return {S, nullptr}; + +  if (Add->getNumOperands() != 2) +    return {S, nullptr}; + +  auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0)); +  if (!ConstOp) +    return {S, nullptr}; + +  return {Add->getOperand(1), ConstOp->getValue()}; +} + +/// Return the ValueOffsetPair set for \p S. \p S can be represented +/// by the value and offset from any ValueOffsetPair in the set. +SetVector<ScalarEvolution::ValueOffsetPair> * +ScalarEvolution::getSCEVValues(const SCEV *S) { +  ExprValueMapType::iterator SI = ExprValueMap.find_as(S); +  if (SI == ExprValueMap.end()) +    return nullptr; +#ifndef NDEBUG +  if (VerifySCEVMap) { +    // Check there is no dangling Value in the set returned. +    for (const auto &VE : SI->second) +      assert(ValueExprMap.count(VE.first)); +  } +#endif +  return &SI->second; +} + +/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) +/// cannot be used separately. eraseValueFromMap should be used to remove +/// V from ValueExprMap and ExprValueMap at the same time. +void ScalarEvolution::eraseValueFromMap(Value *V) { +  ValueExprMapType::iterator I = ValueExprMap.find_as(V); +  if (I != ValueExprMap.end()) { +    const SCEV *S = I->second; +    // Remove {V, 0} from the set of ExprValueMap[S] +    if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S)) +      SV->remove({V, nullptr}); + +    // Remove {V, Offset} from the set of ExprValueMap[Stripped] +    const SCEV *Stripped; +    ConstantInt *Offset; +    std::tie(Stripped, Offset) = splitAddExpr(S); +    if (Offset != nullptr) { +      if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped)) +        SV->remove({V, Offset}); +    } +    ValueExprMap.erase(V); +  } +} + +/// Check whether value has nuw/nsw/exact set but SCEV does not. +/// TODO: In reality it is better to check the poison recursevely +/// but this is better than nothing. +static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) { +  if (auto *I = dyn_cast<Instruction>(V)) { +    if (isa<OverflowingBinaryOperator>(I)) { +      if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) { +        if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap()) +          return true; +        if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap()) +          return true; +      } +    } else if (isa<PossiblyExactOperator>(I) && I->isExact()) +      return true; +  } +  return false; +} + +/// Return an existing SCEV if it exists, otherwise analyze the expression and +/// create a new one. +const SCEV *ScalarEvolution::getSCEV(Value *V) { +  assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + +  const SCEV *S = getExistingSCEV(V); +  if (S == nullptr) { +    S = createSCEV(V); +    // During PHI resolution, it is possible to create two SCEVs for the same +    // V, so it is needed to double check whether V->S is inserted into +    // ValueExprMap before insert S->{V, 0} into ExprValueMap. +    std::pair<ValueExprMapType::iterator, bool> Pair = +        ValueExprMap.insert({SCEVCallbackVH(V, this), S}); +    if (Pair.second && !SCEVLostPoisonFlags(S, V)) { +      ExprValueMap[S].insert({V, nullptr}); + +      // If S == Stripped + Offset, add Stripped -> {V, Offset} into +      // ExprValueMap. +      const SCEV *Stripped = S; +      ConstantInt *Offset = nullptr; +      std::tie(Stripped, Offset) = splitAddExpr(S); +      // If stripped is SCEVUnknown, don't bother to save +      // Stripped -> {V, offset}. It doesn't simplify and sometimes even +      // increase the complexity of the expansion code. +      // If V is GetElementPtrInst, don't save Stripped -> {V, offset} +      // because it may generate add/sub instead of GEP in SCEV expansion. +      if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) && +          !isa<GetElementPtrInst>(V)) +        ExprValueMap[Stripped].insert({V, Offset}); +    } +  } +  return S; +} + +const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { +  assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + +  ValueExprMapType::iterator I = ValueExprMap.find_as(V); +  if (I != ValueExprMap.end()) { +    const SCEV *S = I->second; +    if (checkValidity(S)) +      return S; +    eraseValueFromMap(V); +    forgetMemoizedResults(S); +  } +  return nullptr; +} + +/// Return a SCEV corresponding to -V = -1*V +const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, +                                             SCEV::NoWrapFlags Flags) { +  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) +    return getConstant( +               cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue()))); + +  Type *Ty = V->getType(); +  Ty = getEffectiveSCEVType(Ty); +  return getMulExpr( +      V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags); +} + +/// Return a SCEV corresponding to ~V = -1-V +const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { +  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) +    return getConstant( +                cast<ConstantInt>(ConstantExpr::getNot(VC->getValue()))); + +  Type *Ty = V->getType(); +  Ty = getEffectiveSCEVType(Ty); +  const SCEV *AllOnes = +                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))); +  return getMinusSCEV(AllOnes, V); +} + +const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, +                                          SCEV::NoWrapFlags Flags, +                                          unsigned Depth) { +  // Fast path: X - X --> 0. +  if (LHS == RHS) +    return getZero(LHS->getType()); + +  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation +  // makes it so that we cannot make much use of NUW. +  auto AddFlags = SCEV::FlagAnyWrap; +  const bool RHSIsNotMinSigned = +      !getSignedRangeMin(RHS).isMinSignedValue(); +  if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { +    // Let M be the minimum representable signed value. Then (-1)*RHS +    // signed-wraps if and only if RHS is M. That can happen even for +    // a NSW subtraction because e.g. (-1)*M signed-wraps even though +    // -1 - M does not. So to transfer NSW from LHS - RHS to LHS + +    // (-1)*RHS, we need to prove that RHS != M. +    // +    // If LHS is non-negative and we know that LHS - RHS does not +    // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap +    // either by proving that RHS > M or that LHS >= 0. +    if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) { +      AddFlags = SCEV::FlagNSW; +    } +  } + +  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS - +  // RHS is NSW and LHS >= 0. +  // +  // The difficulty here is that the NSW flag may have been proven +  // relative to a loop that is to be found in a recurrence in LHS and +  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a +  // larger scope than intended. +  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + +  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth); +} + +const SCEV * +ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { +  Type *SrcTy = V->getType(); +  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot truncate or zero extend with non-integer arguments!"); +  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) +    return V;  // No conversion +  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) +    return getTruncateExpr(V, Ty); +  return getZeroExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, +                                         Type *Ty) { +  Type *SrcTy = V->getType(); +  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot truncate or zero extend with non-integer arguments!"); +  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) +    return V;  // No conversion +  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty)) +    return getTruncateExpr(V, Ty); +  return getSignExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { +  Type *SrcTy = V->getType(); +  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot noop or zero extend with non-integer arguments!"); +  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && +         "getNoopOrZeroExtend cannot truncate!"); +  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) +    return V;  // No conversion +  return getZeroExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { +  Type *SrcTy = V->getType(); +  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot noop or sign extend with non-integer arguments!"); +  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && +         "getNoopOrSignExtend cannot truncate!"); +  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) +    return V;  // No conversion +  return getSignExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { +  Type *SrcTy = V->getType(); +  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot noop or any extend with non-integer arguments!"); +  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && +         "getNoopOrAnyExtend cannot truncate!"); +  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) +    return V;  // No conversion +  return getAnyExtendExpr(V, Ty); +} + +const SCEV * +ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { +  Type *SrcTy = V->getType(); +  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +         "Cannot truncate or noop with non-integer arguments!"); +  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && +         "getTruncateOrNoop cannot extend!"); +  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) +    return V;  // No conversion +  return getTruncateExpr(V, Ty); +} + +const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, +                                                        const SCEV *RHS) { +  const SCEV *PromotedLHS = LHS; +  const SCEV *PromotedRHS = RHS; + +  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) +    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); +  else +    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); + +  return getUMaxExpr(PromotedLHS, PromotedRHS); +} + +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, +                                                        const SCEV *RHS) { +  SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; +  return getUMinFromMismatchedTypes(Ops); +} + +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes( +    SmallVectorImpl<const SCEV *> &Ops) { +  assert(!Ops.empty() && "At least one operand must be!"); +  // Trivial case. +  if (Ops.size() == 1) +    return Ops[0]; + +  // Find the max type first. +  Type *MaxType = nullptr; +  for (auto *S : Ops) +    if (MaxType) +      MaxType = getWiderType(MaxType, S->getType()); +    else +      MaxType = S->getType(); + +  // Extend all ops to max type. +  SmallVector<const SCEV *, 2> PromotedOps; +  for (auto *S : Ops) +    PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); + +  // Generate umin. +  return getUMinExpr(PromotedOps); +} + +const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { +  // A pointer operand may evaluate to a nonpointer expression, such as null. +  if (!V->getType()->isPointerTy()) +    return V; + +  if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) { +    return getPointerBase(Cast->getOperand()); +  } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) { +    const SCEV *PtrOp = nullptr; +    for (const SCEV *NAryOp : NAry->operands()) { +      if (NAryOp->getType()->isPointerTy()) { +        // Cannot find the base of an expression with multiple pointer operands. +        if (PtrOp) +          return V; +        PtrOp = NAryOp; +      } +    } +    if (!PtrOp) +      return V; +    return getPointerBase(PtrOp); +  } +  return V; +} + +/// Push users of the given Instruction onto the given Worklist. +static void +PushDefUseChildren(Instruction *I, +                   SmallVectorImpl<Instruction *> &Worklist) { +  // Push the def-use children onto the Worklist stack. +  for (User *U : I->users()) +    Worklist.push_back(cast<Instruction>(U)); +} + +void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { +  SmallVector<Instruction *, 16> Worklist; +  PushDefUseChildren(PN, Worklist); + +  SmallPtrSet<Instruction *, 8> Visited; +  Visited.insert(PN); +  while (!Worklist.empty()) { +    Instruction *I = Worklist.pop_back_val(); +    if (!Visited.insert(I).second) +      continue; + +    auto It = ValueExprMap.find_as(static_cast<Value *>(I)); +    if (It != ValueExprMap.end()) { +      const SCEV *Old = It->second; + +      // Short-circuit the def-use traversal if the symbolic name +      // ceases to appear in expressions. +      if (Old != SymName && !hasOperand(Old, SymName)) +        continue; + +      // SCEVUnknown for a PHI either means that it has an unrecognized +      // structure, it's a PHI that's in the progress of being computed +      // by createNodeForPHI, or it's a single-value PHI. In the first case, +      // additional loop trip count information isn't going to change anything. +      // In the second case, createNodeForPHI will perform the necessary +      // updates on its own when it gets to that point. In the third, we do +      // want to forget the SCEVUnknown. +      if (!isa<PHINode>(I) || +          !isa<SCEVUnknown>(Old) || +          (I != PN && Old == SymName)) { +        eraseValueFromMap(It->first); +        forgetMemoizedResults(Old); +      } +    } + +    PushDefUseChildren(I, Worklist); +  } +} + +namespace { + +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start +/// expression in case its Loop is L. If it is not L then +/// if IgnoreOtherLoops is true then use AddRec itself +/// otherwise rewrite cannot be done. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. +class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { +public: +  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, +                             bool IgnoreOtherLoops = true) { +    SCEVInitRewriter Rewriter(L, SE); +    const SCEV *Result = Rewriter.visit(S); +    if (Rewriter.hasSeenLoopVariantSCEVUnknown()) +      return SE.getCouldNotCompute(); +    return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops +               ? SE.getCouldNotCompute() +               : Result; +  } + +  const SCEV *visitUnknown(const SCEVUnknown *Expr) { +    if (!SE.isLoopInvariant(Expr, L)) +      SeenLoopVariantSCEVUnknown = true; +    return Expr; +  } + +  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { +    // Only re-write AddRecExprs for this loop. +    if (Expr->getLoop() == L) +      return Expr->getStart(); +    SeenOtherLoops = true; +    return Expr; +  } + +  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } + +  bool hasSeenOtherLoops() { return SeenOtherLoops; } + +private: +  explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L) {} + +  const Loop *L; +  bool SeenLoopVariantSCEVUnknown = false; +  bool SeenOtherLoops = false; +}; + +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post +/// increment expression in case its Loop is L. If it is not L then +/// use AddRec itself. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. +class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> { +public: +  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { +    SCEVPostIncRewriter Rewriter(L, SE); +    const SCEV *Result = Rewriter.visit(S); +    return Rewriter.hasSeenLoopVariantSCEVUnknown() +        ? SE.getCouldNotCompute() +        : Result; +  } + +  const SCEV *visitUnknown(const SCEVUnknown *Expr) { +    if (!SE.isLoopInvariant(Expr, L)) +      SeenLoopVariantSCEVUnknown = true; +    return Expr; +  } + +  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { +    // Only re-write AddRecExprs for this loop. +    if (Expr->getLoop() == L) +      return Expr->getPostIncExpr(SE); +    SeenOtherLoops = true; +    return Expr; +  } + +  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } + +  bool hasSeenOtherLoops() { return SeenOtherLoops; } + +private: +  explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L) {} + +  const Loop *L; +  bool SeenLoopVariantSCEVUnknown = false; +  bool SeenOtherLoops = false; +}; + +/// This class evaluates the compare condition by matching it against the +/// condition of loop latch. If there is a match we assume a true value +/// for the condition while building SCEV nodes. +class SCEVBackedgeConditionFolder +    : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> { +public: +  static const SCEV *rewrite(const SCEV *S, const Loop *L, +                             ScalarEvolution &SE) { +    bool IsPosBECond = false; +    Value *BECond = nullptr; +    if (BasicBlock *Latch = L->getLoopLatch()) { +      BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator()); +      if (BI && BI->isConditional()) { +        assert(BI->getSuccessor(0) != BI->getSuccessor(1) && +               "Both outgoing branches should not target same header!"); +        BECond = BI->getCondition(); +        IsPosBECond = BI->getSuccessor(0) == L->getHeader(); +      } else { +        return S; +      } +    } +    SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE); +    return Rewriter.visit(S); +  } + +  const SCEV *visitUnknown(const SCEVUnknown *Expr) { +    const SCEV *Result = Expr; +    bool InvariantF = SE.isLoopInvariant(Expr, L); + +    if (!InvariantF) { +      Instruction *I = cast<Instruction>(Expr->getValue()); +      switch (I->getOpcode()) { +      case Instruction::Select: { +        SelectInst *SI = cast<SelectInst>(I); +        Optional<const SCEV *> Res = +            compareWithBackedgeCondition(SI->getCondition()); +        if (Res.hasValue()) { +          bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne(); +          Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue()); +        } +        break; +      } +      default: { +        Optional<const SCEV *> Res = compareWithBackedgeCondition(I); +        if (Res.hasValue()) +          Result = Res.getValue(); +        break; +      } +      } +    } +    return Result; +  } + +private: +  explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond, +                                       bool IsPosBECond, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), +        IsPositiveBECond(IsPosBECond) {} + +  Optional<const SCEV *> compareWithBackedgeCondition(Value *IC); + +  const Loop *L; +  /// Loop back condition. +  Value *BackedgeCond = nullptr; +  /// Set to true if loop back is on positive branch condition. +  bool IsPositiveBECond; +}; + +Optional<const SCEV *> +SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { + +  // If value matches the backedge condition for loop latch, +  // then return a constant evolution node based on loopback +  // branch taken. +  if (BackedgeCond == IC) +    return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext())) +                            : SE.getZero(Type::getInt1Ty(SE.getContext())); +  return None; +} + +class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { +public: +  static const SCEV *rewrite(const SCEV *S, const Loop *L, +                             ScalarEvolution &SE) { +    SCEVShiftRewriter Rewriter(L, SE); +    const SCEV *Result = Rewriter.visit(S); +    return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); +  } + +  const SCEV *visitUnknown(const SCEVUnknown *Expr) { +    // Only allow AddRecExprs for this loop. +    if (!SE.isLoopInvariant(Expr, L)) +      Valid = false; +    return Expr; +  } + +  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { +    if (Expr->getLoop() == L && Expr->isAffine()) +      return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); +    Valid = false; +    return Expr; +  } + +  bool isValid() { return Valid; } + +private: +  explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L) {} + +  const Loop *L; +  bool Valid = true; +}; + +} // end anonymous namespace + +SCEV::NoWrapFlags +ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { +  if (!AR->isAffine()) +    return SCEV::FlagAnyWrap; + +  using OBO = OverflowingBinaryOperator; + +  SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; + +  if (!AR->hasNoSignedWrap()) { +    ConstantRange AddRecRange = getSignedRange(AR); +    ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); + +    auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( +        Instruction::Add, IncRange, OBO::NoSignedWrap); +    if (NSWRegion.contains(AddRecRange)) +      Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); +  } + +  if (!AR->hasNoUnsignedWrap()) { +    ConstantRange AddRecRange = getUnsignedRange(AR); +    ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); + +    auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( +        Instruction::Add, IncRange, OBO::NoUnsignedWrap); +    if (NUWRegion.contains(AddRecRange)) +      Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); +  } + +  return Result; +} + +namespace { + +/// Represents an abstract binary operation.  This may exist as a +/// normal instruction or constant expression, or may have been +/// derived from an expression tree. +struct BinaryOp { +  unsigned Opcode; +  Value *LHS; +  Value *RHS; +  bool IsNSW = false; +  bool IsNUW = false; + +  /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or +  /// constant expression. +  Operator *Op = nullptr; + +  explicit BinaryOp(Operator *Op) +      : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), +        Op(Op) { +    if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { +      IsNSW = OBO->hasNoSignedWrap(); +      IsNUW = OBO->hasNoUnsignedWrap(); +    } +  } + +  explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, +                    bool IsNUW = false) +      : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {} +}; + +} // end anonymous namespace + +/// Try to map \p V into a BinaryOp, and return \c None on failure. +static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { +  auto *Op = dyn_cast<Operator>(V); +  if (!Op) +    return None; + +  // Implementation detail: all the cleverness here should happen without +  // creating new SCEV expressions -- our caller knowns tricks to avoid creating +  // SCEV expressions when possible, and we should not break that. + +  switch (Op->getOpcode()) { +  case Instruction::Add: +  case Instruction::Sub: +  case Instruction::Mul: +  case Instruction::UDiv: +  case Instruction::URem: +  case Instruction::And: +  case Instruction::Or: +  case Instruction::AShr: +  case Instruction::Shl: +    return BinaryOp(Op); + +  case Instruction::Xor: +    if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) +      // If the RHS of the xor is a signmask, then this is just an add. +      // Instcombine turns add of signmask into xor as a strength reduction step. +      if (RHSC->getValue().isSignMask()) +        return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); +    return BinaryOp(Op); + +  case Instruction::LShr: +    // Turn logical shift right of a constant into a unsigned divide. +    if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { +      uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth(); + +      // If the shift count is not less than the bitwidth, the result of +      // the shift is undefined. Don't try to analyze it, because the +      // resolution chosen here may differ from the resolution chosen in +      // other parts of the compiler. +      if (SA->getValue().ult(BitWidth)) { +        Constant *X = +            ConstantInt::get(SA->getContext(), +                             APInt::getOneBitSet(BitWidth, SA->getZExtValue())); +        return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); +      } +    } +    return BinaryOp(Op); + +  case Instruction::ExtractValue: { +    auto *EVI = cast<ExtractValueInst>(Op); +    if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) +      break; + +    auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand()); +    if (!CI) +      break; + +    if (auto *F = CI->getCalledFunction()) +      switch (F->getIntrinsicID()) { +      case Intrinsic::sadd_with_overflow: +      case Intrinsic::uadd_with_overflow: +        if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) +          return BinaryOp(Instruction::Add, CI->getArgOperand(0), +                          CI->getArgOperand(1)); + +        // Now that we know that all uses of the arithmetic-result component of +        // CI are guarded by the overflow check, we can go ahead and pretend +        // that the arithmetic is non-overflowing. +        if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow) +          return BinaryOp(Instruction::Add, CI->getArgOperand(0), +                          CI->getArgOperand(1), /* IsNSW = */ true, +                          /* IsNUW = */ false); +        else +          return BinaryOp(Instruction::Add, CI->getArgOperand(0), +                          CI->getArgOperand(1), /* IsNSW = */ false, +                          /* IsNUW*/ true); +      case Intrinsic::ssub_with_overflow: +      case Intrinsic::usub_with_overflow: +        if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) +          return BinaryOp(Instruction::Sub, CI->getArgOperand(0), +                          CI->getArgOperand(1)); + +        // The same reasoning as sadd/uadd above. +        if (F->getIntrinsicID() == Intrinsic::ssub_with_overflow) +          return BinaryOp(Instruction::Sub, CI->getArgOperand(0), +                          CI->getArgOperand(1), /* IsNSW = */ true, +                          /* IsNUW = */ false); +        else +          return BinaryOp(Instruction::Sub, CI->getArgOperand(0), +                          CI->getArgOperand(1), /* IsNSW = */ false, +                          /* IsNUW = */ true); +      case Intrinsic::smul_with_overflow: +      case Intrinsic::umul_with_overflow: +        return BinaryOp(Instruction::Mul, CI->getArgOperand(0), +                        CI->getArgOperand(1)); +      default: +        break; +      } +    break; +  } + +  default: +    break; +  } + +  return None; +} + +/// Helper function to createAddRecFromPHIWithCasts. We have a phi +/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via +/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the +/// way. This function checks if \p Op, an operand of this SCEVAddExpr, +/// follows one of the following patterns: +/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) +/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) +/// If the SCEV expression of \p Op conforms with one of the expected patterns +/// we return the type of the truncation operation, and indicate whether the +/// truncated type should be treated as signed/unsigned by setting +/// \p Signed to true/false, respectively. +static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, +                               bool &Signed, ScalarEvolution &SE) { +  // The case where Op == SymbolicPHI (that is, with no type conversions on +  // the way) is handled by the regular add recurrence creating logic and +  // would have already been triggered in createAddRecForPHI. Reaching it here +  // means that createAddRecFromPHI had failed for this PHI before (e.g., +  // because one of the other operands of the SCEVAddExpr updating this PHI is +  // not invariant). +  // +  // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in +  // this case predicates that allow us to prove that Op == SymbolicPHI will +  // be added. +  if (Op == SymbolicPHI) +    return nullptr; + +  unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType()); +  unsigned NewBits = SE.getTypeSizeInBits(Op->getType()); +  if (SourceBits != NewBits) +    return nullptr; + +  const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op); +  const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op); +  if (!SExt && !ZExt) +    return nullptr; +  const SCEVTruncateExpr *Trunc = +      SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand()) +           : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand()); +  if (!Trunc) +    return nullptr; +  const SCEV *X = Trunc->getOperand(); +  if (X != SymbolicPHI) +    return nullptr; +  Signed = SExt != nullptr; +  return Trunc->getType(); +} + +static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { +  if (!PN->getType()->isIntegerTy()) +    return nullptr; +  const Loop *L = LI.getLoopFor(PN->getParent()); +  if (!L || L->getHeader() != PN->getParent()) +    return nullptr; +  return L; +} + +// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the +// computation that updates the phi follows the following pattern: +//   (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum +// which correspond to a phi->trunc->sext/zext->add->phi update chain. +// If so, try to see if it can be rewritten as an AddRecExpr under some +// Predicates. If successful, return them as a pair. Also cache the results +// of the analysis. +// +// Example usage scenario: +//    Say the Rewriter is called for the following SCEV: +//         8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step) +//    where: +//         %X = phi i64 (%Start, %BEValue) +//    It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X), +//    and call this function with %SymbolicPHI = %X. +// +//    The analysis will find that the value coming around the backedge has +//    the following SCEV: +//         BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step) +//    Upon concluding that this matches the desired pattern, the function +//    will return the pair {NewAddRec, SmallPredsVec} where: +//         NewAddRec = {%Start,+,%Step} +//         SmallPredsVec = {P1, P2, P3} as follows: +//           P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw> +//           P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64) +//           P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64) +//    The returned pair means that SymbolicPHI can be rewritten into NewAddRec +//    under the predicates {P1,P2,P3}. +//    This predicated rewrite will be cached in PredicatedSCEVRewrites: +//         PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)} +// +// TODO's: +// +// 1) Extend the Induction descriptor to also support inductions that involve +//    casts: When needed (namely, when we are called in the context of the +//    vectorizer induction analysis), a Set of cast instructions will be +//    populated by this method, and provided back to isInductionPHI. This is +//    needed to allow the vectorizer to properly record them to be ignored by +//    the cost model and to avoid vectorizing them (otherwise these casts, +//    which are redundant under the runtime overflow checks, will be +//    vectorized, which can be costly). +// +// 2) Support additional induction/PHISCEV patterns: We also want to support +//    inductions where the sext-trunc / zext-trunc operations (partly) occur +//    after the induction update operation (the induction increment): +// +//      (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix) +//    which correspond to a phi->add->trunc->sext/zext->phi update chain. +// +//      (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix) +//    which correspond to a phi->trunc->add->sext/zext->phi update chain. +// +// 3) Outline common code with createAddRecFromPHI to avoid duplication. +Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) { +  SmallVector<const SCEVPredicate *, 3> Predicates; + +  // *** Part1: Analyze if we have a phi-with-cast pattern for which we can +  // return an AddRec expression under some predicate. + +  auto *PN = cast<PHINode>(SymbolicPHI->getValue()); +  const Loop *L = isIntegerLoopHeaderPHI(PN, LI); +  assert(L && "Expecting an integer loop header phi"); + +  // The loop may have multiple entrances or multiple exits; we can analyze +  // this phi as an addrec if it has a unique entry value and a unique +  // backedge value. +  Value *BEValueV = nullptr, *StartValueV = nullptr; +  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { +    Value *V = PN->getIncomingValue(i); +    if (L->contains(PN->getIncomingBlock(i))) { +      if (!BEValueV) { +        BEValueV = V; +      } else if (BEValueV != V) { +        BEValueV = nullptr; +        break; +      } +    } else if (!StartValueV) { +      StartValueV = V; +    } else if (StartValueV != V) { +      StartValueV = nullptr; +      break; +    } +  } +  if (!BEValueV || !StartValueV) +    return None; + +  const SCEV *BEValue = getSCEV(BEValueV); + +  // If the value coming around the backedge is an add with the symbolic +  // value we just inserted, possibly with casts that we can ignore under +  // an appropriate runtime guard, then we found a simple induction variable! +  const auto *Add = dyn_cast<SCEVAddExpr>(BEValue); +  if (!Add) +    return None; + +  // If there is a single occurrence of the symbolic value, possibly +  // casted, replace it with a recurrence. +  unsigned FoundIndex = Add->getNumOperands(); +  Type *TruncTy = nullptr; +  bool Signed; +  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) +    if ((TruncTy = +             isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this))) +      if (FoundIndex == e) { +        FoundIndex = i; +        break; +      } + +  if (FoundIndex == Add->getNumOperands()) +    return None; + +  // Create an add with everything but the specified operand. +  SmallVector<const SCEV *, 8> Ops; +  for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) +    if (i != FoundIndex) +      Ops.push_back(Add->getOperand(i)); +  const SCEV *Accum = getAddExpr(Ops); + +  // The runtime checks will not be valid if the step amount is +  // varying inside the loop. +  if (!isLoopInvariant(Accum, L)) +    return None; + +  // *** Part2: Create the predicates + +  // Analysis was successful: we have a phi-with-cast pattern for which we +  // can return an AddRec expression under the following predicates: +  // +  // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum) +  //     fits within the truncated type (does not overflow) for i = 0 to n-1. +  // P2: An Equal predicate that guarantees that +  //     Start = (Ext ix (Trunc iy (Start) to ix) to iy) +  // P3: An Equal predicate that guarantees that +  //     Accum = (Ext ix (Trunc iy (Accum) to ix) to iy) +  // +  // As we next prove, the above predicates guarantee that: +  //     Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy) +  // +  // +  // More formally, we want to prove that: +  //     Expr(i+1) = Start + (i+1) * Accum +  //               = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum +  // +  // Given that: +  // 1) Expr(0) = Start +  // 2) Expr(1) = Start + Accum +  //            = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2 +  // 3) Induction hypothesis (step i): +  //    Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum +  // +  // Proof: +  //  Expr(i+1) = +  //   = Start + (i+1)*Accum +  //   = (Start + i*Accum) + Accum +  //   = Expr(i) + Accum +  //   = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum +  //                                                             :: from step i +  // +  //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum +  // +  //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) +  //     + (Ext ix (Trunc iy (Accum) to ix) to iy) +  //     + Accum                                                     :: from P3 +  // +  //   = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy) +  //     + Accum                            :: from P1: Ext(x)+Ext(y)=>Ext(x+y) +  // +  //   = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum +  //   = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum +  // +  // By induction, the same applies to all iterations 1<=i<n: +  // + +  // Create a truncated addrec for which we will add a no overflow check (P1). +  const SCEV *StartVal = getSCEV(StartValueV); +  const SCEV *PHISCEV = +      getAddRecExpr(getTruncateExpr(StartVal, TruncTy), +                    getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); + +  // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr. +  // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV +  // will be constant. +  // +  //  If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't +  // add P1. +  if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { +    SCEVWrapPredicate::IncrementWrapFlags AddedFlags = +        Signed ? SCEVWrapPredicate::IncrementNSSW +               : SCEVWrapPredicate::IncrementNUSW; +    const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); +    Predicates.push_back(AddRecPred); +  } + +  // Create the Equal Predicates P2,P3: + +  // It is possible that the predicates P2 and/or P3 are computable at +  // compile time due to StartVal and/or Accum being constants. +  // If either one is, then we can check that now and escape if either P2 +  // or P3 is false. + +  // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) +  // for each of StartVal and Accum +  auto getExtendedExpr = [&](const SCEV *Expr, +                             bool CreateSignExtend) -> const SCEV * { +    assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); +    const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); +    const SCEV *ExtendedExpr = +        CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) +                         : getZeroExtendExpr(TruncatedExpr, Expr->getType()); +    return ExtendedExpr; +  }; + +  // Given: +  //  ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy +  //               = getExtendedExpr(Expr) +  // Determine whether the predicate P: Expr == ExtendedExpr +  // is known to be false at compile time +  auto PredIsKnownFalse = [&](const SCEV *Expr, +                              const SCEV *ExtendedExpr) -> bool { +    return Expr != ExtendedExpr && +           isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); +  }; + +  const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); +  if (PredIsKnownFalse(StartVal, StartExtended)) { +    LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";); +    return None; +  } + +  // The Step is always Signed (because the overflow checks are either +  // NSSW or NUSW) +  const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); +  if (PredIsKnownFalse(Accum, AccumExtended)) { +    LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";); +    return None; +  } + +  auto AppendPredicate = [&](const SCEV *Expr, +                             const SCEV *ExtendedExpr) -> void { +    if (Expr != ExtendedExpr && +        !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { +      const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); +      LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred); +      Predicates.push_back(Pred); +    } +  }; + +  AppendPredicate(StartVal, StartExtended); +  AppendPredicate(Accum, AccumExtended); + +  // *** Part3: Predicates are ready. Now go ahead and create the new addrec in +  // which the casts had been folded away. The caller can rewrite SymbolicPHI +  // into NewAR if it will also add the runtime overflow checks specified in +  // Predicates. +  auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); + +  std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite = +      std::make_pair(NewAR, Predicates); +  // Remember the result of the analysis for this SCEV at this locayyytion. +  PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite; +  return PredRewrite; +} + +Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { +  auto *PN = cast<PHINode>(SymbolicPHI->getValue()); +  const Loop *L = isIntegerLoopHeaderPHI(PN, LI); +  if (!L) +    return None; + +  // Check to see if we already analyzed this PHI. +  auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L}); +  if (I != PredicatedSCEVRewrites.end()) { +    std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite = +        I->second; +    // Analysis was done before and failed to create an AddRec: +    if (Rewrite.first == SymbolicPHI) +      return None; +    // Analysis was done before and succeeded to create an AddRec under +    // a predicate: +    assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec"); +    assert(!(Rewrite.second).empty() && "Expected to find Predicates"); +    return Rewrite; +  } + +  Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +    Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI); + +  // Record in the cache that the analysis failed +  if (!Rewrite) { +    SmallVector<const SCEVPredicate *, 3> Predicates; +    PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates}; +    return None; +  } + +  return Rewrite; +} + +// FIXME: This utility is currently required because the Rewriter currently +// does not rewrite this expression: +// {0, +, (sext ix (trunc iy to ix) to iy)} +// into {0, +, %step}, +// even when the following Equal predicate exists: +// "%step == (sext ix (trunc iy to ix) to iy)". +bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( +    const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const { +  if (AR1 == AR2) +    return true; + +  auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { +    if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) && +        !Preds.implies(SE.getEqualPredicate(Expr2, Expr1))) +      return false; +    return true; +  }; + +  if (!areExprsEqual(AR1->getStart(), AR2->getStart()) || +      !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE))) +    return false; +  return true; +} + +/// A helper function for createAddRecFromPHI to handle simple cases. +/// +/// This function tries to find an AddRec expression for the simplest (yet most +/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)). +/// If it fails, createAddRecFromPHI will use a more general, but slow, +/// technique for finding the AddRec expression. +const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN, +                                                      Value *BEValueV, +                                                      Value *StartValueV) { +  const Loop *L = LI.getLoopFor(PN->getParent()); +  assert(L && L->getHeader() == PN->getParent()); +  assert(BEValueV && StartValueV); + +  auto BO = MatchBinaryOp(BEValueV, DT); +  if (!BO) +    return nullptr; + +  if (BO->Opcode != Instruction::Add) +    return nullptr; + +  const SCEV *Accum = nullptr; +  if (BO->LHS == PN && L->isLoopInvariant(BO->RHS)) +    Accum = getSCEV(BO->RHS); +  else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS)) +    Accum = getSCEV(BO->LHS); + +  if (!Accum) +    return nullptr; + +  SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; +  if (BO->IsNUW) +    Flags = setFlags(Flags, SCEV::FlagNUW); +  if (BO->IsNSW) +    Flags = setFlags(Flags, SCEV::FlagNSW); + +  const SCEV *StartVal = getSCEV(StartValueV); +  const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + +  ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + +  // We can add Flags to the post-inc expression only if we +  // know that it is *undefined behavior* for BEValueV to +  // overflow. +  if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) +    if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) +      (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + +  return PHISCEV; +} + +const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { +  const Loop *L = LI.getLoopFor(PN->getParent()); +  if (!L || L->getHeader() != PN->getParent()) +    return nullptr; + +  // The loop may have multiple entrances or multiple exits; we can analyze +  // this phi as an addrec if it has a unique entry value and a unique +  // backedge value. +  Value *BEValueV = nullptr, *StartValueV = nullptr; +  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { +    Value *V = PN->getIncomingValue(i); +    if (L->contains(PN->getIncomingBlock(i))) { +      if (!BEValueV) { +        BEValueV = V; +      } else if (BEValueV != V) { +        BEValueV = nullptr; +        break; +      } +    } else if (!StartValueV) { +      StartValueV = V; +    } else if (StartValueV != V) { +      StartValueV = nullptr; +      break; +    } +  } +  if (!BEValueV || !StartValueV) +    return nullptr; + +  assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && +         "PHI node already processed?"); + +  // First, try to find AddRec expression without creating a fictituos symbolic +  // value for PN. +  if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV)) +    return S; + +  // Handle PHI node value symbolically. +  const SCEV *SymbolicName = getUnknown(PN); +  ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); + +  // Using this symbolic name for the PHI, analyze the value coming around +  // the back-edge. +  const SCEV *BEValue = getSCEV(BEValueV); + +  // NOTE: If BEValue is loop invariant, we know that the PHI node just +  // has a special value for the first iteration of the loop. + +  // If the value coming around the backedge is an add with the symbolic +  // value we just inserted, then we found a simple induction variable! +  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) { +    // If there is a single occurrence of the symbolic value, replace it +    // with a recurrence. +    unsigned FoundIndex = Add->getNumOperands(); +    for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) +      if (Add->getOperand(i) == SymbolicName) +        if (FoundIndex == e) { +          FoundIndex = i; +          break; +        } + +    if (FoundIndex != Add->getNumOperands()) { +      // Create an add with everything but the specified operand. +      SmallVector<const SCEV *, 8> Ops; +      for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) +        if (i != FoundIndex) +          Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), +                                                             L, *this)); +      const SCEV *Accum = getAddExpr(Ops); + +      // This is not a valid addrec if the step amount is varying each +      // loop iteration, but is not itself an addrec in this loop. +      if (isLoopInvariant(Accum, L) || +          (isa<SCEVAddRecExpr>(Accum) && +           cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { +        SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + +        if (auto BO = MatchBinaryOp(BEValueV, DT)) { +          if (BO->Opcode == Instruction::Add && BO->LHS == PN) { +            if (BO->IsNUW) +              Flags = setFlags(Flags, SCEV::FlagNUW); +            if (BO->IsNSW) +              Flags = setFlags(Flags, SCEV::FlagNSW); +          } +        } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { +          // If the increment is an inbounds GEP, then we know the address +          // space cannot be wrapped around. We cannot make any guarantee +          // about signed or unsigned overflow because pointers are +          // unsigned but we may have a negative index from the base +          // pointer. We can guarantee that no unsigned wrap occurs if the +          // indices form a positive value. +          if (GEP->isInBounds() && GEP->getOperand(0) == PN) { +            Flags = setFlags(Flags, SCEV::FlagNW); + +            const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); +            if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) +              Flags = setFlags(Flags, SCEV::FlagNUW); +          } + +          // We cannot transfer nuw and nsw flags from subtraction +          // operations -- sub nuw X, Y is not the same as add nuw X, -Y +          // for instance. +        } + +        const SCEV *StartVal = getSCEV(StartValueV); +        const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); + +        // Okay, for the entire analysis of this edge we assumed the PHI +        // to be symbolic.  We now need to go back and purge all of the +        // entries for the scalars that use the symbolic expression. +        forgetSymbolicName(PN, SymbolicName); +        ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + +        // We can add Flags to the post-inc expression only if we +        // know that it is *undefined behavior* for BEValueV to +        // overflow. +        if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) +          if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) +            (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + +        return PHISCEV; +      } +    } +  } else { +    // Otherwise, this could be a loop like this: +    //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; } +    // In this case, j = {1,+,1}  and BEValue is j. +    // Because the other in-value of i (0) fits the evolution of BEValue +    // i really is an addrec evolution. +    // +    // We can generalize this saying that i is the shifted value of BEValue +    // by one iteration: +    //   PHI(f(0), f({1,+,1})) --> f({0,+,1}) +    const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); +    const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false); +    if (Shifted != getCouldNotCompute() && +        Start != getCouldNotCompute()) { +      const SCEV *StartVal = getSCEV(StartValueV); +      if (Start == StartVal) { +        // Okay, for the entire analysis of this edge we assumed the PHI +        // to be symbolic.  We now need to go back and purge all of the +        // entries for the scalars that use the symbolic expression. +        forgetSymbolicName(PN, SymbolicName); +        ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; +        return Shifted; +      } +    } +  } + +  // Remove the temporary PHI node SCEV that has been inserted while intending +  // to create an AddRecExpr for this PHI node. We can not keep this temporary +  // as it will prevent later (possibly simpler) SCEV expressions to be added +  // to the ValueExprMap. +  eraseValueFromMap(PN); + +  return nullptr; +} + +// Checks if the SCEV S is available at BB.  S is considered available at BB +// if S can be materialized at BB without introducing a fault. +static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, +                               BasicBlock *BB) { +  struct CheckAvailable { +    bool TraversalDone = false; +    bool Available = true; + +    const Loop *L = nullptr;  // The loop BB is in (can be nullptr) +    BasicBlock *BB = nullptr; +    DominatorTree &DT; + +    CheckAvailable(const Loop *L, BasicBlock *BB, DominatorTree &DT) +      : L(L), BB(BB), DT(DT) {} + +    bool setUnavailable() { +      TraversalDone = true; +      Available = false; +      return false; +    } + +    bool follow(const SCEV *S) { +      switch (S->getSCEVType()) { +      case scConstant: case scTruncate: case scZeroExtend: case scSignExtend: +      case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr: +        // These expressions are available if their operand(s) is/are. +        return true; + +      case scAddRecExpr: { +        // We allow add recurrences that are on the loop BB is in, or some +        // outer loop.  This guarantees availability because the value of the +        // add recurrence at BB is simply the "current" value of the induction +        // variable.  We can relax this in the future; for instance an add +        // recurrence on a sibling dominating loop is also available at BB. +        const auto *ARLoop = cast<SCEVAddRecExpr>(S)->getLoop(); +        if (L && (ARLoop == L || ARLoop->contains(L))) +          return true; + +        return setUnavailable(); +      } + +      case scUnknown: { +        // For SCEVUnknown, we check for simple dominance. +        const auto *SU = cast<SCEVUnknown>(S); +        Value *V = SU->getValue(); + +        if (isa<Argument>(V)) +          return false; + +        if (isa<Instruction>(V) && DT.dominates(cast<Instruction>(V), BB)) +          return false; + +        return setUnavailable(); +      } + +      case scUDivExpr: +      case scCouldNotCompute: +        // We do not try to smart about these at all. +        return setUnavailable(); +      } +      llvm_unreachable("switch should be fully covered!"); +    } + +    bool isDone() { return TraversalDone; } +  }; + +  CheckAvailable CA(L, BB, DT); +  SCEVTraversal<CheckAvailable> ST(CA); + +  ST.visitAll(S); +  return CA.Available; +} + +// Try to match a control flow sequence that branches out at BI and merges back +// at Merge into a "C ? LHS : RHS" select pattern.  Return true on a successful +// match. +static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, +                          Value *&C, Value *&LHS, Value *&RHS) { +  C = BI->getCondition(); + +  BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0)); +  BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1)); + +  if (!LeftEdge.isSingleEdge()) +    return false; + +  assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()"); + +  Use &LeftUse = Merge->getOperandUse(0); +  Use &RightUse = Merge->getOperandUse(1); + +  if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) { +    LHS = LeftUse; +    RHS = RightUse; +    return true; +  } + +  if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) { +    LHS = RightUse; +    RHS = LeftUse; +    return true; +  } + +  return false; +} + +const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { +  auto IsReachable = +      [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; +  if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) { +    const Loop *L = LI.getLoopFor(PN->getParent()); + +    // We don't want to break LCSSA, even in a SCEV expression tree. +    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) +      if (LI.getLoopFor(PN->getIncomingBlock(i)) != L) +        return nullptr; + +    // Try to match +    // +    //  br %cond, label %left, label %right +    // left: +    //  br label %merge +    // right: +    //  br label %merge +    // merge: +    //  V = phi [ %x, %left ], [ %y, %right ] +    // +    // as "select %cond, %x, %y" + +    BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock(); +    assert(IDom && "At least the entry block should dominate PN"); + +    auto *BI = dyn_cast<BranchInst>(IDom->getTerminator()); +    Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr; + +    if (BI && BI->isConditional() && +        BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) && +        IsAvailableOnEntry(L, DT, getSCEV(LHS), PN->getParent()) && +        IsAvailableOnEntry(L, DT, getSCEV(RHS), PN->getParent())) +      return createNodeForSelectOrPHI(PN, Cond, LHS, RHS); +  } + +  return nullptr; +} + +const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { +  if (const SCEV *S = createAddRecFromPHI(PN)) +    return S; + +  if (const SCEV *S = createNodeFromSelectLikePHI(PN)) +    return S; + +  // If the PHI has a single incoming value, follow that value, unless the +  // PHI's incoming blocks are in a different loop, in which case doing so +  // risks breaking LCSSA form. Instcombine would normally zap these, but +  // it doesn't have DominatorTree information, so it may miss cases. +  if (Value *V = SimplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC})) +    if (LI.replacementPreservesLCSSAForm(PN, V)) +      return getSCEV(V); + +  // If it's not a loop phi, we can't handle it yet. +  return getUnknown(PN); +} + +const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, +                                                      Value *Cond, +                                                      Value *TrueVal, +                                                      Value *FalseVal) { +  // Handle "constant" branch or select. This can occur for instance when a +  // loop pass transforms an inner loop and moves on to process the outer loop. +  if (auto *CI = dyn_cast<ConstantInt>(Cond)) +    return getSCEV(CI->isOne() ? TrueVal : FalseVal); + +  // Try to match some simple smax or umax patterns. +  auto *ICI = dyn_cast<ICmpInst>(Cond); +  if (!ICI) +    return getUnknown(I); + +  Value *LHS = ICI->getOperand(0); +  Value *RHS = ICI->getOperand(1); + +  switch (ICI->getPredicate()) { +  case ICmpInst::ICMP_SLT: +  case ICmpInst::ICMP_SLE: +    std::swap(LHS, RHS); +    LLVM_FALLTHROUGH; +  case ICmpInst::ICMP_SGT: +  case ICmpInst::ICMP_SGE: +    // a >s b ? a+x : b+x  ->  smax(a, b)+x +    // a >s b ? b+x : a+x  ->  smin(a, b)+x +    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { +      const SCEV *LS = getNoopOrSignExtend(getSCEV(LHS), I->getType()); +      const SCEV *RS = getNoopOrSignExtend(getSCEV(RHS), I->getType()); +      const SCEV *LA = getSCEV(TrueVal); +      const SCEV *RA = getSCEV(FalseVal); +      const SCEV *LDiff = getMinusSCEV(LA, LS); +      const SCEV *RDiff = getMinusSCEV(RA, RS); +      if (LDiff == RDiff) +        return getAddExpr(getSMaxExpr(LS, RS), LDiff); +      LDiff = getMinusSCEV(LA, RS); +      RDiff = getMinusSCEV(RA, LS); +      if (LDiff == RDiff) +        return getAddExpr(getSMinExpr(LS, RS), LDiff); +    } +    break; +  case ICmpInst::ICMP_ULT: +  case ICmpInst::ICMP_ULE: +    std::swap(LHS, RHS); +    LLVM_FALLTHROUGH; +  case ICmpInst::ICMP_UGT: +  case ICmpInst::ICMP_UGE: +    // a >u b ? a+x : b+x  ->  umax(a, b)+x +    // a >u b ? b+x : a+x  ->  umin(a, b)+x +    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType())) { +      const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); +      const SCEV *RS = getNoopOrZeroExtend(getSCEV(RHS), I->getType()); +      const SCEV *LA = getSCEV(TrueVal); +      const SCEV *RA = getSCEV(FalseVal); +      const SCEV *LDiff = getMinusSCEV(LA, LS); +      const SCEV *RDiff = getMinusSCEV(RA, RS); +      if (LDiff == RDiff) +        return getAddExpr(getUMaxExpr(LS, RS), LDiff); +      LDiff = getMinusSCEV(LA, RS); +      RDiff = getMinusSCEV(RA, LS); +      if (LDiff == RDiff) +        return getAddExpr(getUMinExpr(LS, RS), LDiff); +    } +    break; +  case ICmpInst::ICMP_NE: +    // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x +    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && +        isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { +      const SCEV *One = getOne(I->getType()); +      const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); +      const SCEV *LA = getSCEV(TrueVal); +      const SCEV *RA = getSCEV(FalseVal); +      const SCEV *LDiff = getMinusSCEV(LA, LS); +      const SCEV *RDiff = getMinusSCEV(RA, One); +      if (LDiff == RDiff) +        return getAddExpr(getUMaxExpr(One, LS), LDiff); +    } +    break; +  case ICmpInst::ICMP_EQ: +    // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x +    if (getTypeSizeInBits(LHS->getType()) <= getTypeSizeInBits(I->getType()) && +        isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) { +      const SCEV *One = getOne(I->getType()); +      const SCEV *LS = getNoopOrZeroExtend(getSCEV(LHS), I->getType()); +      const SCEV *LA = getSCEV(TrueVal); +      const SCEV *RA = getSCEV(FalseVal); +      const SCEV *LDiff = getMinusSCEV(LA, One); +      const SCEV *RDiff = getMinusSCEV(RA, LS); +      if (LDiff == RDiff) +        return getAddExpr(getUMaxExpr(One, LS), LDiff); +    } +    break; +  default: +    break; +  } + +  return getUnknown(I); +} + +/// Expand GEP instructions into add and multiply operations. This allows them +/// to be analyzed by regular SCEV code. +const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { +  // Don't attempt to analyze GEPs over unsized objects. +  if (!GEP->getSourceElementType()->isSized()) +    return getUnknown(GEP); + +  SmallVector<const SCEV *, 4> IndexExprs; +  for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) +    IndexExprs.push_back(getSCEV(*Index)); +  return getGEPExpr(GEP, IndexExprs); +} + +uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { +  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) +    return C->getAPInt().countTrailingZeros(); + +  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S)) +    return std::min(GetMinTrailingZeros(T->getOperand()), +                    (uint32_t)getTypeSizeInBits(T->getType())); + +  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) { +    uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); +    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) +               ? getTypeSizeInBits(E->getType()) +               : OpRes; +  } + +  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) { +    uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); +    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) +               ? getTypeSizeInBits(E->getType()) +               : OpRes; +  } + +  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { +    // The result is the min of all operands results. +    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); +    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) +      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); +    return MinOpRes; +  } + +  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { +    // The result is the sum of all operands results. +    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); +    uint32_t BitWidth = getTypeSizeInBits(M->getType()); +    for (unsigned i = 1, e = M->getNumOperands(); +         SumOpRes != BitWidth && i != e; ++i) +      SumOpRes = +          std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); +    return SumOpRes; +  } + +  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { +    // The result is the min of all operands results. +    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0)); +    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i) +      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i))); +    return MinOpRes; +  } + +  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) { +    // The result is the min of all operands results. +    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); +    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) +      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); +    return MinOpRes; +  } + +  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) { +    // The result is the min of all operands results. +    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); +    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i) +      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i))); +    return MinOpRes; +  } + +  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { +    // For a SCEVUnknown, ask ValueTracking. +    KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); +    return Known.countMinTrailingZeros(); +  } + +  // SCEVUDivExpr +  return 0; +} + +uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { +  auto I = MinTrailingZerosCache.find(S); +  if (I != MinTrailingZerosCache.end()) +    return I->second; + +  uint32_t Result = GetMinTrailingZerosImpl(S); +  auto InsertPair = MinTrailingZerosCache.insert({S, Result}); +  assert(InsertPair.second && "Should insert a new key"); +  return InsertPair.first->second; +} + +/// Helper method to assign a range to V from metadata present in the IR. +static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { +  if (Instruction *I = dyn_cast<Instruction>(V)) +    if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) +      return getConstantRangeFromMetadata(*MD); + +  return None; +} + +/// Determine the range for a particular SCEV.  If SignHint is +/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges +/// with a "cleaner" unsigned (resp. signed) representation. +const ConstantRange & +ScalarEvolution::getRangeRef(const SCEV *S, +                             ScalarEvolution::RangeSignHint SignHint) { +  DenseMap<const SCEV *, ConstantRange> &Cache = +      SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges +                                                       : SignedRanges; + +  // See if we've computed this range already. +  DenseMap<const SCEV *, ConstantRange>::iterator I = Cache.find(S); +  if (I != Cache.end()) +    return I->second; + +  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) +    return setRange(C, SignHint, ConstantRange(C->getAPInt())); + +  unsigned BitWidth = getTypeSizeInBits(S->getType()); +  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); + +  // If the value has known zeros, the maximum value will have those known zeros +  // as well. +  uint32_t TZ = GetMinTrailingZeros(S); +  if (TZ != 0) { +    if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) +      ConservativeResult = +          ConstantRange(APInt::getMinValue(BitWidth), +                        APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); +    else +      ConservativeResult = ConstantRange( +          APInt::getSignedMinValue(BitWidth), +          APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); +  } + +  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { +    ConstantRange X = getRangeRef(Add->getOperand(0), SignHint); +    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) +      X = X.add(getRangeRef(Add->getOperand(i), SignHint)); +    return setRange(Add, SignHint, ConservativeResult.intersectWith(X)); +  } + +  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { +    ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint); +    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) +      X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint)); +    return setRange(Mul, SignHint, ConservativeResult.intersectWith(X)); +  } + +  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) { +    ConstantRange X = getRangeRef(SMax->getOperand(0), SignHint); +    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) +      X = X.smax(getRangeRef(SMax->getOperand(i), SignHint)); +    return setRange(SMax, SignHint, ConservativeResult.intersectWith(X)); +  } + +  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) { +    ConstantRange X = getRangeRef(UMax->getOperand(0), SignHint); +    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) +      X = X.umax(getRangeRef(UMax->getOperand(i), SignHint)); +    return setRange(UMax, SignHint, ConservativeResult.intersectWith(X)); +  } + +  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) { +    ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint); +    ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint); +    return setRange(UDiv, SignHint, +                    ConservativeResult.intersectWith(X.udiv(Y))); +  } + +  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) { +    ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint); +    return setRange(ZExt, SignHint, +                    ConservativeResult.intersectWith(X.zeroExtend(BitWidth))); +  } + +  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) { +    ConstantRange X = getRangeRef(SExt->getOperand(), SignHint); +    return setRange(SExt, SignHint, +                    ConservativeResult.intersectWith(X.signExtend(BitWidth))); +  } + +  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { +    ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint); +    return setRange(Trunc, SignHint, +                    ConservativeResult.intersectWith(X.truncate(BitWidth))); +  } + +  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { +    // If there's no unsigned wrap, the value will never be less than its +    // initial value. +    if (AddRec->hasNoUnsignedWrap()) +      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart())) +        if (!C->getValue()->isZero()) +          ConservativeResult = ConservativeResult.intersectWith( +              ConstantRange(C->getAPInt(), APInt(BitWidth, 0))); + +    // If there's no signed wrap, and all the operands have the same sign or +    // zero, the value won't ever change sign. +    if (AddRec->hasNoSignedWrap()) { +      bool AllNonNeg = true; +      bool AllNonPos = true; +      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { +        if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false; +        if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false; +      } +      if (AllNonNeg) +        ConservativeResult = ConservativeResult.intersectWith( +          ConstantRange(APInt(BitWidth, 0), +                        APInt::getSignedMinValue(BitWidth))); +      else if (AllNonPos) +        ConservativeResult = ConservativeResult.intersectWith( +          ConstantRange(APInt::getSignedMinValue(BitWidth), +                        APInt(BitWidth, 1))); +    } + +    // TODO: non-affine addrec +    if (AddRec->isAffine()) { +      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); +      if (!isa<SCEVCouldNotCompute>(MaxBECount) && +          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { +        auto RangeFromAffine = getRangeForAffineAR( +            AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, +            BitWidth); +        if (!RangeFromAffine.isFullSet()) +          ConservativeResult = +              ConservativeResult.intersectWith(RangeFromAffine); + +        auto RangeFromFactoring = getRangeViaFactoring( +            AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, +            BitWidth); +        if (!RangeFromFactoring.isFullSet()) +          ConservativeResult = +              ConservativeResult.intersectWith(RangeFromFactoring); +      } +    } + +    return setRange(AddRec, SignHint, std::move(ConservativeResult)); +  } + +  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { +    // Check if the IR explicitly contains !range metadata. +    Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); +    if (MDRange.hasValue()) +      ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + +    // Split here to avoid paying the compile-time cost of calling both +    // computeKnownBits and ComputeNumSignBits.  This restriction can be lifted +    // if needed. +    const DataLayout &DL = getDataLayout(); +    if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { +      // For a SCEVUnknown, ask ValueTracking. +      KnownBits Known = computeKnownBits(U->getValue(), DL, 0, &AC, nullptr, &DT); +      if (Known.One != ~Known.Zero + 1) +        ConservativeResult = +            ConservativeResult.intersectWith(ConstantRange(Known.One, +                                                           ~Known.Zero + 1)); +    } else { +      assert(SignHint == ScalarEvolution::HINT_RANGE_SIGNED && +             "generalize as needed!"); +      unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, &AC, nullptr, &DT); +      if (NS > 1) +        ConservativeResult = ConservativeResult.intersectWith( +            ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1), +                          APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); +    } + +    // A range of Phi is a subset of union of all ranges of its input. +    if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) { +      // Make sure that we do not run over cycled Phis. +      if (PendingPhiRanges.insert(Phi).second) { +        ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false); +        for (auto &Op : Phi->operands()) { +          auto OpRange = getRangeRef(getSCEV(Op), SignHint); +          RangeFromOps = RangeFromOps.unionWith(OpRange); +          // No point to continue if we already have a full set. +          if (RangeFromOps.isFullSet()) +            break; +        } +        ConservativeResult = ConservativeResult.intersectWith(RangeFromOps); +        bool Erased = PendingPhiRanges.erase(Phi); +        assert(Erased && "Failed to erase Phi properly?"); +        (void) Erased; +      } +    } + +    return setRange(U, SignHint, std::move(ConservativeResult)); +  } + +  return setRange(S, SignHint, std::move(ConservativeResult)); +} + +// Given a StartRange, Step and MaxBECount for an expression compute a range of +// values that the expression can take. Initially, the expression has a value +// from StartRange and then is changed by Step up to MaxBECount times. Signed +// argument defines if we treat Step as signed or unsigned. +static ConstantRange getRangeForAffineARHelper(APInt Step, +                                               const ConstantRange &StartRange, +                                               const APInt &MaxBECount, +                                               unsigned BitWidth, bool Signed) { +  // If either Step or MaxBECount is 0, then the expression won't change, and we +  // just need to return the initial range. +  if (Step == 0 || MaxBECount == 0) +    return StartRange; + +  // If we don't know anything about the initial value (i.e. StartRange is +  // FullRange), then we don't know anything about the final range either. +  // Return FullRange. +  if (StartRange.isFullSet()) +    return ConstantRange(BitWidth, /* isFullSet = */ true); + +  // If Step is signed and negative, then we use its absolute value, but we also +  // note that we're moving in the opposite direction. +  bool Descending = Signed && Step.isNegative(); + +  if (Signed) +    // This is correct even for INT_SMIN. Let's look at i8 to illustrate this: +    // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128. +    // This equations hold true due to the well-defined wrap-around behavior of +    // APInt. +    Step = Step.abs(); + +  // Check if Offset is more than full span of BitWidth. If it is, the +  // expression is guaranteed to overflow. +  if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount)) +    return ConstantRange(BitWidth, /* isFullSet = */ true); + +  // Offset is by how much the expression can change. Checks above guarantee no +  // overflow here. +  APInt Offset = Step * MaxBECount; + +  // Minimum value of the final range will match the minimal value of StartRange +  // if the expression is increasing and will be decreased by Offset otherwise. +  // Maximum value of the final range will match the maximal value of StartRange +  // if the expression is decreasing and will be increased by Offset otherwise. +  APInt StartLower = StartRange.getLower(); +  APInt StartUpper = StartRange.getUpper() - 1; +  APInt MovedBoundary = Descending ? (StartLower - std::move(Offset)) +                                   : (StartUpper + std::move(Offset)); + +  // It's possible that the new minimum/maximum value will fall into the initial +  // range (due to wrap around). This means that the expression can take any +  // value in this bitwidth, and we have to return full range. +  if (StartRange.contains(MovedBoundary)) +    return ConstantRange(BitWidth, /* isFullSet = */ true); + +  APInt NewLower = +      Descending ? std::move(MovedBoundary) : std::move(StartLower); +  APInt NewUpper = +      Descending ? std::move(StartUpper) : std::move(MovedBoundary); +  NewUpper += 1; + +  // If we end up with full range, return a proper full range. +  if (NewLower == NewUpper) +    return ConstantRange(BitWidth, /* isFullSet = */ true); + +  // No overflow detected, return [StartLower, StartUpper + Offset + 1) range. +  return ConstantRange(std::move(NewLower), std::move(NewUpper)); +} + +ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, +                                                   const SCEV *Step, +                                                   const SCEV *MaxBECount, +                                                   unsigned BitWidth) { +  assert(!isa<SCEVCouldNotCompute>(MaxBECount) && +         getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && +         "Precondition!"); + +  MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); +  APInt MaxBECountValue = getUnsignedRangeMax(MaxBECount); + +  // First, consider step signed. +  ConstantRange StartSRange = getSignedRange(Start); +  ConstantRange StepSRange = getSignedRange(Step); + +  // If Step can be both positive and negative, we need to find ranges for the +  // maximum absolute step values in both directions and union them. +  ConstantRange SR = +      getRangeForAffineARHelper(StepSRange.getSignedMin(), StartSRange, +                                MaxBECountValue, BitWidth, /* Signed = */ true); +  SR = SR.unionWith(getRangeForAffineARHelper(StepSRange.getSignedMax(), +                                              StartSRange, MaxBECountValue, +                                              BitWidth, /* Signed = */ true)); + +  // Next, consider step unsigned. +  ConstantRange UR = getRangeForAffineARHelper( +      getUnsignedRangeMax(Step), getUnsignedRange(Start), +      MaxBECountValue, BitWidth, /* Signed = */ false); + +  // Finally, intersect signed and unsigned ranges. +  return SR.intersectWith(UR); +} + +ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, +                                                    const SCEV *Step, +                                                    const SCEV *MaxBECount, +                                                    unsigned BitWidth) { +  //    RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) +  // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + +  struct SelectPattern { +    Value *Condition = nullptr; +    APInt TrueValue; +    APInt FalseValue; + +    explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, +                           const SCEV *S) { +      Optional<unsigned> CastOp; +      APInt Offset(BitWidth, 0); + +      assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && +             "Should be!"); + +      // Peel off a constant offset: +      if (auto *SA = dyn_cast<SCEVAddExpr>(S)) { +        // In the future we could consider being smarter here and handle +        // {Start+Step,+,Step} too. +        if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0))) +          return; + +        Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt(); +        S = SA->getOperand(1); +      } + +      // Peel off a cast operation +      if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) { +        CastOp = SCast->getSCEVType(); +        S = SCast->getOperand(); +      } + +      using namespace llvm::PatternMatch; + +      auto *SU = dyn_cast<SCEVUnknown>(S); +      const APInt *TrueVal, *FalseVal; +      if (!SU || +          !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), +                                          m_APInt(FalseVal)))) { +        Condition = nullptr; +        return; +      } + +      TrueValue = *TrueVal; +      FalseValue = *FalseVal; + +      // Re-apply the cast we peeled off earlier +      if (CastOp.hasValue()) +        switch (*CastOp) { +        default: +          llvm_unreachable("Unknown SCEV cast type!"); + +        case scTruncate: +          TrueValue = TrueValue.trunc(BitWidth); +          FalseValue = FalseValue.trunc(BitWidth); +          break; +        case scZeroExtend: +          TrueValue = TrueValue.zext(BitWidth); +          FalseValue = FalseValue.zext(BitWidth); +          break; +        case scSignExtend: +          TrueValue = TrueValue.sext(BitWidth); +          FalseValue = FalseValue.sext(BitWidth); +          break; +        } + +      // Re-apply the constant offset we peeled off earlier +      TrueValue += Offset; +      FalseValue += Offset; +    } + +    bool isRecognized() { return Condition != nullptr; } +  }; + +  SelectPattern StartPattern(*this, BitWidth, Start); +  if (!StartPattern.isRecognized()) +    return ConstantRange(BitWidth, /* isFullSet = */ true); + +  SelectPattern StepPattern(*this, BitWidth, Step); +  if (!StepPattern.isRecognized()) +    return ConstantRange(BitWidth, /* isFullSet = */ true); + +  if (StartPattern.Condition != StepPattern.Condition) { +    // We don't handle this case today; but we could, by considering four +    // possibilities below instead of two. I'm not sure if there are cases where +    // that will help over what getRange already does, though. +    return ConstantRange(BitWidth, /* isFullSet = */ true); +  } + +  // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to +  // construct arbitrary general SCEV expressions here.  This function is called +  // from deep in the call stack, and calling getSCEV (on a sext instruction, +  // say) can end up caching a suboptimal value. + +  // FIXME: without the explicit `this` receiver below, MSVC errors out with +  // C2352 and C2512 (otherwise it isn't needed). + +  const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); +  const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); +  const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); +  const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); + +  ConstantRange TrueRange = +      this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); +  ConstantRange FalseRange = +      this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); + +  return TrueRange.unionWith(FalseRange); +} + +SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { +  if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap; +  const BinaryOperator *BinOp = cast<BinaryOperator>(V); + +  // Return early if there are no flags to propagate to the SCEV. +  SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; +  if (BinOp->hasNoUnsignedWrap()) +    Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); +  if (BinOp->hasNoSignedWrap()) +    Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); +  if (Flags == SCEV::FlagAnyWrap) +    return SCEV::FlagAnyWrap; + +  return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; +} + +bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { +  // Here we check that I is in the header of the innermost loop containing I, +  // since we only deal with instructions in the loop header. The actual loop we +  // need to check later will come from an add recurrence, but getting that +  // requires computing the SCEV of the operands, which can be expensive. This +  // check we can do cheaply to rule out some cases early. +  Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent()); +  if (InnermostContainingLoop == nullptr || +      InnermostContainingLoop->getHeader() != I->getParent()) +    return false; + +  // Only proceed if we can prove that I does not yield poison. +  if (!programUndefinedIfFullPoison(I)) +    return false; + +  // At this point we know that if I is executed, then it does not wrap +  // according to at least one of NSW or NUW. If I is not executed, then we do +  // not know if the calculation that I represents would wrap. Multiple +  // instructions can map to the same SCEV. If we apply NSW or NUW from I to +  // the SCEV, we must guarantee no wrapping for that SCEV also when it is +  // derived from other instructions that map to the same SCEV. We cannot make +  // that guarantee for cases where I is not executed. So we need to find the +  // loop that I is considered in relation to and prove that I is executed for +  // every iteration of that loop. That implies that the value that I +  // calculates does not wrap anywhere in the loop, so then we can apply the +  // flags to the SCEV. +  // +  // We check isLoopInvariant to disambiguate in case we are adding recurrences +  // from different loops, so that we know which loop to prove that I is +  // executed in. +  for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) { +    // I could be an extractvalue from a call to an overflow intrinsic. +    // TODO: We can do better here in some cases. +    if (!isSCEVable(I->getOperand(OpIndex)->getType())) +      return false; +    const SCEV *Op = getSCEV(I->getOperand(OpIndex)); +    if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { +      bool AllOtherOpsLoopInvariant = true; +      for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands(); +           ++OtherOpIndex) { +        if (OtherOpIndex != OpIndex) { +          const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex)); +          if (!isLoopInvariant(OtherOp, AddRec->getLoop())) { +            AllOtherOpsLoopInvariant = false; +            break; +          } +        } +      } +      if (AllOtherOpsLoopInvariant && +          isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop())) +        return true; +    } +  } +  return false; +} + +bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { +  // If we know that \c I can never be poison period, then that's enough. +  if (isSCEVExprNeverPoison(I)) +    return true; + +  // For an add recurrence specifically, we assume that infinite loops without +  // side effects are undefined behavior, and then reason as follows: +  // +  // If the add recurrence is poison in any iteration, it is poison on all +  // future iterations (since incrementing poison yields poison). If the result +  // of the add recurrence is fed into the loop latch condition and the loop +  // does not contain any throws or exiting blocks other than the latch, we now +  // have the ability to "choose" whether the backedge is taken or not (by +  // choosing a sufficiently evil value for the poison feeding into the branch) +  // for every iteration including and after the one in which \p I first became +  // poison.  There are two possibilities (let's call the iteration in which \p +  // I first became poison as K): +  // +  //  1. In the set of iterations including and after K, the loop body executes +  //     no side effects.  In this case executing the backege an infinte number +  //     of times will yield undefined behavior. +  // +  //  2. In the set of iterations including and after K, the loop body executes +  //     at least one side effect.  In this case, that specific instance of side +  //     effect is control dependent on poison, which also yields undefined +  //     behavior. + +  auto *ExitingBB = L->getExitingBlock(); +  auto *LatchBB = L->getLoopLatch(); +  if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) +    return false; + +  SmallPtrSet<const Instruction *, 16> Pushed; +  SmallVector<const Instruction *, 8> PoisonStack; + +  // We start by assuming \c I, the post-inc add recurrence, is poison.  Only +  // things that are known to be fully poison under that assumption go on the +  // PoisonStack. +  Pushed.insert(I); +  PoisonStack.push_back(I); + +  bool LatchControlDependentOnPoison = false; +  while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { +    const Instruction *Poison = PoisonStack.pop_back_val(); + +    for (auto *PoisonUser : Poison->users()) { +      if (propagatesFullPoison(cast<Instruction>(PoisonUser))) { +        if (Pushed.insert(cast<Instruction>(PoisonUser)).second) +          PoisonStack.push_back(cast<Instruction>(PoisonUser)); +      } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { +        assert(BI->isConditional() && "Only possibility!"); +        if (BI->getParent() == LatchBB) { +          LatchControlDependentOnPoison = true; +          break; +        } +      } +    } +  } + +  return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); +} + +ScalarEvolution::LoopProperties +ScalarEvolution::getLoopProperties(const Loop *L) { +  using LoopProperties = ScalarEvolution::LoopProperties; + +  auto Itr = LoopPropertiesCache.find(L); +  if (Itr == LoopPropertiesCache.end()) { +    auto HasSideEffects = [](Instruction *I) { +      if (auto *SI = dyn_cast<StoreInst>(I)) +        return !SI->isSimple(); + +      return I->mayHaveSideEffects(); +    }; + +    LoopProperties LP = {/* HasNoAbnormalExits */ true, +                         /*HasNoSideEffects*/ true}; + +    for (auto *BB : L->getBlocks()) +      for (auto &I : *BB) { +        if (!isGuaranteedToTransferExecutionToSuccessor(&I)) +          LP.HasNoAbnormalExits = false; +        if (HasSideEffects(&I)) +          LP.HasNoSideEffects = false; +        if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects) +          break; // We're already as pessimistic as we can get. +      } + +    auto InsertPair = LoopPropertiesCache.insert({L, LP}); +    assert(InsertPair.second && "We just checked!"); +    Itr = InsertPair.first; +  } + +  return Itr->second; +} + +const SCEV *ScalarEvolution::createSCEV(Value *V) { +  if (!isSCEVable(V->getType())) +    return getUnknown(V); + +  if (Instruction *I = dyn_cast<Instruction>(V)) { +    // Don't attempt to analyze instructions in blocks that aren't +    // reachable. Such instructions don't matter, and they aren't required +    // to obey basic rules for definitions dominating uses which this +    // analysis depends on. +    if (!DT.isReachableFromEntry(I->getParent())) +      return getUnknown(V); +  } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) +    return getConstant(CI); +  else if (isa<ConstantPointerNull>(V)) +    return getZero(V->getType()); +  else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) +    return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); +  else if (!isa<ConstantExpr>(V)) +    return getUnknown(V); + +  Operator *U = cast<Operator>(V); +  if (auto BO = MatchBinaryOp(U, DT)) { +    switch (BO->Opcode) { +    case Instruction::Add: { +      // The simple thing to do would be to just call getSCEV on both operands +      // and call getAddExpr with the result. However if we're looking at a +      // bunch of things all added together, this can be quite inefficient, +      // because it leads to N-1 getAddExpr calls for N ultimate operands. +      // Instead, gather up all the operands and make a single getAddExpr call. +      // LLVM IR canonical form means we need only traverse the left operands. +      SmallVector<const SCEV *, 4> AddOps; +      do { +        if (BO->Op) { +          if (auto *OpSCEV = getExistingSCEV(BO->Op)) { +            AddOps.push_back(OpSCEV); +            break; +          } + +          // If a NUW or NSW flag can be applied to the SCEV for this +          // addition, then compute the SCEV for this addition by itself +          // with a separate call to getAddExpr. We need to do that +          // instead of pushing the operands of the addition onto AddOps, +          // since the flags are only known to apply to this particular +          // addition - they may not apply to other additions that can be +          // formed with operands from AddOps. +          const SCEV *RHS = getSCEV(BO->RHS); +          SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); +          if (Flags != SCEV::FlagAnyWrap) { +            const SCEV *LHS = getSCEV(BO->LHS); +            if (BO->Opcode == Instruction::Sub) +              AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); +            else +              AddOps.push_back(getAddExpr(LHS, RHS, Flags)); +            break; +          } +        } + +        if (BO->Opcode == Instruction::Sub) +          AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); +        else +          AddOps.push_back(getSCEV(BO->RHS)); + +        auto NewBO = MatchBinaryOp(BO->LHS, DT); +        if (!NewBO || (NewBO->Opcode != Instruction::Add && +                       NewBO->Opcode != Instruction::Sub)) { +          AddOps.push_back(getSCEV(BO->LHS)); +          break; +        } +        BO = NewBO; +      } while (true); + +      return getAddExpr(AddOps); +    } + +    case Instruction::Mul: { +      SmallVector<const SCEV *, 4> MulOps; +      do { +        if (BO->Op) { +          if (auto *OpSCEV = getExistingSCEV(BO->Op)) { +            MulOps.push_back(OpSCEV); +            break; +          } + +          SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); +          if (Flags != SCEV::FlagAnyWrap) { +            MulOps.push_back( +                getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); +            break; +          } +        } + +        MulOps.push_back(getSCEV(BO->RHS)); +        auto NewBO = MatchBinaryOp(BO->LHS, DT); +        if (!NewBO || NewBO->Opcode != Instruction::Mul) { +          MulOps.push_back(getSCEV(BO->LHS)); +          break; +        } +        BO = NewBO; +      } while (true); + +      return getMulExpr(MulOps); +    } +    case Instruction::UDiv: +      return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); +    case Instruction::URem: +      return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); +    case Instruction::Sub: { +      SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; +      if (BO->Op) +        Flags = getNoWrapFlagsFromUB(BO->Op); +      return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); +    } +    case Instruction::And: +      // For an expression like x&255 that merely masks off the high bits, +      // use zext(trunc(x)) as the SCEV expression. +      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { +        if (CI->isZero()) +          return getSCEV(BO->RHS); +        if (CI->isMinusOne()) +          return getSCEV(BO->LHS); +        const APInt &A = CI->getValue(); + +        // Instcombine's ShrinkDemandedConstant may strip bits out of +        // constants, obscuring what would otherwise be a low-bits mask. +        // Use computeKnownBits to compute what ShrinkDemandedConstant +        // knew about to reconstruct a low-bits mask value. +        unsigned LZ = A.countLeadingZeros(); +        unsigned TZ = A.countTrailingZeros(); +        unsigned BitWidth = A.getBitWidth(); +        KnownBits Known(BitWidth); +        computeKnownBits(BO->LHS, Known, getDataLayout(), +                         0, &AC, nullptr, &DT); + +        APInt EffectiveMask = +            APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); +        if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) { +          const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ)); +          const SCEV *LHS = getSCEV(BO->LHS); +          const SCEV *ShiftedLHS = nullptr; +          if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) { +            if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) { +              // For an expression like (x * 8) & 8, simplify the multiply. +              unsigned MulZeros = OpC->getAPInt().countTrailingZeros(); +              unsigned GCD = std::min(MulZeros, TZ); +              APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD); +              SmallVector<const SCEV*, 4> MulOps; +              MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD))); +              MulOps.append(LHSMul->op_begin() + 1, LHSMul->op_end()); +              auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags()); +              ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt)); +            } +          } +          if (!ShiftedLHS) +            ShiftedLHS = getUDivExpr(LHS, MulCount); +          return getMulExpr( +              getZeroExtendExpr( +                  getTruncateExpr(ShiftedLHS, +                      IntegerType::get(getContext(), BitWidth - LZ - TZ)), +                  BO->LHS->getType()), +              MulCount); +        } +      } +      break; + +    case Instruction::Or: +      // If the RHS of the Or is a constant, we may have something like: +      // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop +      // optimizations will transparently handle this case. +      // +      // In order for this transformation to be safe, the LHS must be of the +      // form X*(2^n) and the Or constant must be less than 2^n. +      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { +        const SCEV *LHS = getSCEV(BO->LHS); +        const APInt &CIVal = CI->getValue(); +        if (GetMinTrailingZeros(LHS) >= +            (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { +          // Build a plain add SCEV. +          const SCEV *S = getAddExpr(LHS, getSCEV(CI)); +          // If the LHS of the add was an addrec and it has no-wrap flags, +          // transfer the no-wrap flags, since an or won't introduce a wrap. +          if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { +            const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); +            const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( +                OldAR->getNoWrapFlags()); +          } +          return S; +        } +      } +      break; + +    case Instruction::Xor: +      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { +        // If the RHS of xor is -1, then this is a not operation. +        if (CI->isMinusOne()) +          return getNotSCEV(getSCEV(BO->LHS)); + +        // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. +        // This is a variant of the check for xor with -1, and it handles +        // the case where instcombine has trimmed non-demanded bits out +        // of an xor with -1. +        if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS)) +          if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1))) +            if (LBO->getOpcode() == Instruction::And && +                LCI->getValue() == CI->getValue()) +              if (const SCEVZeroExtendExpr *Z = +                      dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) { +                Type *UTy = BO->LHS->getType(); +                const SCEV *Z0 = Z->getOperand(); +                Type *Z0Ty = Z0->getType(); +                unsigned Z0TySize = getTypeSizeInBits(Z0Ty); + +                // If C is a low-bits mask, the zero extend is serving to +                // mask off the high bits. Complement the operand and +                // re-apply the zext. +                if (CI->getValue().isMask(Z0TySize)) +                  return getZeroExtendExpr(getNotSCEV(Z0), UTy); + +                // If C is a single bit, it may be in the sign-bit position +                // before the zero-extend. In this case, represent the xor +                // using an add, which is equivalent, and re-apply the zext. +                APInt Trunc = CI->getValue().trunc(Z0TySize); +                if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && +                    Trunc.isSignMask()) +                  return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), +                                           UTy); +              } +      } +      break; + +    case Instruction::Shl: +      // Turn shift left of a constant amount into a multiply. +      if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { +        uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); + +        // If the shift count is not less than the bitwidth, the result of +        // the shift is undefined. Don't try to analyze it, because the +        // resolution chosen here may differ from the resolution chosen in +        // other parts of the compiler. +        if (SA->getValue().uge(BitWidth)) +          break; + +        // It is currently not resolved how to interpret NSW for left +        // shift by BitWidth - 1, so we avoid applying flags in that +        // case. Remove this check (or this comment) once the situation +        // is resolved. See +        // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html +        // and http://reviews.llvm.org/D8890 . +        auto Flags = SCEV::FlagAnyWrap; +        if (BO->Op && SA->getValue().ult(BitWidth - 1)) +          Flags = getNoWrapFlagsFromUB(BO->Op); + +        Constant *X = ConstantInt::get( +            getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); +        return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); +      } +      break; + +    case Instruction::AShr: { +      // AShr X, C, where C is a constant. +      ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS); +      if (!CI) +        break; + +      Type *OuterTy = BO->LHS->getType(); +      uint64_t BitWidth = getTypeSizeInBits(OuterTy); +      // If the shift count is not less than the bitwidth, the result of +      // the shift is undefined. Don't try to analyze it, because the +      // resolution chosen here may differ from the resolution chosen in +      // other parts of the compiler. +      if (CI->getValue().uge(BitWidth)) +        break; + +      if (CI->isZero()) +        return getSCEV(BO->LHS); // shift by zero --> noop + +      uint64_t AShrAmt = CI->getZExtValue(); +      Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt); + +      Operator *L = dyn_cast<Operator>(BO->LHS); +      if (L && L->getOpcode() == Instruction::Shl) { +        // X = Shl A, n +        // Y = AShr X, m +        // Both n and m are constant. + +        const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0)); +        if (L->getOperand(1) == BO->RHS) +          // For a two-shift sext-inreg, i.e. n = m, +          // use sext(trunc(x)) as the SCEV expression. +          return getSignExtendExpr( +              getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy); + +        ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1)); +        if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) { +          uint64_t ShlAmt = ShlAmtCI->getZExtValue(); +          if (ShlAmt > AShrAmt) { +            // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV +            // expression. We already checked that ShlAmt < BitWidth, so +            // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as +            // ShlAmt - AShrAmt < Amt. +            APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, +                                            ShlAmt - AShrAmt); +            return getSignExtendExpr( +                getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy), +                getConstant(Mul)), OuterTy); +          } +        } +      } +      break; +    } +    } +  } + +  switch (U->getOpcode()) { +  case Instruction::Trunc: +    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); + +  case Instruction::ZExt: +    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + +  case Instruction::SExt: +    if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) { +      // The NSW flag of a subtract does not always survive the conversion to +      // A + (-1)*B.  By pushing sign extension onto its operands we are much +      // more likely to preserve NSW and allow later AddRec optimisations. +      // +      // NOTE: This is effectively duplicating this logic from getSignExtend: +      //   sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> +      // but by that point the NSW information has potentially been lost. +      if (BO->Opcode == Instruction::Sub && BO->IsNSW) { +        Type *Ty = U->getType(); +        auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); +        auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); +        return getMinusSCEV(V1, V2, SCEV::FlagNSW); +      } +    } +    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); + +  case Instruction::BitCast: +    // BitCasts are no-op casts so we just eliminate the cast. +    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) +      return getSCEV(U->getOperand(0)); +    break; + +  // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can +  // lead to pointer expressions which cannot safely be expanded to GEPs, +  // because ScalarEvolution doesn't respect the GEP aliasing rules when +  // simplifying integer expressions. + +  case Instruction::GetElementPtr: +    return createNodeForGEP(cast<GEPOperator>(U)); + +  case Instruction::PHI: +    return createNodeForPHI(cast<PHINode>(U)); + +  case Instruction::Select: +    // U can also be a select constant expr, which let fall through.  Since +    // createNodeForSelect only works for a condition that is an `ICmpInst`, and +    // constant expressions cannot have instructions as operands, we'd have +    // returned getUnknown for a select constant expressions anyway. +    if (isa<Instruction>(U)) +      return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0), +                                      U->getOperand(1), U->getOperand(2)); +    break; + +  case Instruction::Call: +  case Instruction::Invoke: +    if (Value *RV = CallSite(U).getReturnedArgOperand()) +      return getSCEV(RV); +    break; +  } + +  return getUnknown(V); +} + +//===----------------------------------------------------------------------===// +//                   Iteration Count Computation Code +// + +static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { +  if (!ExitCount) +    return 0; + +  ConstantInt *ExitConst = ExitCount->getValue(); + +  // Guard against huge trip counts. +  if (ExitConst->getValue().getActiveBits() > 32) +    return 0; + +  // In case of integer overflow, this returns 0, which is correct. +  return ((unsigned)ExitConst->getZExtValue()) + 1; +} + +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) { +  if (BasicBlock *ExitingBB = L->getExitingBlock()) +    return getSmallConstantTripCount(L, ExitingBB); + +  // No trip count information for multiple exits. +  return 0; +} + +unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L, +                                                    BasicBlock *ExitingBlock) { +  assert(ExitingBlock && "Must pass a non-null exiting block!"); +  assert(L->isLoopExiting(ExitingBlock) && +         "Exiting block must actually branch out of the loop!"); +  const SCEVConstant *ExitCount = +      dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); +  return getConstantTripCount(ExitCount); +} + +unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) { +  const auto *MaxExitCount = +      dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L)); +  return getConstantTripCount(MaxExitCount); +} + +unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { +  if (BasicBlock *ExitingBB = L->getExitingBlock()) +    return getSmallConstantTripMultiple(L, ExitingBB); + +  // No trip multiple information for multiple exits. +  return 0; +} + +/// Returns the largest constant divisor of the trip count of this loop as a +/// normal unsigned value, if possible. This means that the actual trip count is +/// always a multiple of the returned value (don't forget the trip count could +/// very well be zero as well!). +/// +/// Returns 1 if the trip count is unknown or not guaranteed to be the +/// multiple of a constant (which is also the case if the trip count is simply +/// constant, use getSmallConstantTripCount for that case), Will also return 1 +/// if the trip count is very large (>= 2^32). +/// +/// As explained in the comments for getSmallConstantTripCount, this assumes +/// that control exits the loop via ExitingBlock. +unsigned +ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, +                                              BasicBlock *ExitingBlock) { +  assert(ExitingBlock && "Must pass a non-null exiting block!"); +  assert(L->isLoopExiting(ExitingBlock) && +         "Exiting block must actually branch out of the loop!"); +  const SCEV *ExitCount = getExitCount(L, ExitingBlock); +  if (ExitCount == getCouldNotCompute()) +    return 1; + +  // Get the trip count from the BE count by adding 1. +  const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType())); + +  const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr); +  if (!TC) +    // Attempt to factor more general cases. Returns the greatest power of +    // two divisor. If overflow happens, the trip count expression is still +    // divisible by the greatest power of 2 divisor returned. +    return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); + +  ConstantInt *Result = TC->getValue(); + +  // Guard against huge trip counts (this requires checking +  // for zero to handle the case where the trip count == -1 and the +  // addition wraps). +  if (!Result || Result->getValue().getActiveBits() > 32 || +      Result->getValue().getActiveBits() == 0) +    return 1; + +  return (unsigned)Result->getZExtValue(); +} + +/// Get the expression for the number of loop iterations for which this loop is +/// guaranteed not to exit via ExitingBlock. Otherwise return +/// SCEVCouldNotCompute. +const SCEV *ScalarEvolution::getExitCount(const Loop *L, +                                          BasicBlock *ExitingBlock) { +  return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); +} + +const SCEV * +ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, +                                                 SCEVUnionPredicate &Preds) { +  return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); +} + +const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { +  return getBackedgeTakenInfo(L).getExact(L, this); +} + +/// Similar to getBackedgeTakenCount, except return the least SCEV value that is +/// known never to be less than the actual backedge taken count. +const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { +  return getBackedgeTakenInfo(L).getMax(this); +} + +bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) { +  return getBackedgeTakenInfo(L).isMaxOrZero(this); +} + +/// Push PHI nodes in the header of the given loop onto the given Worklist. +static void +PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { +  BasicBlock *Header = L->getHeader(); + +  // Push all Loop-header PHIs onto the Worklist stack. +  for (PHINode &PN : Header->phis()) +    Worklist.push_back(&PN); +} + +const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { +  auto &BTI = getBackedgeTakenInfo(L); +  if (BTI.hasFullInfo()) +    return BTI; + +  auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); + +  if (!Pair.second) +    return Pair.first->second; + +  BackedgeTakenInfo Result = +      computeBackedgeTakenCount(L, /*AllowPredicates=*/true); + +  return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); +} + +const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { +  // Initially insert an invalid entry for this loop. If the insertion +  // succeeds, proceed to actually compute a backedge-taken count and +  // update the value. The temporary CouldNotCompute value tells SCEV +  // code elsewhere that it shouldn't attempt to request a new +  // backedge-taken count, which could result in infinite recursion. +  std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = +      BackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); +  if (!Pair.second) +    return Pair.first->second; + +  // computeBackedgeTakenCount may allocate memory for its result. Inserting it +  // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result +  // must be cleared in this scope. +  BackedgeTakenInfo Result = computeBackedgeTakenCount(L); + +  // In product build, there are no usage of statistic. +  (void)NumTripCountsComputed; +  (void)NumTripCountsNotComputed; +#if LLVM_ENABLE_STATS || !defined(NDEBUG) +  const SCEV *BEExact = Result.getExact(L, this); +  if (BEExact != getCouldNotCompute()) { +    assert(isLoopInvariant(BEExact, L) && +           isLoopInvariant(Result.getMax(this), L) && +           "Computed backedge-taken count isn't loop invariant for loop!"); +    ++NumTripCountsComputed; +  } +  else if (Result.getMax(this) == getCouldNotCompute() && +           isa<PHINode>(L->getHeader()->begin())) { +    // Only count loops that have phi nodes as not being computable. +    ++NumTripCountsNotComputed; +  } +#endif // LLVM_ENABLE_STATS || !defined(NDEBUG) + +  // Now that we know more about the trip count for this loop, forget any +  // existing SCEV values for PHI nodes in this loop since they are only +  // conservative estimates made without the benefit of trip count +  // information. This is similar to the code in forgetLoop, except that +  // it handles SCEVUnknown PHI nodes specially. +  if (Result.hasAnyInfo()) { +    SmallVector<Instruction *, 16> Worklist; +    PushLoopPHIs(L, Worklist); + +    SmallPtrSet<Instruction *, 8> Discovered; +    while (!Worklist.empty()) { +      Instruction *I = Worklist.pop_back_val(); + +      ValueExprMapType::iterator It = +        ValueExprMap.find_as(static_cast<Value *>(I)); +      if (It != ValueExprMap.end()) { +        const SCEV *Old = It->second; + +        // SCEVUnknown for a PHI either means that it has an unrecognized +        // structure, or it's a PHI that's in the progress of being computed +        // by createNodeForPHI.  In the former case, additional loop trip +        // count information isn't going to change anything. In the later +        // case, createNodeForPHI will perform the necessary updates on its +        // own when it gets to that point. +        if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) { +          eraseValueFromMap(It->first); +          forgetMemoizedResults(Old); +        } +        if (PHINode *PN = dyn_cast<PHINode>(I)) +          ConstantEvolutionLoopExitValue.erase(PN); +      } + +      // Since we don't need to invalidate anything for correctness and we're +      // only invalidating to make SCEV's results more precise, we get to stop +      // early to avoid invalidating too much.  This is especially important in +      // cases like: +      // +      //   %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node +      // loop0: +      //   %pn0 = phi +      //   ... +      // loop1: +      //   %pn1 = phi +      //   ... +      // +      // where both loop0 and loop1's backedge taken count uses the SCEV +      // expression for %v.  If we don't have the early stop below then in cases +      // like the above, getBackedgeTakenInfo(loop1) will clear out the trip +      // count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip +      // count for loop1, effectively nullifying SCEV's trip count cache. +      for (auto *U : I->users()) +        if (auto *I = dyn_cast<Instruction>(U)) { +          auto *LoopForUser = LI.getLoopFor(I->getParent()); +          if (LoopForUser && L->contains(LoopForUser) && +              Discovered.insert(I).second) +            Worklist.push_back(I); +        } +    } +  } + +  // Re-lookup the insert position, since the call to +  // computeBackedgeTakenCount above could result in a +  // recusive call to getBackedgeTakenInfo (on a different +  // loop), which would invalidate the iterator computed +  // earlier. +  return BackedgeTakenCounts.find(L)->second = std::move(Result); +} + +void ScalarEvolution::forgetLoop(const Loop *L) { +  // Drop any stored trip count value. +  auto RemoveLoopFromBackedgeMap = +      [](DenseMap<const Loop *, BackedgeTakenInfo> &Map, const Loop *L) { +        auto BTCPos = Map.find(L); +        if (BTCPos != Map.end()) { +          BTCPos->second.clear(); +          Map.erase(BTCPos); +        } +      }; + +  SmallVector<const Loop *, 16> LoopWorklist(1, L); +  SmallVector<Instruction *, 32> Worklist; +  SmallPtrSet<Instruction *, 16> Visited; + +  // Iterate over all the loops and sub-loops to drop SCEV information. +  while (!LoopWorklist.empty()) { +    auto *CurrL = LoopWorklist.pop_back_val(); + +    RemoveLoopFromBackedgeMap(BackedgeTakenCounts, CurrL); +    RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts, CurrL); + +    // Drop information about predicated SCEV rewrites for this loop. +    for (auto I = PredicatedSCEVRewrites.begin(); +         I != PredicatedSCEVRewrites.end();) { +      std::pair<const SCEV *, const Loop *> Entry = I->first; +      if (Entry.second == CurrL) +        PredicatedSCEVRewrites.erase(I++); +      else +        ++I; +    } + +    auto LoopUsersItr = LoopUsers.find(CurrL); +    if (LoopUsersItr != LoopUsers.end()) { +      for (auto *S : LoopUsersItr->second) +        forgetMemoizedResults(S); +      LoopUsers.erase(LoopUsersItr); +    } + +    // Drop information about expressions based on loop-header PHIs. +    PushLoopPHIs(CurrL, Worklist); + +    while (!Worklist.empty()) { +      Instruction *I = Worklist.pop_back_val(); +      if (!Visited.insert(I).second) +        continue; + +      ValueExprMapType::iterator It = +          ValueExprMap.find_as(static_cast<Value *>(I)); +      if (It != ValueExprMap.end()) { +        eraseValueFromMap(It->first); +        forgetMemoizedResults(It->second); +        if (PHINode *PN = dyn_cast<PHINode>(I)) +          ConstantEvolutionLoopExitValue.erase(PN); +      } + +      PushDefUseChildren(I, Worklist); +    } + +    LoopPropertiesCache.erase(CurrL); +    // Forget all contained loops too, to avoid dangling entries in the +    // ValuesAtScopes map. +    LoopWorklist.append(CurrL->begin(), CurrL->end()); +  } +} + +void ScalarEvolution::forgetTopmostLoop(const Loop *L) { +  while (Loop *Parent = L->getParentLoop()) +    L = Parent; +  forgetLoop(L); +} + +void ScalarEvolution::forgetValue(Value *V) { +  Instruction *I = dyn_cast<Instruction>(V); +  if (!I) return; + +  // Drop information about expressions based on loop-header PHIs. +  SmallVector<Instruction *, 16> Worklist; +  Worklist.push_back(I); + +  SmallPtrSet<Instruction *, 8> Visited; +  while (!Worklist.empty()) { +    I = Worklist.pop_back_val(); +    if (!Visited.insert(I).second) +      continue; + +    ValueExprMapType::iterator It = +      ValueExprMap.find_as(static_cast<Value *>(I)); +    if (It != ValueExprMap.end()) { +      eraseValueFromMap(It->first); +      forgetMemoizedResults(It->second); +      if (PHINode *PN = dyn_cast<PHINode>(I)) +        ConstantEvolutionLoopExitValue.erase(PN); +    } + +    PushDefUseChildren(I, Worklist); +  } +} + +/// Get the exact loop backedge taken count considering all loop exits. A +/// computable result can only be returned for loops with all exiting blocks +/// dominating the latch. howFarToZero assumes that the limit of each loop test +/// is never skipped. This is a valid assumption as long as the loop exits via +/// that test. For precise results, it is the caller's responsibility to specify +/// the relevant loop exiting block using getExact(ExitingBlock, SE). +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE, +                                             SCEVUnionPredicate *Preds) const { +  // If any exits were not computable, the loop is not computable. +  if (!isComplete() || ExitNotTaken.empty()) +    return SE->getCouldNotCompute(); + +  const BasicBlock *Latch = L->getLoopLatch(); +  // All exiting blocks we have collected must dominate the only backedge. +  if (!Latch) +    return SE->getCouldNotCompute(); + +  // All exiting blocks we have gathered dominate loop's latch, so exact trip +  // count is simply a minimum out of all these calculated exit counts. +  SmallVector<const SCEV *, 2> Ops; +  for (auto &ENT : ExitNotTaken) { +    const SCEV *BECount = ENT.ExactNotTaken; +    assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!"); +    assert(SE->DT.dominates(ENT.ExitingBlock, Latch) && +           "We should only have known counts for exiting blocks that dominate " +           "latch!"); + +    Ops.push_back(BECount); + +    if (Preds && !ENT.hasAlwaysTruePredicate()) +      Preds->add(ENT.Predicate.get()); + +    assert((Preds || ENT.hasAlwaysTruePredicate()) && +           "Predicate should be always true!"); +  } + +  return SE->getUMinFromMismatchedTypes(Ops); +} + +/// Get the exact not taken count for this loop exit. +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, +                                             ScalarEvolution *SE) const { +  for (auto &ENT : ExitNotTaken) +    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) +      return ENT.ExactNotTaken; + +  return SE->getCouldNotCompute(); +} + +/// getMax - Get the max backedge taken count for the loop. +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { +  auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { +    return !ENT.hasAlwaysTruePredicate(); +  }; + +  if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax()) +    return SE->getCouldNotCompute(); + +  assert((isa<SCEVCouldNotCompute>(getMax()) || isa<SCEVConstant>(getMax())) && +         "No point in having a non-constant max backedge taken count!"); +  return getMax(); +} + +bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const { +  auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { +    return !ENT.hasAlwaysTruePredicate(); +  }; +  return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); +} + +bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, +                                                    ScalarEvolution *SE) const { +  if (getMax() && getMax() != SE->getCouldNotCompute() && +      SE->hasOperand(getMax(), S)) +    return true; + +  for (auto &ENT : ExitNotTaken) +    if (ENT.ExactNotTaken != SE->getCouldNotCompute() && +        SE->hasOperand(ENT.ExactNotTaken, S)) +      return true; + +  return false; +} + +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) +    : ExactNotTaken(E), MaxNotTaken(E) { +  assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || +          isa<SCEVConstant>(MaxNotTaken)) && +         "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit( +    const SCEV *E, const SCEV *M, bool MaxOrZero, +    ArrayRef<const SmallPtrSetImpl<const SCEVPredicate *> *> PredSetList) +    : ExactNotTaken(E), MaxNotTaken(M), MaxOrZero(MaxOrZero) { +  assert((isa<SCEVCouldNotCompute>(ExactNotTaken) || +          !isa<SCEVCouldNotCompute>(MaxNotTaken)) && +         "Exact is not allowed to be less precise than Max"); +  assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || +          isa<SCEVConstant>(MaxNotTaken)) && +         "No point in having a non-constant max backedge taken count!"); +  for (auto *PredSet : PredSetList) +    for (auto *P : *PredSet) +      addPredicate(P); +} + +ScalarEvolution::ExitLimit::ExitLimit( +    const SCEV *E, const SCEV *M, bool MaxOrZero, +    const SmallPtrSetImpl<const SCEVPredicate *> &PredSet) +    : ExitLimit(E, M, MaxOrZero, {&PredSet}) { +  assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || +          isa<SCEVConstant>(MaxNotTaken)) && +         "No point in having a non-constant max backedge taken count!"); +} + +ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M, +                                      bool MaxOrZero) +    : ExitLimit(E, M, MaxOrZero, None) { +  assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || +          isa<SCEVConstant>(MaxNotTaken)) && +         "No point in having a non-constant max backedge taken count!"); +} + +/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each +/// computable exit into a persistent ExitNotTakenInfo array. +ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( +    SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> +        &&ExitCounts, +    bool Complete, const SCEV *MaxCount, bool MaxOrZero) +    : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) { +  using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; + +  ExitNotTaken.reserve(ExitCounts.size()); +  std::transform( +      ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken), +      [&](const EdgeExitInfo &EEI) { +        BasicBlock *ExitBB = EEI.first; +        const ExitLimit &EL = EEI.second; +        if (EL.Predicates.empty()) +          return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr); + +        std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate); +        for (auto *Pred : EL.Predicates) +          Predicate->add(Pred); + +        return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); +      }); +  assert((isa<SCEVCouldNotCompute>(MaxCount) || isa<SCEVConstant>(MaxCount)) && +         "No point in having a non-constant max backedge taken count!"); +} + +/// Invalidate this result and free the ExitNotTakenInfo array. +void ScalarEvolution::BackedgeTakenInfo::clear() { +  ExitNotTaken.clear(); +} + +/// Compute the number of times the backedge of the specified loop will execute. +ScalarEvolution::BackedgeTakenInfo +ScalarEvolution::computeBackedgeTakenCount(const Loop *L, +                                           bool AllowPredicates) { +  SmallVector<BasicBlock *, 8> ExitingBlocks; +  L->getExitingBlocks(ExitingBlocks); + +  using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; + +  SmallVector<EdgeExitInfo, 4> ExitCounts; +  bool CouldComputeBECount = true; +  BasicBlock *Latch = L->getLoopLatch(); // may be NULL. +  const SCEV *MustExitMaxBECount = nullptr; +  const SCEV *MayExitMaxBECount = nullptr; +  bool MustExitMaxOrZero = false; + +  // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts +  // and compute maxBECount. +  // Do a union of all the predicates here. +  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { +    BasicBlock *ExitBB = ExitingBlocks[i]; +    ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); + +    assert((AllowPredicates || EL.Predicates.empty()) && +           "Predicated exit limit when predicates are not allowed!"); + +    // 1. For each exit that can be computed, add an entry to ExitCounts. +    // CouldComputeBECount is true only if all exits can be computed. +    if (EL.ExactNotTaken == getCouldNotCompute()) +      // We couldn't compute an exact value for this exit, so +      // we won't be able to compute an exact value for the loop. +      CouldComputeBECount = false; +    else +      ExitCounts.emplace_back(ExitBB, EL); + +    // 2. Derive the loop's MaxBECount from each exit's max number of +    // non-exiting iterations. Partition the loop exits into two kinds: +    // LoopMustExits and LoopMayExits. +    // +    // If the exit dominates the loop latch, it is a LoopMustExit otherwise it +    // is a LoopMayExit.  If any computable LoopMustExit is found, then +    // MaxBECount is the minimum EL.MaxNotTaken of computable +    // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum +    // EL.MaxNotTaken, where CouldNotCompute is considered greater than any +    // computable EL.MaxNotTaken. +    if (EL.MaxNotTaken != getCouldNotCompute() && Latch && +        DT.dominates(ExitBB, Latch)) { +      if (!MustExitMaxBECount) { +        MustExitMaxBECount = EL.MaxNotTaken; +        MustExitMaxOrZero = EL.MaxOrZero; +      } else { +        MustExitMaxBECount = +            getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken); +      } +    } else if (MayExitMaxBECount != getCouldNotCompute()) { +      if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute()) +        MayExitMaxBECount = EL.MaxNotTaken; +      else { +        MayExitMaxBECount = +            getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken); +      } +    } +  } +  const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount : +    (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); +  // The loop backedge will be taken the maximum or zero times if there's +  // a single exit that must be taken the maximum or zero times. +  bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); +  return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount, +                           MaxBECount, MaxOrZero); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, +                                      bool AllowPredicates) { +  assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); +  // If our exiting block does not dominate the latch, then its connection with +  // loop's exit limit may be far from trivial. +  const BasicBlock *Latch = L->getLoopLatch(); +  if (!Latch || !DT.dominates(ExitingBlock, Latch)) +    return getCouldNotCompute(); + +  bool IsOnlyExit = (L->getExitingBlock() != nullptr); +  TerminatorInst *Term = ExitingBlock->getTerminator(); +  if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { +    assert(BI->isConditional() && "If unconditional, it can't be in loop!"); +    bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); +    assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && +           "It should have one successor in loop and one exit block!"); +    // Proceed to the next level to examine the exit condition expression. +    return computeExitLimitFromCond( +        L, BI->getCondition(), ExitIfTrue, +        /*ControlsExit=*/IsOnlyExit, AllowPredicates); +  } + +  if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) { +    // For switch, make sure that there is a single exit from the loop. +    BasicBlock *Exit = nullptr; +    for (auto *SBB : successors(ExitingBlock)) +      if (!L->contains(SBB)) { +        if (Exit) // Multiple exit successors. +          return getCouldNotCompute(); +        Exit = SBB; +      } +    assert(Exit && "Exiting block must have at least one exit"); +    return computeExitLimitFromSingleExitSwitch(L, SI, Exit, +                                                /*ControlsExit=*/IsOnlyExit); +  } + +  return getCouldNotCompute(); +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( +    const Loop *L, Value *ExitCond, bool ExitIfTrue, +    bool ControlsExit, bool AllowPredicates) { +  ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); +  return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue, +                                        ControlsExit, AllowPredicates); +} + +Optional<ScalarEvolution::ExitLimit> +ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, +                                      bool ExitIfTrue, bool ControlsExit, +                                      bool AllowPredicates) { +  (void)this->L; +  (void)this->ExitIfTrue; +  (void)this->AllowPredicates; + +  assert(this->L == L && this->ExitIfTrue == ExitIfTrue && +         this->AllowPredicates == AllowPredicates && +         "Variance in assumed invariant key components!"); +  auto Itr = TripCountMap.find({ExitCond, ControlsExit}); +  if (Itr == TripCountMap.end()) +    return None; +  return Itr->second; +} + +void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, +                                             bool ExitIfTrue, +                                             bool ControlsExit, +                                             bool AllowPredicates, +                                             const ExitLimit &EL) { +  assert(this->L == L && this->ExitIfTrue == ExitIfTrue && +         this->AllowPredicates == AllowPredicates && +         "Variance in assumed invariant key components!"); + +  auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL}); +  assert(InsertResult.second && "Expected successful insertion!"); +  (void)InsertResult; +  (void)ExitIfTrue; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( +    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, +    bool ControlsExit, bool AllowPredicates) { + +  if (auto MaybeEL = +          Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) +    return *MaybeEL; + +  ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue, +                                              ControlsExit, AllowPredicates); +  Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL); +  return EL; +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( +    ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, +    bool ControlsExit, bool AllowPredicates) { +  // Check if the controlling expression for this loop is an And or Or. +  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { +    if (BO->getOpcode() == Instruction::And) { +      // Recurse on the operands of the and. +      bool EitherMayExit = !ExitIfTrue; +      ExitLimit EL0 = computeExitLimitFromCondCached( +          Cache, L, BO->getOperand(0), ExitIfTrue, +          ControlsExit && !EitherMayExit, AllowPredicates); +      ExitLimit EL1 = computeExitLimitFromCondCached( +          Cache, L, BO->getOperand(1), ExitIfTrue, +          ControlsExit && !EitherMayExit, AllowPredicates); +      const SCEV *BECount = getCouldNotCompute(); +      const SCEV *MaxBECount = getCouldNotCompute(); +      if (EitherMayExit) { +        // Both conditions must be true for the loop to continue executing. +        // Choose the less conservative count. +        if (EL0.ExactNotTaken == getCouldNotCompute() || +            EL1.ExactNotTaken == getCouldNotCompute()) +          BECount = getCouldNotCompute(); +        else +          BECount = +              getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); +        if (EL0.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL1.MaxNotTaken; +        else if (EL1.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL0.MaxNotTaken; +        else +          MaxBECount = +              getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); +      } else { +        // Both conditions must be true at the same time for the loop to exit. +        // For now, be conservative. +        if (EL0.MaxNotTaken == EL1.MaxNotTaken) +          MaxBECount = EL0.MaxNotTaken; +        if (EL0.ExactNotTaken == EL1.ExactNotTaken) +          BECount = EL0.ExactNotTaken; +      } + +      // There are cases (e.g. PR26207) where computeExitLimitFromCond is able +      // to be more aggressive when computing BECount than when computing +      // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and +      // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken +      // to not. +      if (isa<SCEVCouldNotCompute>(MaxBECount) && +          !isa<SCEVCouldNotCompute>(BECount)) +        MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + +      return ExitLimit(BECount, MaxBECount, false, +                       {&EL0.Predicates, &EL1.Predicates}); +    } +    if (BO->getOpcode() == Instruction::Or) { +      // Recurse on the operands of the or. +      bool EitherMayExit = ExitIfTrue; +      ExitLimit EL0 = computeExitLimitFromCondCached( +          Cache, L, BO->getOperand(0), ExitIfTrue, +          ControlsExit && !EitherMayExit, AllowPredicates); +      ExitLimit EL1 = computeExitLimitFromCondCached( +          Cache, L, BO->getOperand(1), ExitIfTrue, +          ControlsExit && !EitherMayExit, AllowPredicates); +      const SCEV *BECount = getCouldNotCompute(); +      const SCEV *MaxBECount = getCouldNotCompute(); +      if (EitherMayExit) { +        // Both conditions must be false for the loop to continue executing. +        // Choose the less conservative count. +        if (EL0.ExactNotTaken == getCouldNotCompute() || +            EL1.ExactNotTaken == getCouldNotCompute()) +          BECount = getCouldNotCompute(); +        else +          BECount = +              getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); +        if (EL0.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL1.MaxNotTaken; +        else if (EL1.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL0.MaxNotTaken; +        else +          MaxBECount = +              getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); +      } else { +        // Both conditions must be false at the same time for the loop to exit. +        // For now, be conservative. +        if (EL0.MaxNotTaken == EL1.MaxNotTaken) +          MaxBECount = EL0.MaxNotTaken; +        if (EL0.ExactNotTaken == EL1.ExactNotTaken) +          BECount = EL0.ExactNotTaken; +      } + +      return ExitLimit(BECount, MaxBECount, false, +                       {&EL0.Predicates, &EL1.Predicates}); +    } +  } + +  // With an icmp, it may be feasible to compute an exact backedge-taken count. +  // Proceed to the next level to examine the icmp. +  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { +    ExitLimit EL = +        computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit); +    if (EL.hasFullInfo() || !AllowPredicates) +      return EL; + +    // Try again, but use SCEV predicates this time. +    return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit, +                                    /*AllowPredicates=*/true); +  } + +  // Check for a constant condition. These are normally stripped out by +  // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to +  // preserve the CFG and is temporarily leaving constant conditions +  // in place. +  if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) { +    if (ExitIfTrue == !CI->getZExtValue()) +      // The backedge is always taken. +      return getCouldNotCompute(); +    else +      // The backedge is never taken. +      return getZero(CI->getType()); +  } + +  // If it's not an integer or pointer comparison then compute it the hard way. +  return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::computeExitLimitFromICmp(const Loop *L, +                                          ICmpInst *ExitCond, +                                          bool ExitIfTrue, +                                          bool ControlsExit, +                                          bool AllowPredicates) { +  // If the condition was exit on true, convert the condition to exit on false +  ICmpInst::Predicate Pred; +  if (!ExitIfTrue) +    Pred = ExitCond->getPredicate(); +  else +    Pred = ExitCond->getInversePredicate(); +  const ICmpInst::Predicate OriginalPred = Pred; + +  // Handle common loops like: for (X = "string"; *X; ++X) +  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0))) +    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) { +      ExitLimit ItCnt = +        computeLoadConstantCompareExitLimit(LI, RHS, L, Pred); +      if (ItCnt.hasAnyInfo()) +        return ItCnt; +    } + +  const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); +  const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); + +  // Try to evaluate any dependencies out of the loop. +  LHS = getSCEVAtScope(LHS, L); +  RHS = getSCEVAtScope(RHS, L); + +  // At this point, we would like to compute how many iterations of the +  // loop the predicate will return true for these inputs. +  if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { +    // If there is a loop-invariant, force it into the RHS. +    std::swap(LHS, RHS); +    Pred = ICmpInst::getSwappedPredicate(Pred); +  } + +  // Simplify the operands before analyzing them. +  (void)SimplifyICmpOperands(Pred, LHS, RHS); + +  // If we have a comparison of a chrec against a constant, try to use value +  // ranges to answer this query. +  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) +    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS)) +      if (AddRec->getLoop() == L) { +        // Form the constant range. +        ConstantRange CompRange = +            ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); + +        const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); +        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; +      } + +  switch (Pred) { +  case ICmpInst::ICMP_NE: {                     // while (X != Y) +    // Convert to: while (X-Y != 0) +    ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, +                                AllowPredicates); +    if (EL.hasAnyInfo()) return EL; +    break; +  } +  case ICmpInst::ICMP_EQ: {                     // while (X == Y) +    // Convert to: while (X-Y == 0) +    ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); +    if (EL.hasAnyInfo()) return EL; +    break; +  } +  case ICmpInst::ICMP_SLT: +  case ICmpInst::ICMP_ULT: {                    // while (X < Y) +    bool IsSigned = Pred == ICmpInst::ICMP_SLT; +    ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, +                                    AllowPredicates); +    if (EL.hasAnyInfo()) return EL; +    break; +  } +  case ICmpInst::ICMP_SGT: +  case ICmpInst::ICMP_UGT: {                    // while (X > Y) +    bool IsSigned = Pred == ICmpInst::ICMP_SGT; +    ExitLimit EL = +        howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, +                            AllowPredicates); +    if (EL.hasAnyInfo()) return EL; +    break; +  } +  default: +    break; +  } + +  auto *ExhaustiveCount = +      computeExitCountExhaustively(L, ExitCond, ExitIfTrue); + +  if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) +    return ExhaustiveCount; + +  return computeShiftCompareExitLimit(ExitCond->getOperand(0), +                                      ExitCond->getOperand(1), L, OriginalPred); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, +                                                      SwitchInst *Switch, +                                                      BasicBlock *ExitingBlock, +                                                      bool ControlsExit) { +  assert(!L->contains(ExitingBlock) && "Not an exiting block!"); + +  // Give up if the exit is the default dest of a switch. +  if (Switch->getDefaultDest() == ExitingBlock) +    return getCouldNotCompute(); + +  assert(L->contains(Switch->getDefaultDest()) && +         "Default case must not exit the loop!"); +  const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); +  const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); + +  // while (X != Y) --> while (X-Y != 0) +  ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); +  if (EL.hasAnyInfo()) +    return EL; + +  return getCouldNotCompute(); +} + +static ConstantInt * +EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, +                                ScalarEvolution &SE) { +  const SCEV *InVal = SE.getConstant(C); +  const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE); +  assert(isa<SCEVConstant>(Val) && +         "Evaluation of SCEV at constant didn't fold correctly?"); +  return cast<SCEVConstant>(Val)->getValue(); +} + +/// Given an exit condition of 'icmp op load X, cst', try to see if we can +/// compute the backedge execution count. +ScalarEvolution::ExitLimit +ScalarEvolution::computeLoadConstantCompareExitLimit( +  LoadInst *LI, +  Constant *RHS, +  const Loop *L, +  ICmpInst::Predicate predicate) { +  if (LI->isVolatile()) return getCouldNotCompute(); + +  // Check to see if the loaded pointer is a getelementptr of a global. +  // TODO: Use SCEV instead of manually grubbing with GEPs. +  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0)); +  if (!GEP) return getCouldNotCompute(); + +  // Make sure that it is really a constant global we are gepping, with an +  // initializer, and make sure the first IDX is really 0. +  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0)); +  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || +      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) || +      !cast<Constant>(GEP->getOperand(1))->isNullValue()) +    return getCouldNotCompute(); + +  // Okay, we allow one non-constant index into the GEP instruction. +  Value *VarIdx = nullptr; +  std::vector<Constant*> Indexes; +  unsigned VarIdxNum = 0; +  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i) +    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { +      Indexes.push_back(CI); +    } else if (!isa<ConstantInt>(GEP->getOperand(i))) { +      if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's. +      VarIdx = GEP->getOperand(i); +      VarIdxNum = i-2; +      Indexes.push_back(nullptr); +    } + +  // Loop-invariant loads may be a byproduct of loop optimization. Skip them. +  if (!VarIdx) +    return getCouldNotCompute(); + +  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant. +  // Check to see if X is a loop variant variable value now. +  const SCEV *Idx = getSCEV(VarIdx); +  Idx = getSCEVAtScope(Idx, L); + +  // We can only recognize very limited forms of loop index expressions, in +  // particular, only affine AddRec's like {C1,+,C2}. +  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx); +  if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) || +      !isa<SCEVConstant>(IdxExpr->getOperand(0)) || +      !isa<SCEVConstant>(IdxExpr->getOperand(1))) +    return getCouldNotCompute(); + +  unsigned MaxSteps = MaxBruteForceIterations; +  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) { +    ConstantInt *ItCst = ConstantInt::get( +                           cast<IntegerType>(IdxExpr->getType()), IterationNum); +    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this); + +    // Form the GEP offset. +    Indexes[VarIdxNum] = Val; + +    Constant *Result = ConstantFoldLoadThroughGEPIndices(GV->getInitializer(), +                                                         Indexes); +    if (!Result) break;  // Cannot compute! + +    // Evaluate the condition for this iteration. +    Result = ConstantExpr::getICmp(predicate, Result, RHS); +    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure +    if (cast<ConstantInt>(Result)->getValue().isMinValue()) { +      ++NumArrayLenItCounts; +      return getConstant(ItCst);   // Found terminating iteration! +    } +  } +  return getCouldNotCompute(); +} + +ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( +    Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) { +  ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV); +  if (!RHS) +    return getCouldNotCompute(); + +  const BasicBlock *Latch = L->getLoopLatch(); +  if (!Latch) +    return getCouldNotCompute(); + +  const BasicBlock *Predecessor = L->getLoopPredecessor(); +  if (!Predecessor) +    return getCouldNotCompute(); + +  // Return true if V is of the form "LHS `shift_op` <positive constant>". +  // Return LHS in OutLHS and shift_opt in OutOpCode. +  auto MatchPositiveShift = +      [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) { + +    using namespace PatternMatch; + +    ConstantInt *ShiftAmt; +    if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) +      OutOpCode = Instruction::LShr; +    else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) +      OutOpCode = Instruction::AShr; +    else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt)))) +      OutOpCode = Instruction::Shl; +    else +      return false; + +    return ShiftAmt->getValue().isStrictlyPositive(); +  }; + +  // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in +  // +  // loop: +  //   %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ] +  //   %iv.shifted = lshr i32 %iv, <positive constant> +  // +  // Return true on a successful match.  Return the corresponding PHI node (%iv +  // above) in PNOut and the opcode of the shift operation in OpCodeOut. +  auto MatchShiftRecurrence = +      [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { +    Optional<Instruction::BinaryOps> PostShiftOpCode; + +    { +      Instruction::BinaryOps OpC; +      Value *V; + +      // If we encounter a shift instruction, "peel off" the shift operation, +      // and remember that we did so.  Later when we inspect %iv's backedge +      // value, we will make sure that the backedge value uses the same +      // operation. +      // +      // Note: the peeled shift operation does not have to be the same +      // instruction as the one feeding into the PHI's backedge value.  We only +      // really care about it being the same *kind* of shift instruction -- +      // that's all that is required for our later inferences to hold. +      if (MatchPositiveShift(LHS, V, OpC)) { +        PostShiftOpCode = OpC; +        LHS = V; +      } +    } + +    PNOut = dyn_cast<PHINode>(LHS); +    if (!PNOut || PNOut->getParent() != L->getHeader()) +      return false; + +    Value *BEValue = PNOut->getIncomingValueForBlock(Latch); +    Value *OpLHS; + +    return +        // The backedge value for the PHI node must be a shift by a positive +        // amount +        MatchPositiveShift(BEValue, OpLHS, OpCodeOut) && + +        // of the PHI node itself +        OpLHS == PNOut && + +        // and the kind of shift should be match the kind of shift we peeled +        // off, if any. +        (!PostShiftOpCode.hasValue() || *PostShiftOpCode == OpCodeOut); +  }; + +  PHINode *PN; +  Instruction::BinaryOps OpCode; +  if (!MatchShiftRecurrence(LHS, PN, OpCode)) +    return getCouldNotCompute(); + +  const DataLayout &DL = getDataLayout(); + +  // The key rationale for this optimization is that for some kinds of shift +  // recurrences, the value of the recurrence "stabilizes" to either 0 or -1 +  // within a finite number of iterations.  If the condition guarding the +  // backedge (in the sense that the backedge is taken if the condition is true) +  // is false for the value the shift recurrence stabilizes to, then we know +  // that the backedge is taken only a finite number of times. + +  ConstantInt *StableValue = nullptr; +  switch (OpCode) { +  default: +    llvm_unreachable("Impossible case!"); + +  case Instruction::AShr: { +    // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most +    // bitwidth(K) iterations. +    Value *FirstValue = PN->getIncomingValueForBlock(Predecessor); +    KnownBits Known = computeKnownBits(FirstValue, DL, 0, nullptr, +                                       Predecessor->getTerminator(), &DT); +    auto *Ty = cast<IntegerType>(RHS->getType()); +    if (Known.isNonNegative()) +      StableValue = ConstantInt::get(Ty, 0); +    else if (Known.isNegative()) +      StableValue = ConstantInt::get(Ty, -1, true); +    else +      return getCouldNotCompute(); + +    break; +  } +  case Instruction::LShr: +  case Instruction::Shl: +    // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>} +    // stabilize to 0 in at most bitwidth(K) iterations. +    StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0); +    break; +  } + +  auto *Result = +      ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI); +  assert(Result->getType()->isIntegerTy(1) && +         "Otherwise cannot be an operand to a branch instruction"); + +  if (Result->isZeroValue()) { +    unsigned BitWidth = getTypeSizeInBits(RHS->getType()); +    const SCEV *UpperBound = +        getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); +    return ExitLimit(getCouldNotCompute(), UpperBound, false); +  } + +  return getCouldNotCompute(); +} + +/// Return true if we can constant fold an instruction of the specified type, +/// assuming that all operands were constants. +static bool CanConstantFold(const Instruction *I) { +  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || +      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || +      isa<LoadInst>(I)) +    return true; + +  if (const CallInst *CI = dyn_cast<CallInst>(I)) +    if (const Function *F = CI->getCalledFunction()) +      return canConstantFoldCallTo(CI, F); +  return false; +} + +/// Determine whether this instruction can constant evolve within this loop +/// assuming its operands can all constant evolve. +static bool canConstantEvolve(Instruction *I, const Loop *L) { +  // An instruction outside of the loop can't be derived from a loop PHI. +  if (!L->contains(I)) return false; + +  if (isa<PHINode>(I)) { +    // We don't currently keep track of the control flow needed to evaluate +    // PHIs, so we cannot handle PHIs inside of loops. +    return L->getHeader() == I->getParent(); +  } + +  // If we won't be able to constant fold this expression even if the operands +  // are constants, bail early. +  return CanConstantFold(I); +} + +/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by +/// recursing through each instruction operand until reaching a loop header phi. +static PHINode * +getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, +                               DenseMap<Instruction *, PHINode *> &PHIMap, +                               unsigned Depth) { +  if (Depth > MaxConstantEvolvingDepth) +    return nullptr; + +  // Otherwise, we can evaluate this instruction if all of its operands are +  // constant or derived from a PHI node themselves. +  PHINode *PHI = nullptr; +  for (Value *Op : UseInst->operands()) { +    if (isa<Constant>(Op)) continue; + +    Instruction *OpInst = dyn_cast<Instruction>(Op); +    if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr; + +    PHINode *P = dyn_cast<PHINode>(OpInst); +    if (!P) +      // If this operand is already visited, reuse the prior result. +      // We may have P != PHI if this is the deepest point at which the +      // inconsistent paths meet. +      P = PHIMap.lookup(OpInst); +    if (!P) { +      // Recurse and memoize the results, whether a phi is found or not. +      // This recursive call invalidates pointers into PHIMap. +      P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1); +      PHIMap[OpInst] = P; +    } +    if (!P) +      return nullptr;  // Not evolving from PHI +    if (PHI && PHI != P) +      return nullptr;  // Evolving from multiple different PHIs. +    PHI = P; +  } +  // This is a expression evolving from a constant PHI! +  return PHI; +} + +/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node +/// in the loop that V is derived from.  We allow arbitrary operations along the +/// way, but the operands of an operation must either be constants or a value +/// derived from a constant PHI.  If this expression does not fit with these +/// constraints, return null. +static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) { +  Instruction *I = dyn_cast<Instruction>(V); +  if (!I || !canConstantEvolve(I, L)) return nullptr; + +  if (PHINode *PN = dyn_cast<PHINode>(I)) +    return PN; + +  // Record non-constant instructions contained by the loop. +  DenseMap<Instruction *, PHINode *> PHIMap; +  return getConstantEvolvingPHIOperands(I, L, PHIMap, 0); +} + +/// EvaluateExpression - Given an expression that passes the +/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node +/// in the loop has the value PHIVal.  If we can't fold this expression for some +/// reason, return null. +static Constant *EvaluateExpression(Value *V, const Loop *L, +                                    DenseMap<Instruction *, Constant *> &Vals, +                                    const DataLayout &DL, +                                    const TargetLibraryInfo *TLI) { +  // Convenient constant check, but redundant for recursive calls. +  if (Constant *C = dyn_cast<Constant>(V)) return C; +  Instruction *I = dyn_cast<Instruction>(V); +  if (!I) return nullptr; + +  if (Constant *C = Vals.lookup(I)) return C; + +  // An instruction inside the loop depends on a value outside the loop that we +  // weren't given a mapping for, or a value such as a call inside the loop. +  if (!canConstantEvolve(I, L)) return nullptr; + +  // An unmapped PHI can be due to a branch or another loop inside this loop, +  // or due to this not being the initial iteration through a loop where we +  // couldn't compute the evolution of this particular PHI last time. +  if (isa<PHINode>(I)) return nullptr; + +  std::vector<Constant*> Operands(I->getNumOperands()); + +  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { +    Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i)); +    if (!Operand) { +      Operands[i] = dyn_cast<Constant>(I->getOperand(i)); +      if (!Operands[i]) return nullptr; +      continue; +    } +    Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI); +    Vals[Operand] = C; +    if (!C) return nullptr; +    Operands[i] = C; +  } + +  if (CmpInst *CI = dyn_cast<CmpInst>(I)) +    return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], +                                           Operands[1], DL, TLI); +  if (LoadInst *LI = dyn_cast<LoadInst>(I)) { +    if (!LI->isVolatile()) +      return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); +  } +  return ConstantFoldInstOperands(I, Operands, DL, TLI); +} + + +// If every incoming value to PN except the one for BB is a specific Constant, +// return that, else return nullptr. +static Constant *getOtherIncomingValue(PHINode *PN, BasicBlock *BB) { +  Constant *IncomingVal = nullptr; + +  for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { +    if (PN->getIncomingBlock(i) == BB) +      continue; + +    auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i)); +    if (!CurrentVal) +      return nullptr; + +    if (IncomingVal != CurrentVal) { +      if (IncomingVal) +        return nullptr; +      IncomingVal = CurrentVal; +    } +  } + +  return IncomingVal; +} + +/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is +/// in the header of its containing loop, we know the loop executes a +/// constant number of times, and the PHI node is just a recurrence +/// involving constants, fold it. +Constant * +ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, +                                                   const APInt &BEs, +                                                   const Loop *L) { +  auto I = ConstantEvolutionLoopExitValue.find(PN); +  if (I != ConstantEvolutionLoopExitValue.end()) +    return I->second; + +  if (BEs.ugt(MaxBruteForceIterations)) +    return ConstantEvolutionLoopExitValue[PN] = nullptr;  // Not going to evaluate it. + +  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN]; + +  DenseMap<Instruction *, Constant *> CurrentIterVals; +  BasicBlock *Header = L->getHeader(); +  assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); + +  BasicBlock *Latch = L->getLoopLatch(); +  if (!Latch) +    return nullptr; + +  for (PHINode &PHI : Header->phis()) { +    if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) +      CurrentIterVals[&PHI] = StartCST; +  } +  if (!CurrentIterVals.count(PN)) +    return RetVal = nullptr; + +  Value *BEValue = PN->getIncomingValueForBlock(Latch); + +  // Execute the loop symbolically to determine the exit value. +  assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) && +         "BEs is <= MaxBruteForceIterations which is an 'unsigned'!"); + +  unsigned NumIterations = BEs.getZExtValue(); // must be in range +  unsigned IterationNum = 0; +  const DataLayout &DL = getDataLayout(); +  for (; ; ++IterationNum) { +    if (IterationNum == NumIterations) +      return RetVal = CurrentIterVals[PN];  // Got exit value! + +    // Compute the value of the PHIs for the next iteration. +    // EvaluateExpression adds non-phi values to the CurrentIterVals map. +    DenseMap<Instruction *, Constant *> NextIterVals; +    Constant *NextPHI = +        EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); +    if (!NextPHI) +      return nullptr;        // Couldn't evaluate! +    NextIterVals[PN] = NextPHI; + +    bool StoppedEvolving = NextPHI == CurrentIterVals[PN]; + +    // Also evaluate the other PHI nodes.  However, we don't get to stop if we +    // cease to be able to evaluate one of them or if they stop evolving, +    // because that doesn't necessarily prevent us from computing PN. +    SmallVector<std::pair<PHINode *, Constant *>, 8> PHIsToCompute; +    for (const auto &I : CurrentIterVals) { +      PHINode *PHI = dyn_cast<PHINode>(I.first); +      if (!PHI || PHI == PN || PHI->getParent() != Header) continue; +      PHIsToCompute.emplace_back(PHI, I.second); +    } +    // We use two distinct loops because EvaluateExpression may invalidate any +    // iterators into CurrentIterVals. +    for (const auto &I : PHIsToCompute) { +      PHINode *PHI = I.first; +      Constant *&NextPHI = NextIterVals[PHI]; +      if (!NextPHI) {   // Not already computed. +        Value *BEValue = PHI->getIncomingValueForBlock(Latch); +        NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); +      } +      if (NextPHI != I.second) +        StoppedEvolving = false; +    } + +    // If all entries in CurrentIterVals == NextIterVals then we can stop +    // iterating, the loop can't continue to change. +    if (StoppedEvolving) +      return RetVal = CurrentIterVals[PN]; + +    CurrentIterVals.swap(NextIterVals); +  } +} + +const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, +                                                          Value *Cond, +                                                          bool ExitWhen) { +  PHINode *PN = getConstantEvolvingPHI(Cond, L); +  if (!PN) return getCouldNotCompute(); + +  // If the loop is canonicalized, the PHI will have exactly two entries. +  // That's the only form we support here. +  if (PN->getNumIncomingValues() != 2) return getCouldNotCompute(); + +  DenseMap<Instruction *, Constant *> CurrentIterVals; +  BasicBlock *Header = L->getHeader(); +  assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!"); + +  BasicBlock *Latch = L->getLoopLatch(); +  assert(Latch && "Should follow from NumIncomingValues == 2!"); + +  for (PHINode &PHI : Header->phis()) { +    if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) +      CurrentIterVals[&PHI] = StartCST; +  } +  if (!CurrentIterVals.count(PN)) +    return getCouldNotCompute(); + +  // Okay, we find a PHI node that defines the trip count of this loop.  Execute +  // the loop symbolically to determine when the condition gets a value of +  // "ExitWhen". +  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis. +  const DataLayout &DL = getDataLayout(); +  for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){ +    auto *CondVal = dyn_cast_or_null<ConstantInt>( +        EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI)); + +    // Couldn't symbolically evaluate. +    if (!CondVal) return getCouldNotCompute(); + +    if (CondVal->getValue() == uint64_t(ExitWhen)) { +      ++NumBruteForceTripCountsComputed; +      return getConstant(Type::getInt32Ty(getContext()), IterationNum); +    } + +    // Update all the PHI nodes for the next iteration. +    DenseMap<Instruction *, Constant *> NextIterVals; + +    // Create a list of which PHIs we need to compute. We want to do this before +    // calling EvaluateExpression on them because that may invalidate iterators +    // into CurrentIterVals. +    SmallVector<PHINode *, 8> PHIsToCompute; +    for (const auto &I : CurrentIterVals) { +      PHINode *PHI = dyn_cast<PHINode>(I.first); +      if (!PHI || PHI->getParent() != Header) continue; +      PHIsToCompute.push_back(PHI); +    } +    for (PHINode *PHI : PHIsToCompute) { +      Constant *&NextPHI = NextIterVals[PHI]; +      if (NextPHI) continue;    // Already computed! + +      Value *BEValue = PHI->getIncomingValueForBlock(Latch); +      NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI); +    } +    CurrentIterVals.swap(NextIterVals); +  } + +  // Too many iterations were needed to evaluate. +  return getCouldNotCompute(); +} + +const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { +  SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = +      ValuesAtScopes[V]; +  // Check to see if we've folded this expression at this loop before. +  for (auto &LS : Values) +    if (LS.first == L) +      return LS.second ? LS.second : V; + +  Values.emplace_back(L, nullptr); + +  // Otherwise compute it. +  const SCEV *C = computeSCEVAtScope(V, L); +  for (auto &LS : reverse(ValuesAtScopes[V])) +    if (LS.first == L) { +      LS.second = C; +      break; +    } +  return C; +} + +/// This builds up a Constant using the ConstantExpr interface.  That way, we +/// will return Constants for objects which aren't represented by a +/// SCEVConstant, because SCEVConstant is restricted to ConstantInt. +/// Returns NULL if the SCEV isn't representable as a Constant. +static Constant *BuildConstantFromSCEV(const SCEV *V) { +  switch (static_cast<SCEVTypes>(V->getSCEVType())) { +    case scCouldNotCompute: +    case scAddRecExpr: +      break; +    case scConstant: +      return cast<SCEVConstant>(V)->getValue(); +    case scUnknown: +      return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); +    case scSignExtend: { +      const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V); +      if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) +        return ConstantExpr::getSExt(CastOp, SS->getType()); +      break; +    } +    case scZeroExtend: { +      const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V); +      if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) +        return ConstantExpr::getZExt(CastOp, SZ->getType()); +      break; +    } +    case scTruncate: { +      const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); +      if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) +        return ConstantExpr::getTrunc(CastOp, ST->getType()); +      break; +    } +    case scAddExpr: { +      const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); +      if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { +        if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { +          unsigned AS = PTy->getAddressSpace(); +          Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); +          C = ConstantExpr::getBitCast(C, DestPtrTy); +        } +        for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { +          Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); +          if (!C2) return nullptr; + +          // First pointer! +          if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { +            unsigned AS = C2->getType()->getPointerAddressSpace(); +            std::swap(C, C2); +            Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); +            // The offsets have been converted to bytes.  We can add bytes to an +            // i8* by GEP with the byte count in the first index. +            C = ConstantExpr::getBitCast(C, DestPtrTy); +          } + +          // Don't bother trying to sum two pointers. We probably can't +          // statically compute a load that results from it anyway. +          if (C2->getType()->isPointerTy()) +            return nullptr; + +          if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { +            if (PTy->getElementType()->isStructTy()) +              C2 = ConstantExpr::getIntegerCast( +                  C2, Type::getInt32Ty(C->getContext()), true); +            C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); +          } else +            C = ConstantExpr::getAdd(C, C2); +        } +        return C; +      } +      break; +    } +    case scMulExpr: { +      const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); +      if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { +        // Don't bother with pointers at all. +        if (C->getType()->isPointerTy()) return nullptr; +        for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { +          Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); +          if (!C2 || C2->getType()->isPointerTy()) return nullptr; +          C = ConstantExpr::getMul(C, C2); +        } +        return C; +      } +      break; +    } +    case scUDivExpr: { +      const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); +      if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) +        if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) +          if (LHS->getType() == RHS->getType()) +            return ConstantExpr::getUDiv(LHS, RHS); +      break; +    } +    case scSMaxExpr: +    case scUMaxExpr: +      break; // TODO: smax, umax. +  } +  return nullptr; +} + +const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { +  if (isa<SCEVConstant>(V)) return V; + +  // If this instruction is evolved from a constant-evolving PHI, compute the +  // exit value from the loop without using SCEVs. +  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) { +    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) { +      const Loop *LI = this->LI[I->getParent()]; +      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value. +        if (PHINode *PN = dyn_cast<PHINode>(I)) +          if (PN->getParent() == LI->getHeader()) { +            // Okay, there is no closed form solution for the PHI node.  Check +            // to see if the loop that contains it has a known backedge-taken +            // count.  If so, we may be able to force computation of the exit +            // value. +            const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI); +            if (const SCEVConstant *BTCC = +                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) { + +              // This trivial case can show up in some degenerate cases where +              // the incoming IR has not yet been fully simplified. +              if (BTCC->getValue()->isZero()) { +                Value *InitValue = nullptr; +                bool MultipleInitValues = false; +                for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { +                  if (!LI->contains(PN->getIncomingBlock(i))) { +                    if (!InitValue) +                      InitValue = PN->getIncomingValue(i); +                    else if (InitValue != PN->getIncomingValue(i)) { +                      MultipleInitValues = true; +                      break; +                    } +                  } +                  if (!MultipleInitValues && InitValue) +                    return getSCEV(InitValue); +                } +              } +              // Okay, we know how many times the containing loop executes.  If +              // this is a constant evolving PHI node, get the final value at +              // the specified iteration number. +              Constant *RV = +                  getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI); +              if (RV) return getSCEV(RV); +            } +          } + +      // Okay, this is an expression that we cannot symbolically evaluate +      // into a SCEV.  Check to see if it's possible to symbolically evaluate +      // the arguments into constants, and if so, try to constant propagate the +      // result.  This is particularly useful for computing loop exit values. +      if (CanConstantFold(I)) { +        SmallVector<Constant *, 4> Operands; +        bool MadeImprovement = false; +        for (Value *Op : I->operands()) { +          if (Constant *C = dyn_cast<Constant>(Op)) { +            Operands.push_back(C); +            continue; +          } + +          // If any of the operands is non-constant and if they are +          // non-integer and non-pointer, don't even try to analyze them +          // with scev techniques. +          if (!isSCEVable(Op->getType())) +            return V; + +          const SCEV *OrigV = getSCEV(Op); +          const SCEV *OpV = getSCEVAtScope(OrigV, L); +          MadeImprovement |= OrigV != OpV; + +          Constant *C = BuildConstantFromSCEV(OpV); +          if (!C) return V; +          if (C->getType() != Op->getType()) +            C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, +                                                              Op->getType(), +                                                              false), +                                      C, Op->getType()); +          Operands.push_back(C); +        } + +        // Check to see if getSCEVAtScope actually made an improvement. +        if (MadeImprovement) { +          Constant *C = nullptr; +          const DataLayout &DL = getDataLayout(); +          if (const CmpInst *CI = dyn_cast<CmpInst>(I)) +            C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], +                                                Operands[1], DL, &TLI); +          else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { +            if (!LI->isVolatile()) +              C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); +          } else +            C = ConstantFoldInstOperands(I, Operands, DL, &TLI); +          if (!C) return V; +          return getSCEV(C); +        } +      } +    } + +    // This is some other type of SCEVUnknown, just return it. +    return V; +  } + +  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) { +    // Avoid performing the look-up in the common case where the specified +    // expression has no loop-variant portions. +    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) { +      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); +      if (OpAtScope != Comm->getOperand(i)) { +        // Okay, at least one of these operands is loop variant but might be +        // foldable.  Build a new instance of the folded commutative expression. +        SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(), +                                            Comm->op_begin()+i); +        NewOps.push_back(OpAtScope); + +        for (++i; i != e; ++i) { +          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L); +          NewOps.push_back(OpAtScope); +        } +        if (isa<SCEVAddExpr>(Comm)) +          return getAddExpr(NewOps); +        if (isa<SCEVMulExpr>(Comm)) +          return getMulExpr(NewOps); +        if (isa<SCEVSMaxExpr>(Comm)) +          return getSMaxExpr(NewOps); +        if (isa<SCEVUMaxExpr>(Comm)) +          return getUMaxExpr(NewOps); +        llvm_unreachable("Unknown commutative SCEV type!"); +      } +    } +    // If we got here, all operands are loop invariant. +    return Comm; +  } + +  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) { +    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L); +    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L); +    if (LHS == Div->getLHS() && RHS == Div->getRHS()) +      return Div;   // must be loop invariant +    return getUDivExpr(LHS, RHS); +  } + +  // If this is a loop recurrence for a loop that does not contain L, then we +  // are dealing with the final value computed by the loop. +  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) { +    // First, attempt to evaluate each operand. +    // Avoid performing the look-up in the common case where the specified +    // expression has no loop-variant portions. +    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { +      const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L); +      if (OpAtScope == AddRec->getOperand(i)) +        continue; + +      // Okay, at least one of these operands is loop variant but might be +      // foldable.  Build a new instance of the folded commutative expression. +      SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(), +                                          AddRec->op_begin()+i); +      NewOps.push_back(OpAtScope); +      for (++i; i != e; ++i) +        NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L)); + +      const SCEV *FoldedRec = +        getAddRecExpr(NewOps, AddRec->getLoop(), +                      AddRec->getNoWrapFlags(SCEV::FlagNW)); +      AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec); +      // The addrec may be folded to a nonrecurrence, for example, if the +      // induction variable is multiplied by zero after constant folding. Go +      // ahead and return the folded value. +      if (!AddRec) +        return FoldedRec; +      break; +    } + +    // If the scope is outside the addrec's loop, evaluate it by using the +    // loop exit value of the addrec. +    if (!AddRec->getLoop()->contains(L)) { +      // To evaluate this recurrence, we need to know how many times the AddRec +      // loop iterates.  Compute this now. +      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop()); +      if (BackedgeTakenCount == getCouldNotCompute()) return AddRec; + +      // Then, evaluate the AddRec. +      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this); +    } + +    return AddRec; +  } + +  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) { +    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); +    if (Op == Cast->getOperand()) +      return Cast;  // must be loop invariant +    return getZeroExtendExpr(Op, Cast->getType()); +  } + +  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) { +    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); +    if (Op == Cast->getOperand()) +      return Cast;  // must be loop invariant +    return getSignExtendExpr(Op, Cast->getType()); +  } + +  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) { +    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); +    if (Op == Cast->getOperand()) +      return Cast;  // must be loop invariant +    return getTruncateExpr(Op, Cast->getType()); +  } + +  llvm_unreachable("Unknown SCEV type!"); +} + +const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { +  return getSCEVAtScope(getSCEV(V), L); +} + +const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { +  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) +    return stripInjectiveFunctions(ZExt->getOperand()); +  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) +    return stripInjectiveFunctions(SExt->getOperand()); +  return S; +} + +/// Finds the minimum unsigned root of the following equation: +/// +///     A * X = B (mod N) +/// +/// where N = 2^BW and BW is the common bit width of A and B. The signedness of +/// A and B isn't important. +/// +/// If the equation does not have a solution, SCEVCouldNotCompute is returned. +static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, +                                               ScalarEvolution &SE) { +  uint32_t BW = A.getBitWidth(); +  assert(BW == SE.getTypeSizeInBits(B->getType())); +  assert(A != 0 && "A must be non-zero."); + +  // 1. D = gcd(A, N) +  // +  // The gcd of A and N may have only one prime factor: 2. The number of +  // trailing zeros in A is its multiplicity +  uint32_t Mult2 = A.countTrailingZeros(); +  // D = 2^Mult2 + +  // 2. Check if B is divisible by D. +  // +  // B is divisible by D if and only if the multiplicity of prime factor 2 for B +  // is not less than multiplicity of this prime factor for D. +  if (SE.GetMinTrailingZeros(B) < Mult2) +    return SE.getCouldNotCompute(); + +  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic +  // modulo (N / D). +  // +  // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent +  // (N / D) in general. The inverse itself always fits into BW bits, though, +  // so we immediately truncate it. +  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D +  APInt Mod(BW + 1, 0); +  Mod.setBit(BW - Mult2);  // Mod = N / D +  APInt I = AD.multiplicativeInverse(Mod).trunc(BW); + +  // 4. Compute the minimum unsigned root of the equation: +  // I * (B / D) mod (N / D) +  // To simplify the computation, we factor out the divide by D: +  // (I * B mod N) / D +  const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2)); +  return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); +} + +/// Find the roots of the quadratic equation for the given quadratic chrec +/// {L,+,M,+,N}.  This returns either the two roots (which might be the same) or +/// two SCEVCouldNotCompute objects. +static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> +SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { +  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); +  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); +  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); +  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); + +  // We currently can only solve this if the coefficients are constants. +  if (!LC || !MC || !NC) +    return None; + +  uint32_t BitWidth = LC->getAPInt().getBitWidth(); +  const APInt &L = LC->getAPInt(); +  const APInt &M = MC->getAPInt(); +  const APInt &N = NC->getAPInt(); +  APInt Two(BitWidth, 2); + +  // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C + +  // The A coefficient is N/2 +  APInt A = N.sdiv(Two); + +  // The B coefficient is M-N/2 +  APInt B = M; +  B -= A; // A is the same as N/2. + +  // The C coefficient is L. +  const APInt& C = L; + +  // Compute the B^2-4ac term. +  APInt SqrtTerm = B; +  SqrtTerm *= B; +  SqrtTerm -= 4 * (A * C); + +  if (SqrtTerm.isNegative()) { +    // The loop is provably infinite. +    return None; +  } + +  // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest +  // integer value or else APInt::sqrt() will assert. +  APInt SqrtVal = SqrtTerm.sqrt(); + +  // Compute the two solutions for the quadratic formula. +  // The divisions must be performed as signed divisions. +  APInt NegB = -std::move(B); +  APInt TwoA = std::move(A); +  TwoA <<= 1; +  if (TwoA.isNullValue()) +    return None; + +  LLVMContext &Context = SE.getContext(); + +  ConstantInt *Solution1 = +    ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); +  ConstantInt *Solution2 = +    ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + +  return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), +                        cast<SCEVConstant>(SE.getConstant(Solution2))); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, +                              bool AllowPredicates) { + +  // This is only used for loops with a "x != y" exit test. The exit condition +  // is now expressed as a single expression, V = x-y. So the exit test is +  // effectively V != 0.  We know and take advantage of the fact that this +  // expression only being used in a comparison by zero context. + +  SmallPtrSet<const SCEVPredicate *, 4> Predicates; +  // If the value is a constant +  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { +    // If the value is already zero, the branch will execute zero times. +    if (C->getValue()->isZero()) return C; +    return getCouldNotCompute();  // Otherwise it will loop infinitely. +  } + +  const SCEVAddRecExpr *AddRec = +      dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V)); + +  if (!AddRec && AllowPredicates) +    // Try to make this an AddRec using runtime tests, in the first X +    // iterations of this loop, where X is the SCEV expression found by the +    // algorithm below. +    AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates); + +  if (!AddRec || AddRec->getLoop() != L) +    return getCouldNotCompute(); + +  // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of +  // the quadratic equation to solve it. +  if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { +    if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { +      const SCEVConstant *R1 = Roots->first; +      const SCEVConstant *R2 = Roots->second; +      // Pick the smallest positive root value. +      if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( +              CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { +        if (!CB->getZExtValue()) +          std::swap(R1, R2); // R1 is the minimum root now. + +        // We can only use this value if the chrec ends up with an exact zero +        // value at this index.  When solving for "X*X != 5", for example, we +        // should not accept a root of 2. +        const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); +        if (Val->isZero()) +          // We found a quadratic root! +          return ExitLimit(R1, R1, false, Predicates); +      } +    } +    return getCouldNotCompute(); +  } + +  // Otherwise we can only handle this if it is affine. +  if (!AddRec->isAffine()) +    return getCouldNotCompute(); + +  // If this is an affine expression, the execution count of this branch is +  // the minimum unsigned root of the following equation: +  // +  //     Start + Step*N = 0 (mod 2^BW) +  // +  // equivalent to: +  // +  //             Step*N = -Start (mod 2^BW) +  // +  // where BW is the common bit width of Start and Step. + +  // Get the initial value for the loop. +  const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop()); +  const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop()); + +  // For now we handle only constant steps. +  // +  // TODO: Handle a nonconstant Step given AddRec<NUW>. If the +  // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap +  // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step. +  // We have not yet seen any such cases. +  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step); +  if (!StepC || StepC->getValue()->isZero()) +    return getCouldNotCompute(); + +  // For positive steps (counting up until unsigned overflow): +  //   N = -Start/Step (as unsigned) +  // For negative steps (counting down to zero): +  //   N = Start/-Step +  // First compute the unsigned distance from zero in the direction of Step. +  bool CountDown = StepC->getAPInt().isNegative(); +  const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start); + +  // Handle unitary steps, which cannot wraparound. +  // 1*N = -Start; -1*N = Start (mod 2^BW), so: +  //   N = Distance (as unsigned) +  if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) { +    APInt MaxBECount = getUnsignedRangeMax(Distance); + +    // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, +    // we end up with a loop whose backedge-taken count is n - 1.  Detect this +    // case, and see if we can improve the bound. +    // +    // Explicitly handling this here is necessary because getUnsignedRange +    // isn't context-sensitive; it doesn't know that we only care about the +    // range inside the loop. +    const SCEV *Zero = getZero(Distance->getType()); +    const SCEV *One = getOne(Distance->getType()); +    const SCEV *DistancePlusOne = getAddExpr(Distance, One); +    if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { +      // If Distance + 1 doesn't overflow, we can compute the maximum distance +      // as "unsigned_max(Distance + 1) - 1". +      ConstantRange CR = getUnsignedRange(DistancePlusOne); +      MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1); +    } +    return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); +  } + +  // If the condition controls loop exit (the loop exits only if the expression +  // is true) and the addition is no-wrap we can use unsigned divide to +  // compute the backedge count.  In this case, the step may not divide the +  // distance, but we don't care because if the condition is "missed" the loop +  // will have undefined behavior due to wrapping. +  if (ControlsExit && AddRec->hasNoSelfWrap() && +      loopHasNoAbnormalExits(AddRec->getLoop())) { +    const SCEV *Exact = +        getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); +    const SCEV *Max = +        Exact == getCouldNotCompute() +            ? Exact +            : getConstant(getUnsignedRangeMax(Exact)); +    return ExitLimit(Exact, Max, false, Predicates); +  } + +  // Solve the general equation. +  const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(), +                                               getNegativeSCEV(Start), *this); +  const SCEV *M = E == getCouldNotCompute() +                      ? E +                      : getConstant(getUnsignedRangeMax(E)); +  return ExitLimit(E, M, false, Predicates); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { +  // Loops that look like: while (X == 0) are very strange indeed.  We don't +  // handle them yet except for the trivial case.  This could be expanded in the +  // future as needed. + +  // If the value is a constant, check to see if it is known to be non-zero +  // already.  If so, the backedge will execute zero times. +  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { +    if (!C->getValue()->isZero()) +      return getZero(C->getType()); +    return getCouldNotCompute();  // Otherwise it will loop infinitely. +  } + +  // We could implement others, but I really doubt anyone writes loops like +  // this, and if they did, they would already be constant folded. +  return getCouldNotCompute(); +} + +std::pair<BasicBlock *, BasicBlock *> +ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { +  // If the block has a unique predecessor, then there is no path from the +  // predecessor to the block that does not go through the direct edge +  // from the predecessor to the block. +  if (BasicBlock *Pred = BB->getSinglePredecessor()) +    return {Pred, BB}; + +  // A loop's header is defined to be a block that dominates the loop. +  // If the header has a unique predecessor outside the loop, it must be +  // a block that has exactly one successor that can reach the loop. +  if (Loop *L = LI.getLoopFor(BB)) +    return {L->getLoopPredecessor(), L->getHeader()}; + +  return {nullptr, nullptr}; +} + +/// SCEV structural equivalence is usually sufficient for testing whether two +/// expressions are equal, however for the purposes of looking for a condition +/// guarding a loop, it can be useful to be a little more general, since a +/// front-end may have replicated the controlling expression. +static bool HasSameValue(const SCEV *A, const SCEV *B) { +  // Quick check to see if they are the same SCEV. +  if (A == B) return true; + +  auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) { +    // Not all instructions that are "identical" compute the same value.  For +    // instance, two distinct alloca instructions allocating the same type are +    // identical and do not read memory; but compute distinct values. +    return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A)); +  }; + +  // Otherwise, if they're both SCEVUnknown, it's possible that they hold +  // two different instructions with the same value. Check for this case. +  if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A)) +    if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B)) +      if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue())) +        if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue())) +          if (ComputesEqualValues(AI, BI)) +            return true; + +  // Otherwise assume they may have a different value. +  return false; +} + +bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, +                                           const SCEV *&LHS, const SCEV *&RHS, +                                           unsigned Depth) { +  bool Changed = false; + +  // If we hit the max recursion limit bail out. +  if (Depth >= 3) +    return false; + +  // Canonicalize a constant to the right side. +  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) { +    // Check for both operands constant. +    if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { +      if (ConstantExpr::getICmp(Pred, +                                LHSC->getValue(), +                                RHSC->getValue())->isNullValue()) +        goto trivially_false; +      else +        goto trivially_true; +    } +    // Otherwise swap the operands to put the constant on the right. +    std::swap(LHS, RHS); +    Pred = ICmpInst::getSwappedPredicate(Pred); +    Changed = true; +  } + +  // If we're comparing an addrec with a value which is loop-invariant in the +  // addrec's loop, put the addrec on the left. Also make a dominance check, +  // as both operands could be addrecs loop-invariant in each other's loop. +  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) { +    const Loop *L = AR->getLoop(); +    if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) { +      std::swap(LHS, RHS); +      Pred = ICmpInst::getSwappedPredicate(Pred); +      Changed = true; +    } +  } + +  // If there's a constant operand, canonicalize comparisons with boundary +  // cases, and canonicalize *-or-equal comparisons to regular comparisons. +  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) { +    const APInt &RA = RC->getAPInt(); + +    bool SimplifiedByConstantRange = false; + +    if (!ICmpInst::isEquality(Pred)) { +      ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); +      if (ExactCR.isFullSet()) +        goto trivially_true; +      else if (ExactCR.isEmptySet()) +        goto trivially_false; + +      APInt NewRHS; +      CmpInst::Predicate NewPred; +      if (ExactCR.getEquivalentICmp(NewPred, NewRHS) && +          ICmpInst::isEquality(NewPred)) { +        // We were able to convert an inequality to an equality. +        Pred = NewPred; +        RHS = getConstant(NewRHS); +        Changed = SimplifiedByConstantRange = true; +      } +    } + +    if (!SimplifiedByConstantRange) { +      switch (Pred) { +      default: +        break; +      case ICmpInst::ICMP_EQ: +      case ICmpInst::ICMP_NE: +        // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. +        if (!RA) +          if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) +            if (const SCEVMulExpr *ME = +                    dyn_cast<SCEVMulExpr>(AE->getOperand(0))) +              if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && +                  ME->getOperand(0)->isAllOnesValue()) { +                RHS = AE->getOperand(1); +                LHS = ME->getOperand(1); +                Changed = true; +              } +        break; + + +        // The "Should have been caught earlier!" messages refer to the fact +        // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above +        // should have fired on the corresponding cases, and canonicalized the +        // check to trivially_true or trivially_false. + +      case ICmpInst::ICMP_UGE: +        assert(!RA.isMinValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_UGT; +        RHS = getConstant(RA - 1); +        Changed = true; +        break; +      case ICmpInst::ICMP_ULE: +        assert(!RA.isMaxValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_ULT; +        RHS = getConstant(RA + 1); +        Changed = true; +        break; +      case ICmpInst::ICMP_SGE: +        assert(!RA.isMinSignedValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_SGT; +        RHS = getConstant(RA - 1); +        Changed = true; +        break; +      case ICmpInst::ICMP_SLE: +        assert(!RA.isMaxSignedValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_SLT; +        RHS = getConstant(RA + 1); +        Changed = true; +        break; +      } +    } +  } + +  // Check for obvious equality. +  if (HasSameValue(LHS, RHS)) { +    if (ICmpInst::isTrueWhenEqual(Pred)) +      goto trivially_true; +    if (ICmpInst::isFalseWhenEqual(Pred)) +      goto trivially_false; +  } + +  // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by +  // adding or subtracting 1 from one of the operands. +  switch (Pred) { +  case ICmpInst::ICMP_SLE: +    if (!getSignedRangeMax(RHS).isMaxSignedValue()) { +      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, +                       SCEV::FlagNSW); +      Pred = ICmpInst::ICMP_SLT; +      Changed = true; +    } else if (!getSignedRangeMin(LHS).isMinSignedValue()) { +      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS, +                       SCEV::FlagNSW); +      Pred = ICmpInst::ICMP_SLT; +      Changed = true; +    } +    break; +  case ICmpInst::ICMP_SGE: +    if (!getSignedRangeMin(RHS).isMinSignedValue()) { +      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, +                       SCEV::FlagNSW); +      Pred = ICmpInst::ICMP_SGT; +      Changed = true; +    } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) { +      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, +                       SCEV::FlagNSW); +      Pred = ICmpInst::ICMP_SGT; +      Changed = true; +    } +    break; +  case ICmpInst::ICMP_ULE: +    if (!getUnsignedRangeMax(RHS).isMaxValue()) { +      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, +                       SCEV::FlagNUW); +      Pred = ICmpInst::ICMP_ULT; +      Changed = true; +    } else if (!getUnsignedRangeMin(LHS).isMinValue()) { +      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS); +      Pred = ICmpInst::ICMP_ULT; +      Changed = true; +    } +    break; +  case ICmpInst::ICMP_UGE: +    if (!getUnsignedRangeMin(RHS).isMinValue()) { +      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); +      Pred = ICmpInst::ICMP_UGT; +      Changed = true; +    } else if (!getUnsignedRangeMax(LHS).isMaxValue()) { +      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS, +                       SCEV::FlagNUW); +      Pred = ICmpInst::ICMP_UGT; +      Changed = true; +    } +    break; +  default: +    break; +  } + +  // TODO: More simplifications are possible here. + +  // Recursively simplify until we either hit a recursion limit or nothing +  // changes. +  if (Changed) +    return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); + +  return Changed; + +trivially_true: +  // Return 0 == 0. +  LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); +  Pred = ICmpInst::ICMP_EQ; +  return true; + +trivially_false: +  // Return 0 != 0. +  LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); +  Pred = ICmpInst::ICMP_NE; +  return true; +} + +bool ScalarEvolution::isKnownNegative(const SCEV *S) { +  return getSignedRangeMax(S).isNegative(); +} + +bool ScalarEvolution::isKnownPositive(const SCEV *S) { +  return getSignedRangeMin(S).isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonNegative(const SCEV *S) { +  return !getSignedRangeMin(S).isNegative(); +} + +bool ScalarEvolution::isKnownNonPositive(const SCEV *S) { +  return !getSignedRangeMax(S).isStrictlyPositive(); +} + +bool ScalarEvolution::isKnownNonZero(const SCEV *S) { +  return isKnownNegative(S) || isKnownPositive(S); +} + +std::pair<const SCEV *, const SCEV *> +ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { +  // Compute SCEV on entry of loop L. +  const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); +  if (Start == getCouldNotCompute()) +    return { Start, Start }; +  // Compute post increment SCEV for loop L. +  const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); +  assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); +  return { Start, PostInc }; +} + +bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, +                                          const SCEV *LHS, const SCEV *RHS) { +  // First collect all loops. +  SmallPtrSet<const Loop *, 8> LoopsUsed; +  getUsedLoops(LHS, LoopsUsed); +  getUsedLoops(RHS, LoopsUsed); + +  if (LoopsUsed.empty()) +    return false; + +  // Domination relationship must be a linear order on collected loops. +#ifndef NDEBUG +  for (auto *L1 : LoopsUsed) +    for (auto *L2 : LoopsUsed) +      assert((DT.dominates(L1->getHeader(), L2->getHeader()) || +              DT.dominates(L2->getHeader(), L1->getHeader())) && +             "Domination relationship is not a linear order"); +#endif + +  const Loop *MDL = +      *std::max_element(LoopsUsed.begin(), LoopsUsed.end(), +                        [&](const Loop *L1, const Loop *L2) { +         return DT.properlyDominates(L1->getHeader(), L2->getHeader()); +       }); + +  // Get init and post increment value for LHS. +  auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS); +  // if LHS contains unknown non-invariant SCEV then bail out. +  if (SplitLHS.first == getCouldNotCompute()) +    return false; +  assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC"); +  // Get init and post increment value for RHS. +  auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS); +  // if RHS contains unknown non-invariant SCEV then bail out. +  if (SplitRHS.first == getCouldNotCompute()) +    return false; +  assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC"); +  // It is possible that init SCEV contains an invariant load but it does +  // not dominate MDL and is not available at MDL loop entry, so we should +  // check it here. +  if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) || +      !isAvailableAtLoopEntry(SplitRHS.first, MDL)) +    return false; + +  return isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first) && +         isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, +                                     SplitRHS.second); +} + +bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, +                                       const SCEV *LHS, const SCEV *RHS) { +  // Canonicalize the inputs first. +  (void)SimplifyICmpOperands(Pred, LHS, RHS); + +  if (isKnownViaInduction(Pred, LHS, RHS)) +    return true; + +  if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) +    return true; + +  // Otherwise see what can be done with some simple reasoning. +  return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS); +} + +bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, +                                              const SCEVAddRecExpr *LHS, +                                              const SCEV *RHS) { +  const Loop *L = LHS->getLoop(); +  return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) && +         isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS); +} + +bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, +                                           ICmpInst::Predicate Pred, +                                           bool &Increasing) { +  bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing); + +#ifndef NDEBUG +  // Verify an invariant: inverting the predicate should turn a monotonically +  // increasing change to a monotonically decreasing one, and vice versa. +  bool IncreasingSwapped; +  bool ResultSwapped = isMonotonicPredicateImpl( +      LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped); + +  assert(Result == ResultSwapped && "should be able to analyze both!"); +  if (ResultSwapped) +    assert(Increasing == !IncreasingSwapped && +           "monotonicity should flip as we flip the predicate"); +#endif + +  return Result; +} + +bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, +                                               ICmpInst::Predicate Pred, +                                               bool &Increasing) { + +  // A zero step value for LHS means the induction variable is essentially a +  // loop invariant value. We don't really depend on the predicate actually +  // flipping from false to true (for increasing predicates, and the other way +  // around for decreasing predicates), all we care about is that *if* the +  // predicate changes then it only changes from false to true. +  // +  // A zero step value in itself is not very useful, but there may be places +  // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be +  // as general as possible. + +  switch (Pred) { +  default: +    return false; // Conservative answer + +  case ICmpInst::ICMP_UGT: +  case ICmpInst::ICMP_UGE: +  case ICmpInst::ICMP_ULT: +  case ICmpInst::ICMP_ULE: +    if (!LHS->hasNoUnsignedWrap()) +      return false; + +    Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; +    return true; + +  case ICmpInst::ICMP_SGT: +  case ICmpInst::ICMP_SGE: +  case ICmpInst::ICMP_SLT: +  case ICmpInst::ICMP_SLE: { +    if (!LHS->hasNoSignedWrap()) +      return false; + +    const SCEV *Step = LHS->getStepRecurrence(*this); + +    if (isKnownNonNegative(Step)) { +      Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE; +      return true; +    } + +    if (isKnownNonPositive(Step)) { +      Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; +      return true; +    } + +    return false; +  } + +  } + +  llvm_unreachable("switch has default clause!"); +} + +bool ScalarEvolution::isLoopInvariantPredicate( +    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, +    ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS, +    const SCEV *&InvariantRHS) { + +  // If there is a loop-invariant, force it into the RHS, otherwise bail out. +  if (!isLoopInvariant(RHS, L)) { +    if (!isLoopInvariant(LHS, L)) +      return false; + +    std::swap(LHS, RHS); +    Pred = ICmpInst::getSwappedPredicate(Pred); +  } + +  const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS); +  if (!ArLHS || ArLHS->getLoop() != L) +    return false; + +  bool Increasing; +  if (!isMonotonicPredicate(ArLHS, Pred, Increasing)) +    return false; + +  // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to +  // true as the loop iterates, and the backedge is control dependent on +  // "ArLHS `Pred` RHS" == true then we can reason as follows: +  // +  //   * if the predicate was false in the first iteration then the predicate +  //     is never evaluated again, since the loop exits without taking the +  //     backedge. +  //   * if the predicate was true in the first iteration then it will +  //     continue to be true for all future iterations since it is +  //     monotonically increasing. +  // +  // For both the above possibilities, we can replace the loop varying +  // predicate with its value on the first iteration of the loop (which is +  // loop invariant). +  // +  // A similar reasoning applies for a monotonically decreasing predicate, by +  // replacing true with false and false with true in the above two bullets. + +  auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred); + +  if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) +    return false; + +  InvariantPred = Pred; +  InvariantLHS = ArLHS->getStart(); +  InvariantRHS = RHS; +  return true; +} + +bool ScalarEvolution::isKnownPredicateViaConstantRanges( +    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { +  if (HasSameValue(LHS, RHS)) +    return ICmpInst::isTrueWhenEqual(Pred); + +  // This code is split out from isKnownPredicate because it is called from +  // within isLoopEntryGuardedByCond. + +  auto CheckRanges = +      [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { +    return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) +        .contains(RangeLHS); +  }; + +  // The check at the top of the function catches the case where the values are +  // known to be equal. +  if (Pred == CmpInst::ICMP_EQ) +    return false; + +  if (Pred == CmpInst::ICMP_NE) +    return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || +           CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) || +           isKnownNonZero(getMinusSCEV(LHS, RHS)); + +  if (CmpInst::isSigned(Pred)) +    return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); + +  return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); +} + +bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, +                                                    const SCEV *LHS, +                                                    const SCEV *RHS) { +  // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer. +  // Return Y via OutY. +  auto MatchBinaryAddToConst = +      [this](const SCEV *Result, const SCEV *X, APInt &OutY, +             SCEV::NoWrapFlags ExpectedFlags) { +    const SCEV *NonConstOp, *ConstOp; +    SCEV::NoWrapFlags FlagsPresent; + +    if (!splitBinaryAdd(Result, ConstOp, NonConstOp, FlagsPresent) || +        !isa<SCEVConstant>(ConstOp) || NonConstOp != X) +      return false; + +    OutY = cast<SCEVConstant>(ConstOp)->getAPInt(); +    return (FlagsPresent & ExpectedFlags) == ExpectedFlags; +  }; + +  APInt C; + +  switch (Pred) { +  default: +    break; + +  case ICmpInst::ICMP_SGE: +    std::swap(LHS, RHS); +    LLVM_FALLTHROUGH; +  case ICmpInst::ICMP_SLE: +    // X s<= (X + C)<nsw> if C >= 0 +    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && C.isNonNegative()) +      return true; + +    // (X + C)<nsw> s<= X if C <= 0 +    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && +        !C.isStrictlyPositive()) +      return true; +    break; + +  case ICmpInst::ICMP_SGT: +    std::swap(LHS, RHS); +    LLVM_FALLTHROUGH; +  case ICmpInst::ICMP_SLT: +    // X s< (X + C)<nsw> if C > 0 +    if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNSW) && +        C.isStrictlyPositive()) +      return true; + +    // (X + C)<nsw> s< X if C < 0 +    if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) +      return true; +    break; +  } + +  return false; +} + +bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, +                                                   const SCEV *LHS, +                                                   const SCEV *RHS) { +  if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate) +    return false; + +  // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on +  // the stack can result in exponential time complexity. +  SaveAndRestore<bool> Restore(ProvingSplitPredicate, true); + +  // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L +  // +  // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use +  // isKnownPredicate.  isKnownPredicate is more powerful, but also more +  // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the +  // interesting cases seen in practice.  We can consider "upgrading" L >= 0 to +  // use isKnownPredicate later if needed. +  return isKnownNonNegative(RHS) && +         isKnownPredicate(CmpInst::ICMP_SGE, LHS, getZero(LHS->getType())) && +         isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); +} + +bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB, +                                        ICmpInst::Predicate Pred, +                                        const SCEV *LHS, const SCEV *RHS) { +  // No need to even try if we know the module has no guards. +  if (!HasGuards) +    return false; + +  return any_of(*BB, [&](Instruction &I) { +    using namespace llvm::PatternMatch; + +    Value *Condition; +    return match(&I, m_Intrinsic<Intrinsic::experimental_guard>( +                         m_Value(Condition))) && +           isImpliedCond(Pred, LHS, RHS, Condition, false); +  }); +} + +/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is +/// protected by a conditional between LHS and RHS.  This is used to +/// to eliminate casts. +bool +ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, +                                             ICmpInst::Predicate Pred, +                                             const SCEV *LHS, const SCEV *RHS) { +  // Interpret a null as meaning no loop, where there is obviously no guard +  // (interprocedural conditions notwithstanding). +  if (!L) return true; + +  if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) +    return true; + +  BasicBlock *Latch = L->getLoopLatch(); +  if (!Latch) +    return false; + +  BranchInst *LoopContinuePredicate = +    dyn_cast<BranchInst>(Latch->getTerminator()); +  if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && +      isImpliedCond(Pred, LHS, RHS, +                    LoopContinuePredicate->getCondition(), +                    LoopContinuePredicate->getSuccessor(0) != L->getHeader())) +    return true; + +  // We don't want more than one activation of the following loops on the stack +  // -- that can lead to O(n!) time complexity. +  if (WalkingBEDominatingConds) +    return false; + +  SaveAndRestore<bool> ClearOnExit(WalkingBEDominatingConds, true); + +  // See if we can exploit a trip count to prove the predicate. +  const auto &BETakenInfo = getBackedgeTakenInfo(L); +  const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this); +  if (LatchBECount != getCouldNotCompute()) { +    // We know that Latch branches back to the loop header exactly +    // LatchBECount times.  This means the backdege condition at Latch is +    // equivalent to  "{0,+,1} u< LatchBECount". +    Type *Ty = LatchBECount->getType(); +    auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW); +    const SCEV *LoopCounter = +      getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags); +    if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter, +                      LatchBECount)) +      return true; +  } + +  // Check conditions due to any @llvm.assume intrinsics. +  for (auto &AssumeVH : AC.assumptions()) { +    if (!AssumeVH) +      continue; +    auto *CI = cast<CallInst>(AssumeVH); +    if (!DT.dominates(CI, Latch->getTerminator())) +      continue; + +    if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) +      return true; +  } + +  // If the loop is not reachable from the entry block, we risk running into an +  // infinite loop as we walk up into the dom tree.  These loops do not matter +  // anyway, so we just return a conservative answer when we see them. +  if (!DT.isReachableFromEntry(L->getHeader())) +    return false; + +  if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) +    return true; + +  for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; +       DTN != HeaderDTN; DTN = DTN->getIDom()) { +    assert(DTN && "should reach the loop header before reaching the root!"); + +    BasicBlock *BB = DTN->getBlock(); +    if (isImpliedViaGuard(BB, Pred, LHS, RHS)) +      return true; + +    BasicBlock *PBB = BB->getSinglePredecessor(); +    if (!PBB) +      continue; + +    BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator()); +    if (!ContinuePredicate || !ContinuePredicate->isConditional()) +      continue; + +    Value *Condition = ContinuePredicate->getCondition(); + +    // If we have an edge `E` within the loop body that dominates the only +    // latch, the condition guarding `E` also guards the backedge.  This +    // reasoning works only for loops with a single latch. + +    BasicBlockEdge DominatingEdge(PBB, BB); +    if (DominatingEdge.isSingleEdge()) { +      // We're constructively (and conservatively) enumerating edges within the +      // loop body that dominate the latch.  The dominator tree better agree +      // with us on this: +      assert(DT.dominates(DominatingEdge, Latch) && "should be!"); + +      if (isImpliedCond(Pred, LHS, RHS, Condition, +                        BB != ContinuePredicate->getSuccessor(0))) +        return true; +    } +  } + +  return false; +} + +bool +ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, +                                          ICmpInst::Predicate Pred, +                                          const SCEV *LHS, const SCEV *RHS) { +  // Interpret a null as meaning no loop, where there is obviously no guard +  // (interprocedural conditions notwithstanding). +  if (!L) return false; + +  // Both LHS and RHS must be available at loop entry. +  assert(isAvailableAtLoopEntry(LHS, L) && +         "LHS is not available at Loop Entry"); +  assert(isAvailableAtLoopEntry(RHS, L) && +         "RHS is not available at Loop Entry"); + +  if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) +    return true; + +  // If we cannot prove strict comparison (e.g. a > b), maybe we can prove +  // the facts (a >= b && a != b) separately. A typical situation is when the +  // non-strict comparison is known from ranges and non-equality is known from +  // dominating predicates. If we are proving strict comparison, we always try +  // to prove non-equality and non-strict comparison separately. +  auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred); +  const bool ProvingStrictComparison = (Pred != NonStrictPredicate); +  bool ProvedNonStrictComparison = false; +  bool ProvedNonEquality = false; + +  if (ProvingStrictComparison) { +    ProvedNonStrictComparison = +        isKnownViaNonRecursiveReasoning(NonStrictPredicate, LHS, RHS); +    ProvedNonEquality = +        isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, LHS, RHS); +    if (ProvedNonStrictComparison && ProvedNonEquality) +      return true; +  } + +  // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard. +  auto ProveViaGuard = [&](BasicBlock *Block) { +    if (isImpliedViaGuard(Block, Pred, LHS, RHS)) +      return true; +    if (ProvingStrictComparison) { +      if (!ProvedNonStrictComparison) +        ProvedNonStrictComparison = +            isImpliedViaGuard(Block, NonStrictPredicate, LHS, RHS); +      if (!ProvedNonEquality) +        ProvedNonEquality = +            isImpliedViaGuard(Block, ICmpInst::ICMP_NE, LHS, RHS); +      if (ProvedNonStrictComparison && ProvedNonEquality) +        return true; +    } +    return false; +  }; + +  // Try to prove (Pred, LHS, RHS) using isImpliedCond. +  auto ProveViaCond = [&](Value *Condition, bool Inverse) { +    if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse)) +      return true; +    if (ProvingStrictComparison) { +      if (!ProvedNonStrictComparison) +        ProvedNonStrictComparison = +            isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse); +      if (!ProvedNonEquality) +        ProvedNonEquality = +            isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse); +      if (ProvedNonStrictComparison && ProvedNonEquality) +        return true; +    } +    return false; +  }; + +  // Starting at the loop predecessor, climb up the predecessor chain, as long +  // as there are predecessors that can be found that have unique successors +  // leading to the original header. +  for (std::pair<BasicBlock *, BasicBlock *> +         Pair(L->getLoopPredecessor(), L->getHeader()); +       Pair.first; +       Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { + +    if (ProveViaGuard(Pair.first)) +      return true; + +    BranchInst *LoopEntryPredicate = +      dyn_cast<BranchInst>(Pair.first->getTerminator()); +    if (!LoopEntryPredicate || +        LoopEntryPredicate->isUnconditional()) +      continue; + +    if (ProveViaCond(LoopEntryPredicate->getCondition(), +                     LoopEntryPredicate->getSuccessor(0) != Pair.second)) +      return true; +  } + +  // Check conditions due to any @llvm.assume intrinsics. +  for (auto &AssumeVH : AC.assumptions()) { +    if (!AssumeVH) +      continue; +    auto *CI = cast<CallInst>(AssumeVH); +    if (!DT.dominates(CI, L->getHeader())) +      continue; + +    if (ProveViaCond(CI->getArgOperand(0), false)) +      return true; +  } + +  return false; +} + +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, +                                    const SCEV *LHS, const SCEV *RHS, +                                    Value *FoundCondValue, +                                    bool Inverse) { +  if (!PendingLoopPredicates.insert(FoundCondValue).second) +    return false; + +  auto ClearOnExit = +      make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); }); + +  // Recursively handle And and Or conditions. +  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) { +    if (BO->getOpcode() == Instruction::And) { +      if (!Inverse) +        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || +               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); +    } else if (BO->getOpcode() == Instruction::Or) { +      if (Inverse) +        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || +               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); +    } +  } + +  ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); +  if (!ICI) return false; + +  // Now that we found a conditional branch that dominates the loop or controls +  // the loop latch. Check to see if it is the comparison we are looking for. +  ICmpInst::Predicate FoundPred; +  if (Inverse) +    FoundPred = ICI->getInversePredicate(); +  else +    FoundPred = ICI->getPredicate(); + +  const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); +  const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); + +  return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); +} + +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, +                                    const SCEV *RHS, +                                    ICmpInst::Predicate FoundPred, +                                    const SCEV *FoundLHS, +                                    const SCEV *FoundRHS) { +  // Balance the types. +  if (getTypeSizeInBits(LHS->getType()) < +      getTypeSizeInBits(FoundLHS->getType())) { +    if (CmpInst::isSigned(Pred)) { +      LHS = getSignExtendExpr(LHS, FoundLHS->getType()); +      RHS = getSignExtendExpr(RHS, FoundLHS->getType()); +    } else { +      LHS = getZeroExtendExpr(LHS, FoundLHS->getType()); +      RHS = getZeroExtendExpr(RHS, FoundLHS->getType()); +    } +  } else if (getTypeSizeInBits(LHS->getType()) > +      getTypeSizeInBits(FoundLHS->getType())) { +    if (CmpInst::isSigned(FoundPred)) { +      FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType()); +      FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType()); +    } else { +      FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType()); +      FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); +    } +  } + +  // Canonicalize the query to match the way instcombine will have +  // canonicalized the comparison. +  if (SimplifyICmpOperands(Pred, LHS, RHS)) +    if (LHS == RHS) +      return CmpInst::isTrueWhenEqual(Pred); +  if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS)) +    if (FoundLHS == FoundRHS) +      return CmpInst::isFalseWhenEqual(FoundPred); + +  // Check to see if we can make the LHS or RHS match. +  if (LHS == FoundRHS || RHS == FoundLHS) { +    if (isa<SCEVConstant>(RHS)) { +      std::swap(FoundLHS, FoundRHS); +      FoundPred = ICmpInst::getSwappedPredicate(FoundPred); +    } else { +      std::swap(LHS, RHS); +      Pred = ICmpInst::getSwappedPredicate(Pred); +    } +  } + +  // Check whether the found predicate is the same as the desired predicate. +  if (FoundPred == Pred) +    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + +  // Check whether swapping the found predicate makes it the same as the +  // desired predicate. +  if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { +    if (isa<SCEVConstant>(RHS)) +      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); +    else +      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), +                                   RHS, LHS, FoundLHS, FoundRHS); +  } + +  // Unsigned comparison is the same as signed comparison when both the operands +  // are non-negative. +  if (CmpInst::isUnsigned(FoundPred) && +      CmpInst::getSignedPredicate(FoundPred) == Pred && +      isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) +    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + +  // Check if we can make progress by sharpening ranges. +  if (FoundPred == ICmpInst::ICMP_NE && +      (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) { + +    const SCEVConstant *C = nullptr; +    const SCEV *V = nullptr; + +    if (isa<SCEVConstant>(FoundLHS)) { +      C = cast<SCEVConstant>(FoundLHS); +      V = FoundRHS; +    } else { +      C = cast<SCEVConstant>(FoundRHS); +      V = FoundLHS; +    } + +    // The guarding predicate tells us that C != V. If the known range +    // of V is [C, t), we can sharpen the range to [C + 1, t).  The +    // range we consider has to correspond to same signedness as the +    // predicate we're interested in folding. + +    APInt Min = ICmpInst::isSigned(Pred) ? +        getSignedRangeMin(V) : getUnsignedRangeMin(V); + +    if (Min == C->getAPInt()) { +      // Given (V >= Min && V != Min) we conclude V >= (Min + 1). +      // This is true even if (Min + 1) wraps around -- in case of +      // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). + +      APInt SharperMin = Min + 1; + +      switch (Pred) { +        case ICmpInst::ICMP_SGE: +        case ICmpInst::ICMP_UGE: +          // We know V `Pred` SharperMin.  If this implies LHS `Pred` +          // RHS, we're done. +          if (isImpliedCondOperands(Pred, LHS, RHS, V, +                                    getConstant(SharperMin))) +            return true; +          LLVM_FALLTHROUGH; + +        case ICmpInst::ICMP_SGT: +        case ICmpInst::ICMP_UGT: +          // We know from the range information that (V `Pred` Min || +          // V == Min).  We know from the guarding condition that !(V +          // == Min).  This gives us +          // +          //       V `Pred` Min || V == Min && !(V == Min) +          //   =>  V `Pred` Min +          // +          // If V `Pred` Min implies LHS `Pred` RHS, we're done. + +          if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) +            return true; +          LLVM_FALLTHROUGH; + +        default: +          // No change +          break; +      } +    } +  } + +  // Check whether the actual condition is beyond sufficient. +  if (FoundPred == ICmpInst::ICMP_EQ) +    if (ICmpInst::isTrueWhenEqual(Pred)) +      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) +        return true; +  if (Pred == ICmpInst::ICMP_NE) +    if (!ICmpInst::isTrueWhenEqual(FoundPred)) +      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) +        return true; + +  // Otherwise assume the worst. +  return false; +} + +bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr, +                                     const SCEV *&L, const SCEV *&R, +                                     SCEV::NoWrapFlags &Flags) { +  const auto *AE = dyn_cast<SCEVAddExpr>(Expr); +  if (!AE || AE->getNumOperands() != 2) +    return false; + +  L = AE->getOperand(0); +  R = AE->getOperand(1); +  Flags = AE->getNoWrapFlags(); +  return true; +} + +Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More, +                                                           const SCEV *Less) { +  // We avoid subtracting expressions here because this function is usually +  // fairly deep in the call stack (i.e. is called many times). + +  if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) { +    const auto *LAR = cast<SCEVAddRecExpr>(Less); +    const auto *MAR = cast<SCEVAddRecExpr>(More); + +    if (LAR->getLoop() != MAR->getLoop()) +      return None; + +    // We look at affine expressions only; not for correctness but to keep +    // getStepRecurrence cheap. +    if (!LAR->isAffine() || !MAR->isAffine()) +      return None; + +    if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) +      return None; + +    Less = LAR->getStart(); +    More = MAR->getStart(); + +    // fall through +  } + +  if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) { +    const auto &M = cast<SCEVConstant>(More)->getAPInt(); +    const auto &L = cast<SCEVConstant>(Less)->getAPInt(); +    return M - L; +  } + +  SCEV::NoWrapFlags Flags; +  const SCEV *LLess = nullptr, *RLess = nullptr; +  const SCEV *LMore = nullptr, *RMore = nullptr; +  const SCEVConstant *C1 = nullptr, *C2 = nullptr; +  // Compare (X + C1) vs X. +  if (splitBinaryAdd(Less, LLess, RLess, Flags)) +    if ((C1 = dyn_cast<SCEVConstant>(LLess))) +      if (RLess == More) +        return -(C1->getAPInt()); + +  // Compare X vs (X + C2). +  if (splitBinaryAdd(More, LMore, RMore, Flags)) +    if ((C2 = dyn_cast<SCEVConstant>(LMore))) +      if (RMore == Less) +        return C2->getAPInt(); + +  // Compare (X + C1) vs (X + C2). +  if (C1 && C2 && RLess == RMore) +    return C2->getAPInt() - C1->getAPInt(); + +  return None; +} + +bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( +    ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, +    const SCEV *FoundLHS, const SCEV *FoundRHS) { +  if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT) +    return false; + +  const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS); +  if (!AddRecLHS) +    return false; + +  const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS); +  if (!AddRecFoundLHS) +    return false; + +  // We'd like to let SCEV reason about control dependencies, so we constrain +  // both the inequalities to be about add recurrences on the same loop.  This +  // way we can use isLoopEntryGuardedByCond later. + +  const Loop *L = AddRecFoundLHS->getLoop(); +  if (L != AddRecLHS->getLoop()) +    return false; + +  //  FoundLHS u< FoundRHS u< -C =>  (FoundLHS + C) u< (FoundRHS + C) ... (1) +  // +  //  FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C) +  //                                                                  ... (2) +  // +  // Informal proof for (2), assuming (1) [*]: +  // +  // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**] +  // +  // Then +  // +  //       FoundLHS s< FoundRHS s< INT_MIN - C +  // <=>  (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C   [ using (3) ] +  // <=>  (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ] +  // <=>  (FoundLHS + INT_MIN + C + INT_MIN) s< +  //                        (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ] +  // <=>  FoundLHS + C s< FoundRHS + C +  // +  // [*]: (1) can be proved by ruling out overflow. +  // +  // [**]: This can be proved by analyzing all the four possibilities: +  //    (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and +  //    (A s>= 0, B s>= 0). +  // +  // Note: +  // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C" +  // will not sign underflow.  For instance, say FoundLHS = (i8 -128), FoundRHS +  // = (i8 -127) and C = (i8 -100).  Then INT_MIN - C = (i8 -28), and FoundRHS +  // s< (INT_MIN - C).  Lack of sign overflow / underflow in "FoundRHS + C" is +  // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS + +  // C)". + +  Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS); +  Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS); +  if (!LDiff || !RDiff || *LDiff != *RDiff) +    return false; + +  if (LDiff->isMinValue()) +    return true; + +  APInt FoundRHSLimit; + +  if (Pred == CmpInst::ICMP_ULT) { +    FoundRHSLimit = -(*RDiff); +  } else { +    assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); +    FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff; +  } + +  // Try to prove (1) or (2), as needed. +  return isAvailableAtLoopEntry(FoundRHS, L) && +         isLoopEntryGuardedByCond(L, Pred, FoundRHS, +                                  getConstant(FoundRHSLimit)); +} + +bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, +                                        const SCEV *LHS, const SCEV *RHS, +                                        const SCEV *FoundLHS, +                                        const SCEV *FoundRHS, unsigned Depth) { +  const PHINode *LPhi = nullptr, *RPhi = nullptr; + +  auto ClearOnExit = make_scope_exit([&]() { +    if (LPhi) { +      bool Erased = PendingMerges.erase(LPhi); +      assert(Erased && "Failed to erase LPhi!"); +      (void)Erased; +    } +    if (RPhi) { +      bool Erased = PendingMerges.erase(RPhi); +      assert(Erased && "Failed to erase RPhi!"); +      (void)Erased; +    } +  }); + +  // Find respective Phis and check that they are not being pending. +  if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) +    if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) { +      if (!PendingMerges.insert(Phi).second) +        return false; +      LPhi = Phi; +    } +  if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS)) +    if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) { +      // If we detect a loop of Phi nodes being processed by this method, for +      // example: +      // +      //   %a = phi i32 [ %some1, %preheader ], [ %b, %latch ] +      //   %b = phi i32 [ %some2, %preheader ], [ %a, %latch ] +      // +      // we don't want to deal with a case that complex, so return conservative +      // answer false. +      if (!PendingMerges.insert(Phi).second) +        return false; +      RPhi = Phi; +    } + +  // If none of LHS, RHS is a Phi, nothing to do here. +  if (!LPhi && !RPhi) +    return false; + +  // If there is a SCEVUnknown Phi we are interested in, make it left. +  if (!LPhi) { +    std::swap(LHS, RHS); +    std::swap(FoundLHS, FoundRHS); +    std::swap(LPhi, RPhi); +    Pred = ICmpInst::getSwappedPredicate(Pred); +  } + +  assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!"); +  const BasicBlock *LBB = LPhi->getParent(); +  const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); + +  auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) { +    return isKnownViaNonRecursiveReasoning(Pred, S1, S2) || +           isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) || +           isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth); +  }; + +  if (RPhi && RPhi->getParent() == LBB) { +    // Case one: RHS is also a SCEVUnknown Phi from the same basic block. +    // If we compare two Phis from the same block, and for each entry block +    // the predicate is true for incoming values from this block, then the +    // predicate is also true for the Phis. +    for (const BasicBlock *IncBB : predecessors(LBB)) { +      const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); +      const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); +      if (!ProvedEasily(L, R)) +        return false; +    } +  } else if (RAR && RAR->getLoop()->getHeader() == LBB) { +    // Case two: RHS is also a Phi from the same basic block, and it is an +    // AddRec. It means that there is a loop which has both AddRec and Unknown +    // PHIs, for it we can compare incoming values of AddRec from above the loop +    // and latch with their respective incoming values of LPhi. +    // TODO: Generalize to handle loops with many inputs in a header. +    if (LPhi->getNumIncomingValues() != 2) return false; + +    auto *RLoop = RAR->getLoop(); +    auto *Predecessor = RLoop->getLoopPredecessor(); +    assert(Predecessor && "Loop with AddRec with no predecessor?"); +    const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); +    if (!ProvedEasily(L1, RAR->getStart())) +      return false; +    auto *Latch = RLoop->getLoopLatch(); +    assert(Latch && "Loop with AddRec with no latch?"); +    const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); +    if (!ProvedEasily(L2, RAR->getPostIncExpr(*this))) +      return false; +  } else { +    // In all other cases go over inputs of LHS and compare each of them to RHS, +    // the predicate is true for (LHS, RHS) if it is true for all such pairs. +    // At this point RHS is either a non-Phi, or it is a Phi from some block +    // different from LBB. +    for (const BasicBlock *IncBB : predecessors(LBB)) { +      // Check that RHS is available in this block. +      if (!dominates(RHS, IncBB)) +        return false; +      const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); +      if (!ProvedEasily(L, RHS)) +        return false; +    } +  } +  return true; +} + +bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, +                                            const SCEV *LHS, const SCEV *RHS, +                                            const SCEV *FoundLHS, +                                            const SCEV *FoundRHS) { +  if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) +    return true; + +  if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) +    return true; + +  return isImpliedCondOperandsHelper(Pred, LHS, RHS, +                                     FoundLHS, FoundRHS) || +         // ~x < ~y --> x > y +         isImpliedCondOperandsHelper(Pred, LHS, RHS, +                                     getNotSCEV(FoundRHS), +                                     getNotSCEV(FoundLHS)); +} + +/// If Expr computes ~A, return A else return nullptr +static const SCEV *MatchNotExpr(const SCEV *Expr) { +  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); +  if (!Add || Add->getNumOperands() != 2 || +      !Add->getOperand(0)->isAllOnesValue()) +    return nullptr; + +  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); +  if (!AddRHS || AddRHS->getNumOperands() != 2 || +      !AddRHS->getOperand(0)->isAllOnesValue()) +    return nullptr; + +  return AddRHS->getOperand(1); +} + +/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? +template<typename MaxExprType> +static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, +                              const SCEV *Candidate) { +  const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr); +  if (!MaxExpr) return false; + +  return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end(); +} + +/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? +template<typename MaxExprType> +static bool IsMinConsistingOf(ScalarEvolution &SE, +                              const SCEV *MaybeMinExpr, +                              const SCEV *Candidate) { +  const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr); +  if (!MaybeMaxExpr) +    return false; + +  return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate)); +} + +static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, +                                           ICmpInst::Predicate Pred, +                                           const SCEV *LHS, const SCEV *RHS) { +  // If both sides are affine addrecs for the same loop, with equal +  // steps, and we know the recurrences don't wrap, then we only +  // need to check the predicate on the starting values. + +  if (!ICmpInst::isRelational(Pred)) +    return false; + +  const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS); +  if (!LAR) +    return false; +  const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); +  if (!RAR) +    return false; +  if (LAR->getLoop() != RAR->getLoop()) +    return false; +  if (!LAR->isAffine() || !RAR->isAffine()) +    return false; + +  if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE)) +    return false; + +  SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ? +                         SCEV::FlagNSW : SCEV::FlagNUW; +  if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW)) +    return false; + +  return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart()); +} + +/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max +/// expression? +static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, +                                        ICmpInst::Predicate Pred, +                                        const SCEV *LHS, const SCEV *RHS) { +  switch (Pred) { +  default: +    return false; + +  case ICmpInst::ICMP_SGE: +    std::swap(LHS, RHS); +    LLVM_FALLTHROUGH; +  case ICmpInst::ICMP_SLE: +    return +      // min(A, ...) <= A +      IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) || +      // A <= max(A, ...) +      IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS); + +  case ICmpInst::ICMP_UGE: +    std::swap(LHS, RHS); +    LLVM_FALLTHROUGH; +  case ICmpInst::ICMP_ULE: +    return +      // min(A, ...) <= A +      IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) || +      // A <= max(A, ...) +      IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS); +  } + +  llvm_unreachable("covered switch fell through?!"); +} + +bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, +                                             const SCEV *LHS, const SCEV *RHS, +                                             const SCEV *FoundLHS, +                                             const SCEV *FoundRHS, +                                             unsigned Depth) { +  assert(getTypeSizeInBits(LHS->getType()) == +             getTypeSizeInBits(RHS->getType()) && +         "LHS and RHS have different sizes?"); +  assert(getTypeSizeInBits(FoundLHS->getType()) == +             getTypeSizeInBits(FoundRHS->getType()) && +         "FoundLHS and FoundRHS have different sizes?"); +  // We want to avoid hurting the compile time with analysis of too big trees. +  if (Depth > MaxSCEVOperationsImplicationDepth) +    return false; +  // We only want to work with ICMP_SGT comparison so far. +  // TODO: Extend to ICMP_UGT? +  if (Pred == ICmpInst::ICMP_SLT) { +    Pred = ICmpInst::ICMP_SGT; +    std::swap(LHS, RHS); +    std::swap(FoundLHS, FoundRHS); +  } +  if (Pred != ICmpInst::ICMP_SGT) +    return false; + +  auto GetOpFromSExt = [&](const SCEV *S) { +    if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S)) +      return Ext->getOperand(); +    // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off +    // the constant in some cases. +    return S; +  }; + +  // Acquire values from extensions. +  auto *OrigLHS = LHS; +  auto *OrigFoundLHS = FoundLHS; +  LHS = GetOpFromSExt(LHS); +  FoundLHS = GetOpFromSExt(FoundLHS); + +  // Is the SGT predicate can be proved trivially or using the found context. +  auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { +    return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) || +           isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, +                                  FoundRHS, Depth + 1); +  }; + +  if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) { +    // We want to avoid creation of any new non-constant SCEV. Since we are +    // going to compare the operands to RHS, we should be certain that we don't +    // need any size extensions for this. So let's decline all cases when the +    // sizes of types of LHS and RHS do not match. +    // TODO: Maybe try to get RHS from sext to catch more cases? +    if (getTypeSizeInBits(LHS->getType()) != getTypeSizeInBits(RHS->getType())) +      return false; + +    // Should not overflow. +    if (!LHSAddExpr->hasNoSignedWrap()) +      return false; + +    auto *LL = LHSAddExpr->getOperand(0); +    auto *LR = LHSAddExpr->getOperand(1); +    auto *MinusOne = getNegativeSCEV(getOne(RHS->getType())); + +    // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. +    auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { +      return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS); +    }; +    // Try to prove the following rule: +    // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS). +    // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS). +    if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL)) +      return true; +  } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { +    Value *LL, *LR; +    // FIXME: Once we have SDiv implemented, we can get rid of this matching. + +    using namespace llvm::PatternMatch; + +    if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { +      // Rules for division. +      // We are going to perform some comparisons with Denominator and its +      // derivative expressions. In general case, creating a SCEV for it may +      // lead to a complex analysis of the entire graph, and in particular it +      // can request trip count recalculation for the same loop. This would +      // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid +      // this, we only want to create SCEVs that are constants in this section. +      // So we bail if Denominator is not a constant. +      if (!isa<ConstantInt>(LR)) +        return false; + +      auto *Denominator = cast<SCEVConstant>(getSCEV(LR)); + +      // We want to make sure that LHS = FoundLHS / Denominator. If it is so, +      // then a SCEV for the numerator already exists and matches with FoundLHS. +      auto *Numerator = getExistingSCEV(LL); +      if (!Numerator || Numerator->getType() != FoundLHS->getType()) +        return false; + +      // Make sure that the numerator matches with FoundLHS and the denominator +      // is positive. +      if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator)) +        return false; + +      auto *DTy = Denominator->getType(); +      auto *FRHSTy = FoundRHS->getType(); +      if (DTy->isPointerTy() != FRHSTy->isPointerTy()) +        // One of types is a pointer and another one is not. We cannot extend +        // them properly to a wider type, so let us just reject this case. +        // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help +        // to avoid this check. +        return false; + +      // Given that: +      // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0. +      auto *WTy = getWiderType(DTy, FRHSTy); +      auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy); +      auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy); + +      // Try to prove the following rule: +      // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS). +      // For example, given that FoundLHS > 2. It means that FoundLHS is at +      // least 3. If we divide it by Denominator < 4, we will have at least 1. +      auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2)); +      if (isKnownNonPositive(RHS) && +          IsSGTViaContext(FoundRHSExt, DenomMinusTwo)) +        return true; + +      // Try to prove the following rule: +      // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS). +      // For example, given that FoundLHS > -3. Then FoundLHS is at least -2. +      // If we divide it by Denominator > 2, then: +      // 1. If FoundLHS is negative, then the result is 0. +      // 2. If FoundLHS is non-negative, then the result is non-negative. +      // Anyways, the result is non-negative. +      auto *MinusOne = getNegativeSCEV(getOne(WTy)); +      auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); +      if (isKnownNegative(RHS) && +          IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) +        return true; +    } +  } + +  // If our expression contained SCEVUnknown Phis, and we split it down and now +  // need to prove something for them, try to prove the predicate for every +  // possible incoming values of those Phis. +  if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1)) +    return true; + +  return false; +} + +bool +ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, +                                           const SCEV *LHS, const SCEV *RHS) { +  return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || +         IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || +         IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || +         isKnownPredicateViaNoOverflow(Pred, LHS, RHS); +} + +bool +ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, +                                             const SCEV *LHS, const SCEV *RHS, +                                             const SCEV *FoundLHS, +                                             const SCEV *FoundRHS) { +  switch (Pred) { +  default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); +  case ICmpInst::ICMP_EQ: +  case ICmpInst::ICMP_NE: +    if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS)) +      return true; +    break; +  case ICmpInst::ICMP_SLT: +  case ICmpInst::ICMP_SLE: +    if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && +        isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) +      return true; +    break; +  case ICmpInst::ICMP_SGT: +  case ICmpInst::ICMP_SGE: +    if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && +        isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) +      return true; +    break; +  case ICmpInst::ICMP_ULT: +  case ICmpInst::ICMP_ULE: +    if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && +        isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) +      return true; +    break; +  case ICmpInst::ICMP_UGT: +  case ICmpInst::ICMP_UGE: +    if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && +        isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) +      return true; +    break; +  } + +  // Maybe it can be proved via operations? +  if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS)) +    return true; + +  return false; +} + +bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, +                                                     const SCEV *LHS, +                                                     const SCEV *RHS, +                                                     const SCEV *FoundLHS, +                                                     const SCEV *FoundRHS) { +  if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS)) +    // The restriction on `FoundRHS` be lifted easily -- it exists only to +    // reduce the compile time impact of this optimization. +    return false; + +  Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS); +  if (!Addend) +    return false; + +  const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); + +  // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the +  // antecedent "`FoundLHS` `Pred` `FoundRHS`". +  ConstantRange FoundLHSRange = +      ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); + +  // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`: +  ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend)); + +  // We can also compute the range of values for `LHS` that satisfy the +  // consequent, "`LHS` `Pred` `RHS`": +  const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt(); +  ConstantRange SatisfyingLHSRange = +      ConstantRange::makeSatisfyingICmpRegion(Pred, ConstRHS); + +  // The antecedent implies the consequent if every value of `LHS` that +  // satisfies the antecedent also satisfies the consequent. +  return SatisfyingLHSRange.contains(LHSRange); +} + +bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, +                                         bool IsSigned, bool NoWrap) { +  assert(isKnownPositive(Stride) && "Positive stride expected!"); + +  if (NoWrap) return false; + +  unsigned BitWidth = getTypeSizeInBits(RHS->getType()); +  const SCEV *One = getOne(Stride->getType()); + +  if (IsSigned) { +    APInt MaxRHS = getSignedRangeMax(RHS); +    APInt MaxValue = APInt::getSignedMaxValue(BitWidth); +    APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); + +    // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow! +    return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS); +  } + +  APInt MaxRHS = getUnsignedRangeMax(RHS); +  APInt MaxValue = APInt::getMaxValue(BitWidth); +  APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); + +  // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow! +  return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS); +} + +bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, +                                         bool IsSigned, bool NoWrap) { +  if (NoWrap) return false; + +  unsigned BitWidth = getTypeSizeInBits(RHS->getType()); +  const SCEV *One = getOne(Stride->getType()); + +  if (IsSigned) { +    APInt MinRHS = getSignedRangeMin(RHS); +    APInt MinValue = APInt::getSignedMinValue(BitWidth); +    APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One)); + +    // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow! +    return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS); +  } + +  APInt MinRHS = getUnsignedRangeMin(RHS); +  APInt MinValue = APInt::getMinValue(BitWidth); +  APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One)); + +  // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow! +  return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS); +} + +const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, +                                            bool Equality) { +  const SCEV *One = getOne(Step->getType()); +  Delta = Equality ? getAddExpr(Delta, Step) +                   : getAddExpr(Delta, getMinusSCEV(Step, One)); +  return getUDivExpr(Delta, Step); +} + +const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, +                                                    const SCEV *Stride, +                                                    const SCEV *End, +                                                    unsigned BitWidth, +                                                    bool IsSigned) { + +  assert(!isKnownNonPositive(Stride) && +         "Stride is expected strictly positive!"); +  // Calculate the maximum backedge count based on the range of values +  // permitted by Start, End, and Stride. +  const SCEV *MaxBECount; +  APInt MinStart = +      IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start); + +  APInt StrideForMaxBECount = +      IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride); + +  // We already know that the stride is positive, so we paper over conservatism +  // in our range computation by forcing StrideForMaxBECount to be at least one. +  // In theory this is unnecessary, but we expect MaxBECount to be a +  // SCEVConstant, and (udiv <constant> 0) is not constant folded by SCEV (there +  // is nothing to constant fold it to). +  APInt One(BitWidth, 1, IsSigned); +  StrideForMaxBECount = APIntOps::smax(One, StrideForMaxBECount); + +  APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth) +                            : APInt::getMaxValue(BitWidth); +  APInt Limit = MaxValue - (StrideForMaxBECount - 1); + +  // Although End can be a MAX expression we estimate MaxEnd considering only +  // the case End = RHS of the loop termination condition. This is safe because +  // in the other case (End - Start) is zero, leading to a zero maximum backedge +  // taken count. +  APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit) +                          : APIntOps::umin(getUnsignedRangeMax(End), Limit); + +  MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */, +                              getConstant(StrideForMaxBECount) /* Step */, +                              false /* Equality */); + +  return MaxBECount; +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, +                                  const Loop *L, bool IsSigned, +                                  bool ControlsExit, bool AllowPredicates) { +  SmallPtrSet<const SCEVPredicate *, 4> Predicates; + +  const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); +  bool PredicatedIV = false; + +  if (!IV && AllowPredicates) { +    // Try to make this an AddRec using runtime tests, in the first X +    // iterations of this loop, where X is the SCEV expression found by the +    // algorithm below. +    IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); +    PredicatedIV = true; +  } + +  // Avoid weird loops +  if (!IV || IV->getLoop() != L || !IV->isAffine()) +    return getCouldNotCompute(); + +  bool NoWrap = ControlsExit && +                IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); + +  const SCEV *Stride = IV->getStepRecurrence(*this); + +  bool PositiveStride = isKnownPositive(Stride); + +  // Avoid negative or zero stride values. +  if (!PositiveStride) { +    // We can compute the correct backedge taken count for loops with unknown +    // strides if we can prove that the loop is not an infinite loop with side +    // effects. Here's the loop structure we are trying to handle - +    // +    // i = start +    // do { +    //   A[i] = i; +    //   i += s; +    // } while (i < end); +    // +    // The backedge taken count for such loops is evaluated as - +    // (max(end, start + stride) - start - 1) /u stride +    // +    // The additional preconditions that we need to check to prove correctness +    // of the above formula is as follows - +    // +    // a) IV is either nuw or nsw depending upon signedness (indicated by the +    //    NoWrap flag). +    // b) loop is single exit with no side effects. +    // +    // +    // Precondition a) implies that if the stride is negative, this is a single +    // trip loop. The backedge taken count formula reduces to zero in this case. +    // +    // Precondition b) implies that the unknown stride cannot be zero otherwise +    // we have UB. +    // +    // The positive stride case is the same as isKnownPositive(Stride) returning +    // true (original behavior of the function). +    // +    // We want to make sure that the stride is truly unknown as there are edge +    // cases where ScalarEvolution propagates no wrap flags to the +    // post-increment/decrement IV even though the increment/decrement operation +    // itself is wrapping. The computed backedge taken count may be wrong in +    // such cases. This is prevented by checking that the stride is not known to +    // be either positive or non-positive. For example, no wrap flags are +    // propagated to the post-increment IV of this loop with a trip count of 2 - +    // +    // unsigned char i; +    // for(i=127; i<128; i+=129) +    //   A[i] = i; +    // +    if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) || +        !loopHasNoSideEffects(L)) +      return getCouldNotCompute(); +  } else if (!Stride->isOne() && +             doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) +    // Avoid proven overflow cases: this will ensure that the backedge taken +    // count will not generate any unsigned overflow. Relaxed no-overflow +    // conditions exploit NoWrapFlags, allowing to optimize in presence of +    // undefined behaviors like the case of C language. +    return getCouldNotCompute(); + +  ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT +                                      : ICmpInst::ICMP_ULT; +  const SCEV *Start = IV->getStart(); +  const SCEV *End = RHS; +  // When the RHS is not invariant, we do not know the end bound of the loop and +  // cannot calculate the ExactBECount needed by ExitLimit. However, we can +  // calculate the MaxBECount, given the start, stride and max value for the end +  // bound of the loop (RHS), and the fact that IV does not overflow (which is +  // checked above). +  if (!isLoopInvariant(RHS, L)) { +    const SCEV *MaxBECount = computeMaxBECountForLT( +        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); +    return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, +                     false /*MaxOrZero*/, Predicates); +  } +  // If the backedge is taken at least once, then it will be taken +  // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start +  // is the LHS value of the less-than comparison the first time it is evaluated +  // and End is the RHS. +  const SCEV *BECountIfBackedgeTaken = +    computeBECount(getMinusSCEV(End, Start), Stride, false); +  // If the loop entry is guarded by the result of the backedge test of the +  // first loop iteration, then we know the backedge will be taken at least +  // once and so the backedge taken count is as above. If not then we use the +  // expression (max(End,Start)-Start)/Stride to describe the backedge count, +  // as if the backedge is taken at least once max(End,Start) is End and so the +  // result is as above, and if not max(End,Start) is Start so we get a backedge +  // count of zero. +  const SCEV *BECount; +  if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) +    BECount = BECountIfBackedgeTaken; +  else { +    End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); +    BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); +  } + +  const SCEV *MaxBECount; +  bool MaxOrZero = false; +  if (isa<SCEVConstant>(BECount)) +    MaxBECount = BECount; +  else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) { +    // If we know exactly how many times the backedge will be taken if it's +    // taken at least once, then the backedge count will either be that or +    // zero. +    MaxBECount = BECountIfBackedgeTaken; +    MaxOrZero = true; +  } else { +    MaxBECount = computeMaxBECountForLT( +        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); +  } + +  if (isa<SCEVCouldNotCompute>(MaxBECount) && +      !isa<SCEVCouldNotCompute>(BECount)) +    MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + +  return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates); +} + +ScalarEvolution::ExitLimit +ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, +                                     const Loop *L, bool IsSigned, +                                     bool ControlsExit, bool AllowPredicates) { +  SmallPtrSet<const SCEVPredicate *, 4> Predicates; +  // We handle only IV > Invariant +  if (!isLoopInvariant(RHS, L)) +    return getCouldNotCompute(); + +  const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); +  if (!IV && AllowPredicates) +    // Try to make this an AddRec using runtime tests, in the first X +    // iterations of this loop, where X is the SCEV expression found by the +    // algorithm below. +    IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); + +  // Avoid weird loops +  if (!IV || IV->getLoop() != L || !IV->isAffine()) +    return getCouldNotCompute(); + +  bool NoWrap = ControlsExit && +                IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); + +  const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); + +  // Avoid negative or zero stride values +  if (!isKnownPositive(Stride)) +    return getCouldNotCompute(); + +  // Avoid proven overflow cases: this will ensure that the backedge taken count +  // will not generate any unsigned overflow. Relaxed no-overflow conditions +  // exploit NoWrapFlags, allowing to optimize in presence of undefined +  // behaviors like the case of C language. +  if (!Stride->isOne() && doesIVOverflowOnGT(RHS, Stride, IsSigned, NoWrap)) +    return getCouldNotCompute(); + +  ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT +                                      : ICmpInst::ICMP_UGT; + +  const SCEV *Start = IV->getStart(); +  const SCEV *End = RHS; +  if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) +    End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); + +  const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); + +  APInt MaxStart = IsSigned ? getSignedRangeMax(Start) +                            : getUnsignedRangeMax(Start); + +  APInt MinStride = IsSigned ? getSignedRangeMin(Stride) +                             : getUnsignedRangeMin(Stride); + +  unsigned BitWidth = getTypeSizeInBits(LHS->getType()); +  APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1) +                         : APInt::getMinValue(BitWidth) + (MinStride - 1); + +  // Although End can be a MIN expression we estimate MinEnd considering only +  // the case End = RHS. This is safe because in the other case (Start - End) +  // is zero, leading to a zero maximum backedge taken count. +  APInt MinEnd = +    IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit) +             : APIntOps::umax(getUnsignedRangeMin(RHS), Limit); + + +  const SCEV *MaxBECount = getCouldNotCompute(); +  if (isa<SCEVConstant>(BECount)) +    MaxBECount = BECount; +  else +    MaxBECount = computeBECount(getConstant(MaxStart - MinEnd), +                                getConstant(MinStride), false); + +  if (isa<SCEVCouldNotCompute>(MaxBECount)) +    MaxBECount = BECount; + +  return ExitLimit(BECount, MaxBECount, false, Predicates); +} + +const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, +                                                    ScalarEvolution &SE) const { +  if (Range.isFullSet())  // Infinite loop. +    return SE.getCouldNotCompute(); + +  // If the start is a non-zero constant, shift the range to simplify things. +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) +    if (!SC->getValue()->isZero()) { +      SmallVector<const SCEV *, 4> Operands(op_begin(), op_end()); +      Operands[0] = SE.getZero(SC->getType()); +      const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), +                                             getNoWrapFlags(FlagNW)); +      if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted)) +        return ShiftedAddRec->getNumIterationsInRange( +            Range.subtract(SC->getAPInt()), SE); +      // This is strange and shouldn't happen. +      return SE.getCouldNotCompute(); +    } + +  // The only time we can solve this is when we have all constant indices. +  // Otherwise, we cannot determine the overflow conditions. +  if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); })) +    return SE.getCouldNotCompute(); + +  // Okay at this point we know that all elements of the chrec are constants and +  // that the start element is zero. + +  // First check to see if the range contains zero.  If not, the first +  // iteration exits. +  unsigned BitWidth = SE.getTypeSizeInBits(getType()); +  if (!Range.contains(APInt(BitWidth, 0))) +    return SE.getZero(getType()); + +  if (isAffine()) { +    // If this is an affine expression then we have this situation: +    //   Solve {0,+,A} in Range  ===  Ax in Range + +    // We know that zero is in the range.  If A is positive then we know that +    // the upper value of the range must be the first possible exit value. +    // If A is negative then the lower of the range is the last possible loop +    // value.  Also note that we already checked for a full range. +    APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt(); +    APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower(); + +    // The exit value should be (End+A)/A. +    APInt ExitVal = (End + A).udiv(A); +    ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal); + +    // Evaluate at the exit value.  If we really did fall out of the valid +    // range, then we computed our trip count, otherwise wrap around or other +    // things must have happened. +    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE); +    if (Range.contains(Val->getValue())) +      return SE.getCouldNotCompute();  // Something strange happened + +    // Ensure that the previous value is in the range.  This is a sanity check. +    assert(Range.contains( +           EvaluateConstantChrecAtConstant(this, +           ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && +           "Linear scev computation is off in a bad way!"); +    return SE.getConstant(ExitValue); +  } else if (isQuadratic()) { +    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the +    // quadratic equation to solve it.  To do this, we must frame our problem in +    // terms of figuring out when zero is crossed, instead of when +    // Range.getUpper() is crossed. +    SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end()); +    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); +    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap); + +    // Next, solve the constructed addrec +    if (auto Roots = +            SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) { +      const SCEVConstant *R1 = Roots->first; +      const SCEVConstant *R2 = Roots->second; +      // Pick the smallest positive root value. +      if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( +              ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { +        if (!CB->getZExtValue()) +          std::swap(R1, R2); // R1 is the minimum root now. + +        // Make sure the root is not off by one.  The returned iteration should +        // not be in the range, but the previous one should be.  When solving +        // for "X*X < 5", for example, we should not return a root of 2. +        ConstantInt *R1Val = +            EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); +        if (Range.contains(R1Val->getValue())) { +          // The next iteration must be out of the range... +          ConstantInt *NextVal = +              ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); + +          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); +          if (!Range.contains(R1Val->getValue())) +            return SE.getConstant(NextVal); +          return SE.getCouldNotCompute(); // Something strange happened +        } + +        // If R1 was not in the range, then it is a good return value.  Make +        // sure that R1-1 WAS in the range though, just in case. +        ConstantInt *NextVal = +            ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); +        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); +        if (Range.contains(R1Val->getValue())) +          return R1; +        return SE.getCouldNotCompute(); // Something strange happened +      } +    } +  } + +  return SE.getCouldNotCompute(); +} + +const SCEVAddRecExpr * +SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { +  assert(getNumOperands() > 1 && "AddRec with zero step?"); +  // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)), +  // but in this case we cannot guarantee that the value returned will be an +  // AddRec because SCEV does not have a fixed point where it stops +  // simplification: it is legal to return ({rec1} + {rec2}). For example, it +  // may happen if we reach arithmetic depth limit while simplifying. So we +  // construct the returned value explicitly. +  SmallVector<const SCEV *, 3> Ops; +  // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and +  // (this + Step) is {A+B,+,B+C,+...,+,N}. +  for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i) +    Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1))); +  // We know that the last operand is not a constant zero (otherwise it would +  // have been popped out earlier). This guarantees us that if the result has +  // the same last operand, then it will also not be popped out, meaning that +  // the returned value will be an AddRec. +  const SCEV *Last = getOperand(getNumOperands() - 1); +  assert(!Last->isZero() && "Recurrency with zero step?"); +  Ops.push_back(Last); +  return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(), +                                               SCEV::FlagAnyWrap)); +} + +// Return true when S contains at least an undef value. +static inline bool containsUndefs(const SCEV *S) { +  return SCEVExprContains(S, [](const SCEV *S) { +    if (const auto *SU = dyn_cast<SCEVUnknown>(S)) +      return isa<UndefValue>(SU->getValue()); +    else if (const auto *SC = dyn_cast<SCEVConstant>(S)) +      return isa<UndefValue>(SC->getValue()); +    return false; +  }); +} + +namespace { + +// Collect all steps of SCEV expressions. +struct SCEVCollectStrides { +  ScalarEvolution &SE; +  SmallVectorImpl<const SCEV *> &Strides; + +  SCEVCollectStrides(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &S) +      : SE(SE), Strides(S) {} + +  bool follow(const SCEV *S) { +    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) +      Strides.push_back(AR->getStepRecurrence(SE)); +    return true; +  } + +  bool isDone() const { return false; } +}; + +// Collect all SCEVUnknown and SCEVMulExpr expressions. +struct SCEVCollectTerms { +  SmallVectorImpl<const SCEV *> &Terms; + +  SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {} + +  bool follow(const SCEV *S) { +    if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) || +        isa<SCEVSignExtendExpr>(S)) { +      if (!containsUndefs(S)) +        Terms.push_back(S); + +      // Stop recursion: once we collected a term, do not walk its operands. +      return false; +    } + +    // Keep looking. +    return true; +  } + +  bool isDone() const { return false; } +}; + +// Check if a SCEV contains an AddRecExpr. +struct SCEVHasAddRec { +  bool &ContainsAddRec; + +  SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { +    ContainsAddRec = false; +  } + +  bool follow(const SCEV *S) { +    if (isa<SCEVAddRecExpr>(S)) { +      ContainsAddRec = true; + +      // Stop recursion: once we collected a term, do not walk its operands. +      return false; +    } + +    // Keep looking. +    return true; +  } + +  bool isDone() const { return false; } +}; + +// Find factors that are multiplied with an expression that (possibly as a +// subexpression) contains an AddRecExpr. In the expression: +// +//  8 * (100 +  %p * %q * (%a + {0, +, 1}_loop)) +// +// "%p * %q" are factors multiplied by the expression "(%a + {0, +, 1}_loop)" +// that contains the AddRec {0, +, 1}_loop. %p * %q are likely to be array size +// parameters as they form a product with an induction variable. +// +// This collector expects all array size parameters to be in the same MulExpr. +// It might be necessary to later add support for collecting parameters that are +// spread over different nested MulExpr. +struct SCEVCollectAddRecMultiplies { +  SmallVectorImpl<const SCEV *> &Terms; +  ScalarEvolution &SE; + +  SCEVCollectAddRecMultiplies(SmallVectorImpl<const SCEV *> &T, ScalarEvolution &SE) +      : Terms(T), SE(SE) {} + +  bool follow(const SCEV *S) { +    if (auto *Mul = dyn_cast<SCEVMulExpr>(S)) { +      bool HasAddRec = false; +      SmallVector<const SCEV *, 0> Operands; +      for (auto Op : Mul->operands()) { +        const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Op); +        if (Unknown && !isa<CallInst>(Unknown->getValue())) { +          Operands.push_back(Op); +        } else if (Unknown) { +          HasAddRec = true; +        } else { +          bool ContainsAddRec; +          SCEVHasAddRec ContiansAddRec(ContainsAddRec); +          visitAll(Op, ContiansAddRec); +          HasAddRec |= ContainsAddRec; +        } +      } +      if (Operands.size() == 0) +        return true; + +      if (!HasAddRec) +        return false; + +      Terms.push_back(SE.getMulExpr(Operands)); +      // Stop recursion: once we collected a term, do not walk its operands. +      return false; +    } + +    // Keep looking. +    return true; +  } + +  bool isDone() const { return false; } +}; + +} // end anonymous namespace + +/// Find parametric terms in this SCEVAddRecExpr. We first for parameters in +/// two places: +///   1) The strides of AddRec expressions. +///   2) Unknowns that are multiplied with AddRec expressions. +void ScalarEvolution::collectParametricTerms(const SCEV *Expr, +    SmallVectorImpl<const SCEV *> &Terms) { +  SmallVector<const SCEV *, 4> Strides; +  SCEVCollectStrides StrideCollector(*this, Strides); +  visitAll(Expr, StrideCollector); + +  LLVM_DEBUG({ +    dbgs() << "Strides:\n"; +    for (const SCEV *S : Strides) +      dbgs() << *S << "\n"; +  }); + +  for (const SCEV *S : Strides) { +    SCEVCollectTerms TermCollector(Terms); +    visitAll(S, TermCollector); +  } + +  LLVM_DEBUG({ +    dbgs() << "Terms:\n"; +    for (const SCEV *T : Terms) +      dbgs() << *T << "\n"; +  }); + +  SCEVCollectAddRecMultiplies MulCollector(Terms, *this); +  visitAll(Expr, MulCollector); +} + +static bool findArrayDimensionsRec(ScalarEvolution &SE, +                                   SmallVectorImpl<const SCEV *> &Terms, +                                   SmallVectorImpl<const SCEV *> &Sizes) { +  int Last = Terms.size() - 1; +  const SCEV *Step = Terms[Last]; + +  // End of recursion. +  if (Last == 0) { +    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Step)) { +      SmallVector<const SCEV *, 2> Qs; +      for (const SCEV *Op : M->operands()) +        if (!isa<SCEVConstant>(Op)) +          Qs.push_back(Op); + +      Step = SE.getMulExpr(Qs); +    } + +    Sizes.push_back(Step); +    return true; +  } + +  for (const SCEV *&Term : Terms) { +    // Normalize the terms before the next call to findArrayDimensionsRec. +    const SCEV *Q, *R; +    SCEVDivision::divide(SE, Term, Step, &Q, &R); + +    // Bail out when GCD does not evenly divide one of the terms. +    if (!R->isZero()) +      return false; + +    Term = Q; +  } + +  // Remove all SCEVConstants. +  Terms.erase( +      remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }), +      Terms.end()); + +  if (Terms.size() > 0) +    if (!findArrayDimensionsRec(SE, Terms, Sizes)) +      return false; + +  Sizes.push_back(Step); +  return true; +} + +// Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. +static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) { +  for (const SCEV *T : Terms) +    if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>)) +      return true; +  return false; +} + +// Return the number of product terms in S. +static inline int numberOfTerms(const SCEV *S) { +  if (const SCEVMulExpr *Expr = dyn_cast<SCEVMulExpr>(S)) +    return Expr->getNumOperands(); +  return 1; +} + +static const SCEV *removeConstantFactors(ScalarEvolution &SE, const SCEV *T) { +  if (isa<SCEVConstant>(T)) +    return nullptr; + +  if (isa<SCEVUnknown>(T)) +    return T; + +  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(T)) { +    SmallVector<const SCEV *, 2> Factors; +    for (const SCEV *Op : M->operands()) +      if (!isa<SCEVConstant>(Op)) +        Factors.push_back(Op); + +    return SE.getMulExpr(Factors); +  } + +  return T; +} + +/// Return the size of an element read or written by Inst. +const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { +  Type *Ty; +  if (StoreInst *Store = dyn_cast<StoreInst>(Inst)) +    Ty = Store->getValueOperand()->getType(); +  else if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) +    Ty = Load->getType(); +  else +    return nullptr; + +  Type *ETy = getEffectiveSCEVType(PointerType::getUnqual(Ty)); +  return getSizeOfExpr(ETy, Ty); +} + +void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, +                                          SmallVectorImpl<const SCEV *> &Sizes, +                                          const SCEV *ElementSize) { +  if (Terms.size() < 1 || !ElementSize) +    return; + +  // Early return when Terms do not contain parameters: we do not delinearize +  // non parametric SCEVs. +  if (!containsParameters(Terms)) +    return; + +  LLVM_DEBUG({ +    dbgs() << "Terms:\n"; +    for (const SCEV *T : Terms) +      dbgs() << *T << "\n"; +  }); + +  // Remove duplicates. +  array_pod_sort(Terms.begin(), Terms.end()); +  Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); + +  // Put larger terms first. +  llvm::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { +    return numberOfTerms(LHS) > numberOfTerms(RHS); +  }); + +  // Try to divide all terms by the element size. If term is not divisible by +  // element size, proceed with the original term. +  for (const SCEV *&Term : Terms) { +    const SCEV *Q, *R; +    SCEVDivision::divide(*this, Term, ElementSize, &Q, &R); +    if (!Q->isZero()) +      Term = Q; +  } + +  SmallVector<const SCEV *, 4> NewTerms; + +  // Remove constant factors. +  for (const SCEV *T : Terms) +    if (const SCEV *NewT = removeConstantFactors(*this, T)) +      NewTerms.push_back(NewT); + +  LLVM_DEBUG({ +    dbgs() << "Terms after sorting:\n"; +    for (const SCEV *T : NewTerms) +      dbgs() << *T << "\n"; +  }); + +  if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) { +    Sizes.clear(); +    return; +  } + +  // The last element to be pushed into Sizes is the size of an element. +  Sizes.push_back(ElementSize); + +  LLVM_DEBUG({ +    dbgs() << "Sizes:\n"; +    for (const SCEV *S : Sizes) +      dbgs() << *S << "\n"; +  }); +} + +void ScalarEvolution::computeAccessFunctions( +    const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, +    SmallVectorImpl<const SCEV *> &Sizes) { +  // Early exit in case this SCEV is not an affine multivariate function. +  if (Sizes.empty()) +    return; + +  if (auto *AR = dyn_cast<SCEVAddRecExpr>(Expr)) +    if (!AR->isAffine()) +      return; + +  const SCEV *Res = Expr; +  int Last = Sizes.size() - 1; +  for (int i = Last; i >= 0; i--) { +    const SCEV *Q, *R; +    SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R); + +    LLVM_DEBUG({ +      dbgs() << "Res: " << *Res << "\n"; +      dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; +      dbgs() << "Res divided by Sizes[i]:\n"; +      dbgs() << "Quotient: " << *Q << "\n"; +      dbgs() << "Remainder: " << *R << "\n"; +    }); + +    Res = Q; + +    // Do not record the last subscript corresponding to the size of elements in +    // the array. +    if (i == Last) { + +      // Bail out if the remainder is too complex. +      if (isa<SCEVAddRecExpr>(R)) { +        Subscripts.clear(); +        Sizes.clear(); +        return; +      } + +      continue; +    } + +    // Record the access function for the current subscript. +    Subscripts.push_back(R); +  } + +  // Also push in last position the remainder of the last division: it will be +  // the access function of the innermost dimension. +  Subscripts.push_back(Res); + +  std::reverse(Subscripts.begin(), Subscripts.end()); + +  LLVM_DEBUG({ +    dbgs() << "Subscripts:\n"; +    for (const SCEV *S : Subscripts) +      dbgs() << *S << "\n"; +  }); +} + +/// Splits the SCEV into two vectors of SCEVs representing the subscripts and +/// sizes of an array access. Returns the remainder of the delinearization that +/// is the offset start of the array.  The SCEV->delinearize algorithm computes +/// the multiples of SCEV coefficients: that is a pattern matching of sub +/// expressions in the stride and base of a SCEV corresponding to the +/// computation of a GCD (greatest common divisor) of base and stride.  When +/// SCEV->delinearize fails, it returns the SCEV unchanged. +/// +/// For example: when analyzing the memory access A[i][j][k] in this loop nest +/// +///  void foo(long n, long m, long o, double A[n][m][o]) { +/// +///    for (long i = 0; i < n; i++) +///      for (long j = 0; j < m; j++) +///        for (long k = 0; k < o; k++) +///          A[i][j][k] = 1.0; +///  } +/// +/// the delinearization input is the following AddRec SCEV: +/// +///  AddRec: {{{%A,+,(8 * %m * %o)}<%for.i>,+,(8 * %o)}<%for.j>,+,8}<%for.k> +/// +/// From this SCEV, we are able to say that the base offset of the access is %A +/// because it appears as an offset that does not divide any of the strides in +/// the loops: +/// +///  CHECK: Base offset: %A +/// +/// and then SCEV->delinearize determines the size of some of the dimensions of +/// the array as these are the multiples by which the strides are happening: +/// +///  CHECK: ArrayDecl[UnknownSize][%m][%o] with elements of sizeof(double) bytes. +/// +/// Note that the outermost dimension remains of UnknownSize because there are +/// no strides that would help identifying the size of the last dimension: when +/// the array has been statically allocated, one could compute the size of that +/// dimension by dividing the overall size of the array by the size of the known +/// dimensions: %m * %o * 8. +/// +/// Finally delinearize provides the access functions for the array reference +/// that does correspond to A[i][j][k] of the above C testcase: +/// +///  CHECK: ArrayRef[{0,+,1}<%for.i>][{0,+,1}<%for.j>][{0,+,1}<%for.k>] +/// +/// The testcases are checking the output of a function pass: +/// DelinearizationPass that walks through all loads and stores of a function +/// asking for the SCEV of the memory access with respect to all enclosing +/// loops, calling SCEV->delinearize on that and printing the results. +void ScalarEvolution::delinearize(const SCEV *Expr, +                                 SmallVectorImpl<const SCEV *> &Subscripts, +                                 SmallVectorImpl<const SCEV *> &Sizes, +                                 const SCEV *ElementSize) { +  // First step: collect parametric terms. +  SmallVector<const SCEV *, 4> Terms; +  collectParametricTerms(Expr, Terms); + +  if (Terms.empty()) +    return; + +  // Second step: find subscript sizes. +  findArrayDimensions(Terms, Sizes, ElementSize); + +  if (Sizes.empty()) +    return; + +  // Third step: compute the access functions for each subscript. +  computeAccessFunctions(Expr, Subscripts, Sizes); + +  if (Subscripts.empty()) +    return; + +  LLVM_DEBUG({ +    dbgs() << "succeeded to delinearize " << *Expr << "\n"; +    dbgs() << "ArrayDecl[UnknownSize]"; +    for (const SCEV *S : Sizes) +      dbgs() << "[" << *S << "]"; + +    dbgs() << "\nArrayRef"; +    for (const SCEV *S : Subscripts) +      dbgs() << "[" << *S << "]"; +    dbgs() << "\n"; +  }); +} + +//===----------------------------------------------------------------------===// +//                   SCEVCallbackVH Class Implementation +//===----------------------------------------------------------------------===// + +void ScalarEvolution::SCEVCallbackVH::deleted() { +  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); +  if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) +    SE->ConstantEvolutionLoopExitValue.erase(PN); +  SE->eraseValueFromMap(getValPtr()); +  // this now dangles! +} + +void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { +  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); + +  // Forget all the expressions associated with users of the old value, +  // so that future queries will recompute the expressions using the new +  // value. +  Value *Old = getValPtr(); +  SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end()); +  SmallPtrSet<User *, 8> Visited; +  while (!Worklist.empty()) { +    User *U = Worklist.pop_back_val(); +    // Deleting the Old value will cause this to dangle. Postpone +    // that until everything else is done. +    if (U == Old) +      continue; +    if (!Visited.insert(U).second) +      continue; +    if (PHINode *PN = dyn_cast<PHINode>(U)) +      SE->ConstantEvolutionLoopExitValue.erase(PN); +    SE->eraseValueFromMap(U); +    Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); +  } +  // Delete the Old value. +  if (PHINode *PN = dyn_cast<PHINode>(Old)) +    SE->ConstantEvolutionLoopExitValue.erase(PN); +  SE->eraseValueFromMap(Old); +  // this now dangles! +} + +ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se) +  : CallbackVH(V), SE(se) {} + +//===----------------------------------------------------------------------===// +//                   ScalarEvolution Class Implementation +//===----------------------------------------------------------------------===// + +ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, +                                 AssumptionCache &AC, DominatorTree &DT, +                                 LoopInfo &LI) +    : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), +      CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), +      LoopDispositions(64), BlockDispositions(64) { +  // To use guards for proving predicates, we need to scan every instruction in +  // relevant basic blocks, and not just terminators.  Doing this is a waste of +  // time if the IR does not actually contain any calls to +  // @llvm.experimental.guard, so do a quick check and remember this beforehand. +  // +  // This pessimizes the case where a pass that preserves ScalarEvolution wants +  // to _add_ guards to the module when there weren't any before, and wants +  // ScalarEvolution to optimize based on those guards.  For now we prefer to be +  // efficient in lieu of being smart in that rather obscure case. + +  auto *GuardDecl = F.getParent()->getFunction( +      Intrinsic::getName(Intrinsic::experimental_guard)); +  HasGuards = GuardDecl && !GuardDecl->use_empty(); +} + +ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) +    : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), +      LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), +      ValueExprMap(std::move(Arg.ValueExprMap)), +      PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), +      PendingPhiRanges(std::move(Arg.PendingPhiRanges)), +      PendingMerges(std::move(Arg.PendingMerges)), +      MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), +      BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), +      PredicatedBackedgeTakenCounts( +          std::move(Arg.PredicatedBackedgeTakenCounts)), +      ConstantEvolutionLoopExitValue( +          std::move(Arg.ConstantEvolutionLoopExitValue)), +      ValuesAtScopes(std::move(Arg.ValuesAtScopes)), +      LoopDispositions(std::move(Arg.LoopDispositions)), +      LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)), +      BlockDispositions(std::move(Arg.BlockDispositions)), +      UnsignedRanges(std::move(Arg.UnsignedRanges)), +      SignedRanges(std::move(Arg.SignedRanges)), +      UniqueSCEVs(std::move(Arg.UniqueSCEVs)), +      UniquePreds(std::move(Arg.UniquePreds)), +      SCEVAllocator(std::move(Arg.SCEVAllocator)), +      LoopUsers(std::move(Arg.LoopUsers)), +      PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)), +      FirstUnknown(Arg.FirstUnknown) { +  Arg.FirstUnknown = nullptr; +} + +ScalarEvolution::~ScalarEvolution() { +  // Iterate through all the SCEVUnknown instances and call their +  // destructors, so that they release their references to their values. +  for (SCEVUnknown *U = FirstUnknown; U;) { +    SCEVUnknown *Tmp = U; +    U = U->Next; +    Tmp->~SCEVUnknown(); +  } +  FirstUnknown = nullptr; + +  ExprValueMap.clear(); +  ValueExprMap.clear(); +  HasRecMap.clear(); + +  // Free any extra memory created for ExitNotTakenInfo in the unlikely event +  // that a loop had multiple computable exits. +  for (auto &BTCI : BackedgeTakenCounts) +    BTCI.second.clear(); +  for (auto &BTCI : PredicatedBackedgeTakenCounts) +    BTCI.second.clear(); + +  assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); +  assert(PendingPhiRanges.empty() && "getRangeRef garbage"); +  assert(PendingMerges.empty() && "isImpliedViaMerge garbage"); +  assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); +  assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"); +} + +bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { +  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L)); +} + +static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, +                          const Loop *L) { +  // Print all inner loops first +  for (Loop *I : *L) +    PrintLoopInfo(OS, SE, I); + +  OS << "Loop "; +  L->getHeader()->printAsOperand(OS, /*PrintType=*/false); +  OS << ": "; + +  SmallVector<BasicBlock *, 8> ExitBlocks; +  L->getExitBlocks(ExitBlocks); +  if (ExitBlocks.size() != 1) +    OS << "<multiple exits> "; + +  if (SE->hasLoopInvariantBackedgeTakenCount(L)) { +    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L); +  } else { +    OS << "Unpredictable backedge-taken count. "; +  } + +  OS << "\n" +        "Loop "; +  L->getHeader()->printAsOperand(OS, /*PrintType=*/false); +  OS << ": "; + +  if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) { +    OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); +    if (SE->isBackedgeTakenCountMaxOrZero(L)) +      OS << ", actual taken count either this or zero."; +  } else { +    OS << "Unpredictable max backedge-taken count. "; +  } + +  OS << "\n" +        "Loop "; +  L->getHeader()->printAsOperand(OS, /*PrintType=*/false); +  OS << ": "; + +  SCEVUnionPredicate Pred; +  auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); +  if (!isa<SCEVCouldNotCompute>(PBT)) { +    OS << "Predicated backedge-taken count is " << *PBT << "\n"; +    OS << " Predicates:\n"; +    Pred.print(OS, 4); +  } else { +    OS << "Unpredictable predicated backedge-taken count. "; +  } +  OS << "\n"; + +  if (SE->hasLoopInvariantBackedgeTakenCount(L)) { +    OS << "Loop "; +    L->getHeader()->printAsOperand(OS, /*PrintType=*/false); +    OS << ": "; +    OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n"; +  } +} + +static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { +  switch (LD) { +  case ScalarEvolution::LoopVariant: +    return "Variant"; +  case ScalarEvolution::LoopInvariant: +    return "Invariant"; +  case ScalarEvolution::LoopComputable: +    return "Computable"; +  } +  llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); +} + +void ScalarEvolution::print(raw_ostream &OS) const { +  // ScalarEvolution's implementation of the print method is to print +  // out SCEV values of all instructions that are interesting. Doing +  // this potentially causes it to create new SCEV objects though, +  // which technically conflicts with the const qualifier. This isn't +  // observable from outside the class though, so casting away the +  // const isn't dangerous. +  ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); + +  OS << "Classifying expressions for: "; +  F.printAsOperand(OS, /*PrintType=*/false); +  OS << "\n"; +  for (Instruction &I : instructions(F)) +    if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) { +      OS << I << '\n'; +      OS << "  -->  "; +      const SCEV *SV = SE.getSCEV(&I); +      SV->print(OS); +      if (!isa<SCEVCouldNotCompute>(SV)) { +        OS << " U: "; +        SE.getUnsignedRange(SV).print(OS); +        OS << " S: "; +        SE.getSignedRange(SV).print(OS); +      } + +      const Loop *L = LI.getLoopFor(I.getParent()); + +      const SCEV *AtUse = SE.getSCEVAtScope(SV, L); +      if (AtUse != SV) { +        OS << "  -->  "; +        AtUse->print(OS); +        if (!isa<SCEVCouldNotCompute>(AtUse)) { +          OS << " U: "; +          SE.getUnsignedRange(AtUse).print(OS); +          OS << " S: "; +          SE.getSignedRange(AtUse).print(OS); +        } +      } + +      if (L) { +        OS << "\t\t" "Exits: "; +        const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop()); +        if (!SE.isLoopInvariant(ExitValue, L)) { +          OS << "<<Unknown>>"; +        } else { +          OS << *ExitValue; +        } + +        bool First = true; +        for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { +          if (First) { +            OS << "\t\t" "LoopDispositions: { "; +            First = false; +          } else { +            OS << ", "; +          } + +          Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); +          OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); +        } + +        for (auto *InnerL : depth_first(L)) { +          if (InnerL == L) +            continue; +          if (First) { +            OS << "\t\t" "LoopDispositions: { "; +            First = false; +          } else { +            OS << ", "; +          } + +          InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); +          OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); +        } + +        OS << " }"; +      } + +      OS << "\n"; +    } + +  OS << "Determining loop execution counts for: "; +  F.printAsOperand(OS, /*PrintType=*/false); +  OS << "\n"; +  for (Loop *I : LI) +    PrintLoopInfo(OS, &SE, I); +} + +ScalarEvolution::LoopDisposition +ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { +  auto &Values = LoopDispositions[S]; +  for (auto &V : Values) { +    if (V.getPointer() == L) +      return V.getInt(); +  } +  Values.emplace_back(L, LoopVariant); +  LoopDisposition D = computeLoopDisposition(S, L); +  auto &Values2 = LoopDispositions[S]; +  for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { +    if (V.getPointer() == L) { +      V.setInt(D); +      break; +    } +  } +  return D; +} + +ScalarEvolution::LoopDisposition +ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { +  switch (static_cast<SCEVTypes>(S->getSCEVType())) { +  case scConstant: +    return LoopInvariant; +  case scTruncate: +  case scZeroExtend: +  case scSignExtend: +    return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L); +  case scAddRecExpr: { +    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); + +    // If L is the addrec's loop, it's computable. +    if (AR->getLoop() == L) +      return LoopComputable; + +    // Add recurrences are never invariant in the function-body (null loop). +    if (!L) +      return LoopVariant; + +    // Everything that is not defined at loop entry is variant. +    if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) +      return LoopVariant; +    assert(!L->contains(AR->getLoop()) && "Containing loop's header does not" +           " dominate the contained loop's header?"); + +    // This recurrence is invariant w.r.t. L if AR's loop contains L. +    if (AR->getLoop()->contains(L)) +      return LoopInvariant; + +    // This recurrence is variant w.r.t. L if any of its operands +    // are variant. +    for (auto *Op : AR->operands()) +      if (!isLoopInvariant(Op, L)) +        return LoopVariant; + +    // Otherwise it's loop-invariant. +    return LoopInvariant; +  } +  case scAddExpr: +  case scMulExpr: +  case scUMaxExpr: +  case scSMaxExpr: { +    bool HasVarying = false; +    for (auto *Op : cast<SCEVNAryExpr>(S)->operands()) { +      LoopDisposition D = getLoopDisposition(Op, L); +      if (D == LoopVariant) +        return LoopVariant; +      if (D == LoopComputable) +        HasVarying = true; +    } +    return HasVarying ? LoopComputable : LoopInvariant; +  } +  case scUDivExpr: { +    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); +    LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L); +    if (LD == LoopVariant) +      return LoopVariant; +    LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L); +    if (RD == LoopVariant) +      return LoopVariant; +    return (LD == LoopInvariant && RD == LoopInvariant) ? +           LoopInvariant : LoopComputable; +  } +  case scUnknown: +    // All non-instruction values are loop invariant.  All instructions are loop +    // invariant if they are not contained in the specified loop. +    // Instructions are never considered invariant in the function body +    // (null loop) because they are defined within the "loop". +    if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) +      return (L && !L->contains(I)) ? LoopInvariant : LoopVariant; +    return LoopInvariant; +  case scCouldNotCompute: +    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); +  } +  llvm_unreachable("Unknown SCEV kind!"); +} + +bool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) { +  return getLoopDisposition(S, L) == LoopInvariant; +} + +bool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) { +  return getLoopDisposition(S, L) == LoopComputable; +} + +ScalarEvolution::BlockDisposition +ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { +  auto &Values = BlockDispositions[S]; +  for (auto &V : Values) { +    if (V.getPointer() == BB) +      return V.getInt(); +  } +  Values.emplace_back(BB, DoesNotDominateBlock); +  BlockDisposition D = computeBlockDisposition(S, BB); +  auto &Values2 = BlockDispositions[S]; +  for (auto &V : make_range(Values2.rbegin(), Values2.rend())) { +    if (V.getPointer() == BB) { +      V.setInt(D); +      break; +    } +  } +  return D; +} + +ScalarEvolution::BlockDisposition +ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { +  switch (static_cast<SCEVTypes>(S->getSCEVType())) { +  case scConstant: +    return ProperlyDominatesBlock; +  case scTruncate: +  case scZeroExtend: +  case scSignExtend: +    return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB); +  case scAddRecExpr: { +    // This uses a "dominates" query instead of "properly dominates" query +    // to test for proper dominance too, because the instruction which +    // produces the addrec's value is a PHI, and a PHI effectively properly +    // dominates its entire containing block. +    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S); +    if (!DT.dominates(AR->getLoop()->getHeader(), BB)) +      return DoesNotDominateBlock; + +    // Fall through into SCEVNAryExpr handling. +    LLVM_FALLTHROUGH; +  } +  case scAddExpr: +  case scMulExpr: +  case scUMaxExpr: +  case scSMaxExpr: { +    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S); +    bool Proper = true; +    for (const SCEV *NAryOp : NAry->operands()) { +      BlockDisposition D = getBlockDisposition(NAryOp, BB); +      if (D == DoesNotDominateBlock) +        return DoesNotDominateBlock; +      if (D == DominatesBlock) +        Proper = false; +    } +    return Proper ? ProperlyDominatesBlock : DominatesBlock; +  } +  case scUDivExpr: { +    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S); +    const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS(); +    BlockDisposition LD = getBlockDisposition(LHS, BB); +    if (LD == DoesNotDominateBlock) +      return DoesNotDominateBlock; +    BlockDisposition RD = getBlockDisposition(RHS, BB); +    if (RD == DoesNotDominateBlock) +      return DoesNotDominateBlock; +    return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ? +      ProperlyDominatesBlock : DominatesBlock; +  } +  case scUnknown: +    if (Instruction *I = +          dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) { +      if (I->getParent() == BB) +        return DominatesBlock; +      if (DT.properlyDominates(I->getParent(), BB)) +        return ProperlyDominatesBlock; +      return DoesNotDominateBlock; +    } +    return ProperlyDominatesBlock; +  case scCouldNotCompute: +    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); +  } +  llvm_unreachable("Unknown SCEV kind!"); +} + +bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) { +  return getBlockDisposition(S, BB) >= DominatesBlock; +} + +bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) { +  return getBlockDisposition(S, BB) == ProperlyDominatesBlock; +} + +bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { +  return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); +} + +bool ScalarEvolution::ExitLimit::hasOperand(const SCEV *S) const { +  auto IsS = [&](const SCEV *X) { return S == X; }; +  auto ContainsS = [&](const SCEV *X) { +    return !isa<SCEVCouldNotCompute>(X) && SCEVExprContains(X, IsS); +  }; +  return ContainsS(ExactNotTaken) || ContainsS(MaxNotTaken); +} + +void +ScalarEvolution::forgetMemoizedResults(const SCEV *S) { +  ValuesAtScopes.erase(S); +  LoopDispositions.erase(S); +  BlockDispositions.erase(S); +  UnsignedRanges.erase(S); +  SignedRanges.erase(S); +  ExprValueMap.erase(S); +  HasRecMap.erase(S); +  MinTrailingZerosCache.erase(S); + +  for (auto I = PredicatedSCEVRewrites.begin(); +       I != PredicatedSCEVRewrites.end();) { +    std::pair<const SCEV *, const Loop *> Entry = I->first; +    if (Entry.first == S) +      PredicatedSCEVRewrites.erase(I++); +    else +      ++I; +  } + +  auto RemoveSCEVFromBackedgeMap = +      [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { +        for (auto I = Map.begin(), E = Map.end(); I != E;) { +          BackedgeTakenInfo &BEInfo = I->second; +          if (BEInfo.hasOperand(S, this)) { +            BEInfo.clear(); +            Map.erase(I++); +          } else +            ++I; +        } +      }; + +  RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); +  RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); +} + +void +ScalarEvolution::getUsedLoops(const SCEV *S, +                              SmallPtrSetImpl<const Loop *> &LoopsUsed) { +  struct FindUsedLoops { +    FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed) +        : LoopsUsed(LoopsUsed) {} +    SmallPtrSetImpl<const Loop *> &LoopsUsed; +    bool follow(const SCEV *S) { +      if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) +        LoopsUsed.insert(AR->getLoop()); +      return true; +    } + +    bool isDone() const { return false; } +  }; + +  FindUsedLoops F(LoopsUsed); +  SCEVTraversal<FindUsedLoops>(F).visitAll(S); +} + +void ScalarEvolution::addToLoopUseLists(const SCEV *S) { +  SmallPtrSet<const Loop *, 8> LoopsUsed; +  getUsedLoops(S, LoopsUsed); +  for (auto *L : LoopsUsed) +    LoopUsers[L].push_back(S); +} + +void ScalarEvolution::verify() const { +  ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); +  ScalarEvolution SE2(F, TLI, AC, DT, LI); + +  SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end()); + +  // Map's SCEV expressions from one ScalarEvolution "universe" to another. +  struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> { +    SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} + +    const SCEV *visitConstant(const SCEVConstant *Constant) { +      return SE.getConstant(Constant->getAPInt()); +    } + +    const SCEV *visitUnknown(const SCEVUnknown *Expr) { +      return SE.getUnknown(Expr->getValue()); +    } + +    const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { +      return SE.getCouldNotCompute(); +    } +  }; + +  SCEVMapper SCM(SE2); + +  while (!LoopStack.empty()) { +    auto *L = LoopStack.pop_back_val(); +    LoopStack.insert(LoopStack.end(), L->begin(), L->end()); + +    auto *CurBECount = SCM.visit( +        const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L)); +    auto *NewBECount = SE2.getBackedgeTakenCount(L); + +    if (CurBECount == SE2.getCouldNotCompute() || +        NewBECount == SE2.getCouldNotCompute()) { +      // NB! This situation is legal, but is very suspicious -- whatever pass +      // change the loop to make a trip count go from could not compute to +      // computable or vice-versa *should have* invalidated SCEV.  However, we +      // choose not to assert here (for now) since we don't want false +      // positives. +      continue; +    } + +    if (containsUndefs(CurBECount) || containsUndefs(NewBECount)) { +      // SCEV treats "undef" as an unknown but consistent value (i.e. it does +      // not propagate undef aggressively).  This means we can (and do) fail +      // verification in cases where a transform makes the trip count of a loop +      // go from "undef" to "undef+1" (say).  The transform is fine, since in +      // both cases the loop iterates "undef" times, but SCEV thinks we +      // increased the trip count of the loop by 1 incorrectly. +      continue; +    } + +    if (SE.getTypeSizeInBits(CurBECount->getType()) > +        SE.getTypeSizeInBits(NewBECount->getType())) +      NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType()); +    else if (SE.getTypeSizeInBits(CurBECount->getType()) < +             SE.getTypeSizeInBits(NewBECount->getType())) +      CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType()); + +    auto *ConstantDelta = +        dyn_cast<SCEVConstant>(SE2.getMinusSCEV(CurBECount, NewBECount)); + +    if (ConstantDelta && ConstantDelta->getAPInt() != 0) { +      dbgs() << "Trip Count Changed!\n"; +      dbgs() << "Old: " << *CurBECount << "\n"; +      dbgs() << "New: " << *NewBECount << "\n"; +      dbgs() << "Delta: " << *ConstantDelta << "\n"; +      std::abort(); +    } +  } +} + +bool ScalarEvolution::invalidate( +    Function &F, const PreservedAnalyses &PA, +    FunctionAnalysisManager::Invalidator &Inv) { +  // Invalidate the ScalarEvolution object whenever it isn't preserved or one +  // of its dependencies is invalidated. +  auto PAC = PA.getChecker<ScalarEvolutionAnalysis>(); +  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) || +         Inv.invalidate<AssumptionAnalysis>(F, PA) || +         Inv.invalidate<DominatorTreeAnalysis>(F, PA) || +         Inv.invalidate<LoopAnalysis>(F, PA); +} + +AnalysisKey ScalarEvolutionAnalysis::Key; + +ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, +                                             FunctionAnalysisManager &AM) { +  return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), +                         AM.getResult<AssumptionAnalysis>(F), +                         AM.getResult<DominatorTreeAnalysis>(F), +                         AM.getResult<LoopAnalysis>(F)); +} + +PreservedAnalyses +ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { +  AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); +  return PreservedAnalyses::all(); +} + +INITIALIZE_PASS_BEGIN(ScalarEvolutionWrapperPass, "scalar-evolution", +                      "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", +                    "Scalar Evolution Analysis", false, true) + +char ScalarEvolutionWrapperPass::ID = 0; + +ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) { +  initializeScalarEvolutionWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ScalarEvolutionWrapperPass::runOnFunction(Function &F) { +  SE.reset(new ScalarEvolution( +      F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), +      getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), +      getAnalysis<DominatorTreeWrapperPass>().getDomTree(), +      getAnalysis<LoopInfoWrapperPass>().getLoopInfo())); +  return false; +} + +void ScalarEvolutionWrapperPass::releaseMemory() { SE.reset(); } + +void ScalarEvolutionWrapperPass::print(raw_ostream &OS, const Module *) const { +  SE->print(OS); +} + +void ScalarEvolutionWrapperPass::verifyAnalysis() const { +  if (!VerifySCEV) +    return; + +  SE->verify(); +} + +void ScalarEvolutionWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequiredTransitive<AssumptionCacheTracker>(); +  AU.addRequiredTransitive<LoopInfoWrapperPass>(); +  AU.addRequiredTransitive<DominatorTreeWrapperPass>(); +  AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); +} + +const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, +                                                        const SCEV *RHS) { +  FoldingSetNodeID ID; +  assert(LHS->getType() == RHS->getType() && +         "Type mismatch between LHS and RHS"); +  // Unique this node based on the arguments +  ID.AddInteger(SCEVPredicate::P_Equal); +  ID.AddPointer(LHS); +  ID.AddPointer(RHS); +  void *IP = nullptr; +  if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) +    return S; +  SCEVEqualPredicate *Eq = new (SCEVAllocator) +      SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); +  UniquePreds.InsertNode(Eq, IP); +  return Eq; +} + +const SCEVPredicate *ScalarEvolution::getWrapPredicate( +    const SCEVAddRecExpr *AR, +    SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { +  FoldingSetNodeID ID; +  // Unique this node based on the arguments +  ID.AddInteger(SCEVPredicate::P_Wrap); +  ID.AddPointer(AR); +  ID.AddInteger(AddedFlags); +  void *IP = nullptr; +  if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) +    return S; +  auto *OF = new (SCEVAllocator) +      SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); +  UniquePreds.InsertNode(OF, IP); +  return OF; +} + +namespace { + +class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { +public: + +  /// Rewrites \p S in the context of a loop L and the SCEV predication +  /// infrastructure. +  /// +  /// If \p Pred is non-null, the SCEV expression is rewritten to respect the +  /// equivalences present in \p Pred. +  /// +  /// If \p NewPreds is non-null, rewrite is free to add further predicates to +  /// \p NewPreds such that the result will be an AddRecExpr. +  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, +                             SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, +                             SCEVUnionPredicate *Pred) { +    SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); +    return Rewriter.visit(S); +  } + +  const SCEV *visitUnknown(const SCEVUnknown *Expr) { +    if (Pred) { +      auto ExprPreds = Pred->getPredicatesForExpr(Expr); +      for (auto *Pred : ExprPreds) +        if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) +          if (IPred->getLHS() == Expr) +            return IPred->getRHS(); +    } +    return convertToAddRecWithPreds(Expr); +  } + +  const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { +    const SCEV *Operand = visit(Expr->getOperand()); +    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); +    if (AR && AR->getLoop() == L && AR->isAffine()) { +      // This couldn't be folded because the operand didn't have the nuw +      // flag. Add the nusw flag as an assumption that we could make. +      const SCEV *Step = AR->getStepRecurrence(SE); +      Type *Ty = Expr->getType(); +      if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) +        return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), +                                SE.getSignExtendExpr(Step, Ty), L, +                                AR->getNoWrapFlags()); +    } +    return SE.getZeroExtendExpr(Operand, Expr->getType()); +  } + +  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { +    const SCEV *Operand = visit(Expr->getOperand()); +    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); +    if (AR && AR->getLoop() == L && AR->isAffine()) { +      // This couldn't be folded because the operand didn't have the nsw +      // flag. Add the nssw flag as an assumption that we could make. +      const SCEV *Step = AR->getStepRecurrence(SE); +      Type *Ty = Expr->getType(); +      if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) +        return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), +                                SE.getSignExtendExpr(Step, Ty), L, +                                AR->getNoWrapFlags()); +    } +    return SE.getSignExtendExpr(Operand, Expr->getType()); +  } + +private: +  explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, +                        SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, +                        SCEVUnionPredicate *Pred) +      : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} + +  bool addOverflowAssumption(const SCEVPredicate *P) { +    if (!NewPreds) { +      // Check if we've already made this assumption. +      return Pred && Pred->implies(P); +    } +    NewPreds->insert(P); +    return true; +  } + +  bool addOverflowAssumption(const SCEVAddRecExpr *AR, +                             SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { +    auto *A = SE.getWrapPredicate(AR, AddedFlags); +    return addOverflowAssumption(A); +  } + +  // If \p Expr represents a PHINode, we try to see if it can be represented +  // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible +  // to add this predicate as a runtime overflow check, we return the AddRec. +  // If \p Expr does not meet these conditions (is not a PHI node, or we +  // couldn't create an AddRec for it, or couldn't add the predicate), we just +  // return \p Expr. +  const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) { +    if (!isa<PHINode>(Expr->getValue())) +      return Expr; +    Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> +    PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr); +    if (!PredicatedRewrite) +      return Expr; +    for (auto *P : PredicatedRewrite->second){ +      // Wrap predicates from outer loops are not supported. +      if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) { +        auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr()); +        if (L != AR->getLoop()) +          return Expr; +      } +      if (!addOverflowAssumption(P)) +        return Expr; +    } +    return PredicatedRewrite->first; +  } + +  SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; +  SCEVUnionPredicate *Pred; +  const Loop *L; +}; + +} // end anonymous namespace + +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, +                                                   SCEVUnionPredicate &Preds) { +  return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); +} + +const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( +    const SCEV *S, const Loop *L, +    SmallPtrSetImpl<const SCEVPredicate *> &Preds) { +  SmallPtrSet<const SCEVPredicate *, 4> TransformPreds; +  S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); +  auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); + +  if (!AddRec) +    return nullptr; + +  // Since the transformation was successful, we can now transfer the SCEV +  // predicates. +  for (auto *P : TransformPreds) +    Preds.insert(P); + +  return AddRec; +} + +/// SCEV predicates +SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, +                             SCEVPredicateKind Kind) +    : FastID(ID), Kind(Kind) {} + +SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, +                                       const SCEV *LHS, const SCEV *RHS) +    : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) { +  assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); +  assert(LHS != RHS && "LHS and RHS are the same SCEV"); +} + +bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { +  const auto *Op = dyn_cast<SCEVEqualPredicate>(N); + +  if (!Op) +    return false; + +  return Op->LHS == LHS && Op->RHS == RHS; +} + +bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } + +const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } + +void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { +  OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; +} + +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, +                                     const SCEVAddRecExpr *AR, +                                     IncrementWrapFlags Flags) +    : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { +  const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + +  return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { +  SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); +  IncrementWrapFlags IFlags = Flags; + +  if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) +    IFlags = clearFlags(IFlags, IncrementNSSW); + +  return IFlags == IncrementAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { +  OS.indent(Depth) << *getExpr() << " Added Flags: "; +  if (SCEVWrapPredicate::IncrementNUSW & getFlags()) +    OS << "<nusw>"; +  if (SCEVWrapPredicate::IncrementNSSW & getFlags()) +    OS << "<nssw>"; +  OS << "\n"; +} + +SCEVWrapPredicate::IncrementWrapFlags +SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, +                                   ScalarEvolution &SE) { +  IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; +  SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); + +  // We can safely transfer the NSW flag as NSSW. +  if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) +    ImpliedFlags = IncrementNSSW; + +  if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { +    // If the increment is positive, the SCEV NUW flag will also imply the +    // WrapPredicate NUSW flag. +    if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) +      if (Step->getValue()->getValue().isNonNegative()) +        ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); +  } + +  return ImpliedFlags; +} + +/// Union predicates don't get cached so create a dummy set ID for it. +SCEVUnionPredicate::SCEVUnionPredicate() +    : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} + +bool SCEVUnionPredicate::isAlwaysTrue() const { +  return all_of(Preds, +                [](const SCEVPredicate *I) { return I->isAlwaysTrue(); }); +} + +ArrayRef<const SCEVPredicate *> +SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { +  auto I = SCEVToPreds.find(Expr); +  if (I == SCEVToPreds.end()) +    return ArrayRef<const SCEVPredicate *>(); +  return I->second; +} + +bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { +  if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) +    return all_of(Set->Preds, +                  [this](const SCEVPredicate *I) { return this->implies(I); }); + +  auto ScevPredsIt = SCEVToPreds.find(N->getExpr()); +  if (ScevPredsIt == SCEVToPreds.end()) +    return false; +  auto &SCEVPreds = ScevPredsIt->second; + +  return any_of(SCEVPreds, +                [N](const SCEVPredicate *I) { return I->implies(N); }); +} + +const SCEV *SCEVUnionPredicate::getExpr() const { return nullptr; } + +void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { +  for (auto Pred : Preds) +    Pred->print(OS, Depth); +} + +void SCEVUnionPredicate::add(const SCEVPredicate *N) { +  if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { +    for (auto Pred : Set->Preds) +      add(Pred); +    return; +  } + +  if (implies(N)) +    return; + +  const SCEV *Key = N->getExpr(); +  assert(Key && "Only SCEVUnionPredicate doesn't have an " +                " associated expression!"); + +  SCEVToPreds[Key].push_back(N); +  Preds.push_back(N); +} + +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, +                                                     Loop &L) +    : SE(SE), L(L) {} + +const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { +  const SCEV *Expr = SE.getSCEV(V); +  RewriteEntry &Entry = RewriteMap[Expr]; + +  // If we already have an entry and the version matches, return it. +  if (Entry.second && Generation == Entry.first) +    return Entry.second; + +  // We found an entry but it's stale. Rewrite the stale entry +  // according to the current predicate. +  if (Entry.second) +    Expr = Entry.second; + +  const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); +  Entry = {Generation, NewSCEV}; + +  return NewSCEV; +} + +const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { +  if (!BackedgeCount) { +    SCEVUnionPredicate BackedgePred; +    BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); +    addPredicate(BackedgePred); +  } +  return BackedgeCount; +} + +void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { +  if (Preds.implies(&Pred)) +    return; +  Preds.add(&Pred); +  updateGeneration(); +} + +const SCEVUnionPredicate &PredicatedScalarEvolution::getUnionPredicate() const { +  return Preds; +} + +void PredicatedScalarEvolution::updateGeneration() { +  // If the generation number wrapped recompute everything. +  if (++Generation == 0) { +    for (auto &II : RewriteMap) { +      const SCEV *Rewritten = II.second.second; +      II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; +    } +  } +} + +void PredicatedScalarEvolution::setNoOverflow( +    Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { +  const SCEV *Expr = getSCEV(V); +  const auto *AR = cast<SCEVAddRecExpr>(Expr); + +  auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); + +  // Clear the statically implied flags. +  Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); +  addPredicate(*SE.getWrapPredicate(AR, Flags)); + +  auto II = FlagsMap.insert({V, Flags}); +  if (!II.second) +    II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( +    Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { +  const SCEV *Expr = getSCEV(V); +  const auto *AR = cast<SCEVAddRecExpr>(Expr); + +  Flags = SCEVWrapPredicate::clearFlags( +      Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + +  auto II = FlagsMap.find(V); + +  if (II != FlagsMap.end()) +    Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + +  return Flags == SCEVWrapPredicate::IncrementAnyWrap; +} + +const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { +  const SCEV *Expr = this->getSCEV(V); +  SmallPtrSet<const SCEVPredicate *, 4> NewPreds; +  auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); + +  if (!New) +    return nullptr; + +  for (auto *P : NewPreds) +    Preds.add(P); + +  updateGeneration(); +  RewriteMap[SE.getSCEV(V)] = {Generation, New}; +  return New; +} + +PredicatedScalarEvolution::PredicatedScalarEvolution( +    const PredicatedScalarEvolution &Init) +    : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), +      Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { +  for (const auto &I : Init.FlagsMap) +    FlagsMap.insert(I); +} + +void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { +  // For each block. +  for (auto *BB : L.getBlocks()) +    for (auto &I : *BB) { +      if (!SE.isSCEVable(I.getType())) +        continue; + +      auto *Expr = SE.getSCEV(&I); +      auto II = RewriteMap.find(Expr); + +      if (II == RewriteMap.end()) +        continue; + +      // Don't print things that are not interesting. +      if (II->second.second == Expr) +        continue; + +      OS.indent(Depth) << "[PSE]" << I << ":\n"; +      OS.indent(Depth + 2) << *Expr << "\n"; +      OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; +    } +} + +// Match the mathematical pattern A - (A / B) * B, where A and B can be +// arbitrary expressions. +// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is +// 4, A / B becomes X / 8). +bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, +                                const SCEV *&RHS) { +  const auto *Add = dyn_cast<SCEVAddExpr>(Expr); +  if (Add == nullptr || Add->getNumOperands() != 2) +    return false; + +  const SCEV *A = Add->getOperand(1); +  const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); + +  if (Mul == nullptr) +    return false; + +  const auto MatchURemWithDivisor = [&](const SCEV *B) { +    // (SomeExpr + (-(SomeExpr / B) * B)). +    if (Expr == getURemExpr(A, B)) { +      LHS = A; +      RHS = B; +      return true; +    } +    return false; +  }; + +  // (SomeExpr + (-1 * (SomeExpr / B) * B)). +  if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0))) +    return MatchURemWithDivisor(Mul->getOperand(1)) || +           MatchURemWithDivisor(Mul->getOperand(2)); + +  // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)). +  if (Mul->getNumOperands() == 2) +    return MatchURemWithDivisor(Mul->getOperand(1)) || +           MatchURemWithDivisor(Mul->getOperand(0)) || +           MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) || +           MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0))); +  return false; +} diff --git a/contrib/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp new file mode 100644 index 000000000000..7bea994121c8 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp @@ -0,0 +1,143 @@ +//===- ScalarEvolutionAliasAnalysis.cpp - SCEV-based Alias Analysis -------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the ScalarEvolutionAliasAnalysis pass, which implements a +// simple alias analysis implemented in terms of ScalarEvolution queries. +// +// This differs from traditional loop dependence analysis in that it tests +// for dependencies within a single iteration of a loop, rather than +// dependencies between different iterations. +// +// ScalarEvolution has a more complete understanding of pointer arithmetic +// than BasicAliasAnalysis' collection of ad-hoc analyses. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +using namespace llvm; + +AliasResult SCEVAAResult::alias(const MemoryLocation &LocA, +                                const MemoryLocation &LocB) { +  // If either of the memory references is empty, it doesn't matter what the +  // pointer values are. This allows the code below to ignore this special +  // case. +  if (LocA.Size == 0 || LocB.Size == 0) +    return NoAlias; + +  // This is SCEVAAResult. Get the SCEVs! +  const SCEV *AS = SE.getSCEV(const_cast<Value *>(LocA.Ptr)); +  const SCEV *BS = SE.getSCEV(const_cast<Value *>(LocB.Ptr)); + +  // If they evaluate to the same expression, it's a MustAlias. +  if (AS == BS) +    return MustAlias; + +  // If something is known about the difference between the two addresses, +  // see if it's enough to prove a NoAlias. +  if (SE.getEffectiveSCEVType(AS->getType()) == +      SE.getEffectiveSCEVType(BS->getType())) { +    unsigned BitWidth = SE.getTypeSizeInBits(AS->getType()); +    APInt ASizeInt(BitWidth, LocA.Size); +    APInt BSizeInt(BitWidth, LocB.Size); + +    // Compute the difference between the two pointers. +    const SCEV *BA = SE.getMinusSCEV(BS, AS); + +    // Test whether the difference is known to be great enough that memory of +    // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt +    // are non-zero, which is special-cased above. +    if (ASizeInt.ule(SE.getUnsignedRange(BA).getUnsignedMin()) && +        (-BSizeInt).uge(SE.getUnsignedRange(BA).getUnsignedMax())) +      return NoAlias; + +    // Folding the subtraction while preserving range information can be tricky +    // (because of INT_MIN, etc.); if the prior test failed, swap AS and BS +    // and try again to see if things fold better that way. + +    // Compute the difference between the two pointers. +    const SCEV *AB = SE.getMinusSCEV(AS, BS); + +    // Test whether the difference is known to be great enough that memory of +    // the given sizes don't overlap. This assumes that ASizeInt and BSizeInt +    // are non-zero, which is special-cased above. +    if (BSizeInt.ule(SE.getUnsignedRange(AB).getUnsignedMin()) && +        (-ASizeInt).uge(SE.getUnsignedRange(AB).getUnsignedMax())) +      return NoAlias; +  } + +  // If ScalarEvolution can find an underlying object, form a new query. +  // The correctness of this depends on ScalarEvolution not recognizing +  // inttoptr and ptrtoint operators. +  Value *AO = GetBaseValue(AS); +  Value *BO = GetBaseValue(BS); +  if ((AO && AO != LocA.Ptr) || (BO && BO != LocB.Ptr)) +    if (alias(MemoryLocation(AO ? AO : LocA.Ptr, +                             AO ? +MemoryLocation::UnknownSize : LocA.Size, +                             AO ? AAMDNodes() : LocA.AATags), +              MemoryLocation(BO ? BO : LocB.Ptr, +                             BO ? +MemoryLocation::UnknownSize : LocB.Size, +                             BO ? AAMDNodes() : LocB.AATags)) == NoAlias) +      return NoAlias; + +  // Forward the query to the next analysis. +  return AAResultBase::alias(LocA, LocB); +} + +/// Given an expression, try to find a base value. +/// +/// Returns null if none was found. +Value *SCEVAAResult::GetBaseValue(const SCEV *S) { +  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { +    // In an addrec, assume that the base will be in the start, rather +    // than the step. +    return GetBaseValue(AR->getStart()); +  } else if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) { +    // If there's a pointer operand, it'll be sorted at the end of the list. +    const SCEV *Last = A->getOperand(A->getNumOperands() - 1); +    if (Last->getType()->isPointerTy()) +      return GetBaseValue(Last); +  } else if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { +    // This is a leaf node. +    return U->getValue(); +  } +  // No Identified object found. +  return nullptr; +} + +AnalysisKey SCEVAA::Key; + +SCEVAAResult SCEVAA::run(Function &F, FunctionAnalysisManager &AM) { +  return SCEVAAResult(AM.getResult<ScalarEvolutionAnalysis>(F)); +} + +char SCEVAAWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(SCEVAAWrapperPass, "scev-aa", +                      "ScalarEvolution-based Alias Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(SCEVAAWrapperPass, "scev-aa", +                    "ScalarEvolution-based Alias Analysis", false, true) + +FunctionPass *llvm::createSCEVAAWrapperPass() { +  return new SCEVAAWrapperPass(); +} + +SCEVAAWrapperPass::SCEVAAWrapperPass() : FunctionPass(ID) { +  initializeSCEVAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool SCEVAAWrapperPass::runOnFunction(Function &F) { +  Result.reset( +      new SCEVAAResult(getAnalysis<ScalarEvolutionWrapperPass>().getSE())); +  return false; +} + +void SCEVAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +  AU.addRequired<ScalarEvolutionWrapperPass>(); +} diff --git a/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp b/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp new file mode 100644 index 000000000000..8f89389c4b5d --- /dev/null +++ b/contrib/llvm/lib/Analysis/ScalarEvolutionExpander.cpp @@ -0,0 +1,2344 @@ +//===- ScalarEvolutionExpander.cpp - Scalar Evolution Analysis ------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of the scalar evolution expander, +// which is used to generate the code corresponding to a given scalar evolution +// expression. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace PatternMatch; + +/// ReuseOrCreateCast - Arrange for there to be a cast of V to Ty at IP, +/// reusing an existing cast if a suitable one exists, moving an existing +/// cast if a suitable one exists but isn't in the right place, or +/// creating a new one. +Value *SCEVExpander::ReuseOrCreateCast(Value *V, Type *Ty, +                                       Instruction::CastOps Op, +                                       BasicBlock::iterator IP) { +  // This function must be called with the builder having a valid insertion +  // point. It doesn't need to be the actual IP where the uses of the returned +  // cast will be added, but it must dominate such IP. +  // We use this precondition to produce a cast that will dominate all its +  // uses. In particular, this is crucial for the case where the builder's +  // insertion point *is* the point where we were asked to put the cast. +  // Since we don't know the builder's insertion point is actually +  // where the uses will be added (only that it dominates it), we are +  // not allowed to move it. +  BasicBlock::iterator BIP = Builder.GetInsertPoint(); + +  Instruction *Ret = nullptr; + +  // Check to see if there is already a cast! +  for (User *U : V->users()) +    if (U->getType() == Ty) +      if (CastInst *CI = dyn_cast<CastInst>(U)) +        if (CI->getOpcode() == Op) { +          // If the cast isn't where we want it, create a new cast at IP. +          // Likewise, do not reuse a cast at BIP because it must dominate +          // instructions that might be inserted before BIP. +          if (BasicBlock::iterator(CI) != IP || BIP == IP) { +            // Create a new cast, and leave the old cast in place in case +            // it is being used as an insert point. Clear its operand +            // so that it doesn't hold anything live. +            Ret = CastInst::Create(Op, V, Ty, "", &*IP); +            Ret->takeName(CI); +            CI->replaceAllUsesWith(Ret); +            CI->setOperand(0, UndefValue::get(V->getType())); +            break; +          } +          Ret = CI; +          break; +        } + +  // Create a new cast. +  if (!Ret) +    Ret = CastInst::Create(Op, V, Ty, V->getName(), &*IP); + +  // We assert at the end of the function since IP might point to an +  // instruction with different dominance properties than a cast +  // (an invoke for example) and not dominate BIP (but the cast does). +  assert(SE.DT.dominates(Ret, &*BIP)); + +  rememberInstruction(Ret); +  return Ret; +} + +static BasicBlock::iterator findInsertPointAfter(Instruction *I, +                                                 BasicBlock *MustDominate) { +  BasicBlock::iterator IP = ++I->getIterator(); +  if (auto *II = dyn_cast<InvokeInst>(I)) +    IP = II->getNormalDest()->begin(); + +  while (isa<PHINode>(IP)) +    ++IP; + +  if (isa<FuncletPadInst>(IP) || isa<LandingPadInst>(IP)) { +    ++IP; +  } else if (isa<CatchSwitchInst>(IP)) { +    IP = MustDominate->getFirstInsertionPt(); +  } else { +    assert(!IP->isEHPad() && "unexpected eh pad!"); +  } + +  return IP; +} + +/// InsertNoopCastOfTo - Insert a cast of V to the specified type, +/// which must be possible with a noop cast, doing what we can to share +/// the casts. +Value *SCEVExpander::InsertNoopCastOfTo(Value *V, Type *Ty) { +  Instruction::CastOps Op = CastInst::getCastOpcode(V, false, Ty, false); +  assert((Op == Instruction::BitCast || +          Op == Instruction::PtrToInt || +          Op == Instruction::IntToPtr) && +         "InsertNoopCastOfTo cannot perform non-noop casts!"); +  assert(SE.getTypeSizeInBits(V->getType()) == SE.getTypeSizeInBits(Ty) && +         "InsertNoopCastOfTo cannot change sizes!"); + +  // Short-circuit unnecessary bitcasts. +  if (Op == Instruction::BitCast) { +    if (V->getType() == Ty) +      return V; +    if (CastInst *CI = dyn_cast<CastInst>(V)) { +      if (CI->getOperand(0)->getType() == Ty) +        return CI->getOperand(0); +    } +  } +  // Short-circuit unnecessary inttoptr<->ptrtoint casts. +  if ((Op == Instruction::PtrToInt || Op == Instruction::IntToPtr) && +      SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(V->getType())) { +    if (CastInst *CI = dyn_cast<CastInst>(V)) +      if ((CI->getOpcode() == Instruction::PtrToInt || +           CI->getOpcode() == Instruction::IntToPtr) && +          SE.getTypeSizeInBits(CI->getType()) == +          SE.getTypeSizeInBits(CI->getOperand(0)->getType())) +        return CI->getOperand(0); +    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) +      if ((CE->getOpcode() == Instruction::PtrToInt || +           CE->getOpcode() == Instruction::IntToPtr) && +          SE.getTypeSizeInBits(CE->getType()) == +          SE.getTypeSizeInBits(CE->getOperand(0)->getType())) +        return CE->getOperand(0); +  } + +  // Fold a cast of a constant. +  if (Constant *C = dyn_cast<Constant>(V)) +    return ConstantExpr::getCast(Op, C, Ty); + +  // Cast the argument at the beginning of the entry block, after +  // any bitcasts of other arguments. +  if (Argument *A = dyn_cast<Argument>(V)) { +    BasicBlock::iterator IP = A->getParent()->getEntryBlock().begin(); +    while ((isa<BitCastInst>(IP) && +            isa<Argument>(cast<BitCastInst>(IP)->getOperand(0)) && +            cast<BitCastInst>(IP)->getOperand(0) != A) || +           isa<DbgInfoIntrinsic>(IP)) +      ++IP; +    return ReuseOrCreateCast(A, Ty, Op, IP); +  } + +  // Cast the instruction immediately after the instruction. +  Instruction *I = cast<Instruction>(V); +  BasicBlock::iterator IP = findInsertPointAfter(I, Builder.GetInsertBlock()); +  return ReuseOrCreateCast(I, Ty, Op, IP); +} + +/// InsertBinop - Insert the specified binary operator, doing a small amount +/// of work to avoid inserting an obviously redundant operation. +Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, +                                 Value *LHS, Value *RHS) { +  // Fold a binop with constant operands. +  if (Constant *CLHS = dyn_cast<Constant>(LHS)) +    if (Constant *CRHS = dyn_cast<Constant>(RHS)) +      return ConstantExpr::get(Opcode, CLHS, CRHS); + +  // Do a quick scan to see if we have this binop nearby.  If so, reuse it. +  unsigned ScanLimit = 6; +  BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); +  // Scanning starts from the last instruction before the insertion point. +  BasicBlock::iterator IP = Builder.GetInsertPoint(); +  if (IP != BlockBegin) { +    --IP; +    for (; ScanLimit; --IP, --ScanLimit) { +      // Don't count dbg.value against the ScanLimit, to avoid perturbing the +      // generated code. +      if (isa<DbgInfoIntrinsic>(IP)) +        ScanLimit++; + +      // Conservatively, do not use any instruction which has any of wrap/exact +      // flags installed. +      // TODO: Instead of simply disable poison instructions we can be clever +      //       here and match SCEV to this instruction. +      auto canGeneratePoison = [](Instruction *I) { +        if (isa<OverflowingBinaryOperator>(I) && +            (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())) +          return true; +        if (isa<PossiblyExactOperator>(I) && I->isExact()) +          return true; +        return false; +      }; +      if (IP->getOpcode() == (unsigned)Opcode && IP->getOperand(0) == LHS && +          IP->getOperand(1) == RHS && !canGeneratePoison(&*IP)) +        return &*IP; +      if (IP == BlockBegin) break; +    } +  } + +  // Save the original insertion point so we can restore it when we're done. +  DebugLoc Loc = Builder.GetInsertPoint()->getDebugLoc(); +  SCEVInsertPointGuard Guard(Builder, this); + +  // Move the insertion point out of as many loops as we can. +  while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { +    if (!L->isLoopInvariant(LHS) || !L->isLoopInvariant(RHS)) break; +    BasicBlock *Preheader = L->getLoopPreheader(); +    if (!Preheader) break; + +    // Ok, move up a level. +    Builder.SetInsertPoint(Preheader->getTerminator()); +  } + +  // If we haven't found this binop, insert it. +  Instruction *BO = cast<Instruction>(Builder.CreateBinOp(Opcode, LHS, RHS)); +  BO->setDebugLoc(Loc); +  rememberInstruction(BO); + +  return BO; +} + +/// FactorOutConstant - Test if S is divisible by Factor, using signed +/// division. If so, update S with Factor divided out and return true. +/// S need not be evenly divisible if a reasonable remainder can be +/// computed. +/// TODO: When ScalarEvolution gets a SCEVSDivExpr, this can be made +/// unnecessary; in its place, just signed-divide Ops[i] by the scale and +/// check to see if the divide was folded. +static bool FactorOutConstant(const SCEV *&S, const SCEV *&Remainder, +                              const SCEV *Factor, ScalarEvolution &SE, +                              const DataLayout &DL) { +  // Everything is divisible by one. +  if (Factor->isOne()) +    return true; + +  // x/x == 1. +  if (S == Factor) { +    S = SE.getConstant(S->getType(), 1); +    return true; +  } + +  // For a Constant, check for a multiple of the given factor. +  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { +    // 0/x == 0. +    if (C->isZero()) +      return true; +    // Check for divisibility. +    if (const SCEVConstant *FC = dyn_cast<SCEVConstant>(Factor)) { +      ConstantInt *CI = +          ConstantInt::get(SE.getContext(), C->getAPInt().sdiv(FC->getAPInt())); +      // If the quotient is zero and the remainder is non-zero, reject +      // the value at this scale. It will be considered for subsequent +      // smaller scales. +      if (!CI->isZero()) { +        const SCEV *Div = SE.getConstant(CI); +        S = Div; +        Remainder = SE.getAddExpr( +            Remainder, SE.getConstant(C->getAPInt().srem(FC->getAPInt()))); +        return true; +      } +    } +  } + +  // In a Mul, check if there is a constant operand which is a multiple +  // of the given factor. +  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) { +    // Size is known, check if there is a constant operand which is a multiple +    // of the given factor. If so, we can factor it. +    const SCEVConstant *FC = cast<SCEVConstant>(Factor); +    if (const SCEVConstant *C = dyn_cast<SCEVConstant>(M->getOperand(0))) +      if (!C->getAPInt().srem(FC->getAPInt())) { +        SmallVector<const SCEV *, 4> NewMulOps(M->op_begin(), M->op_end()); +        NewMulOps[0] = SE.getConstant(C->getAPInt().sdiv(FC->getAPInt())); +        S = SE.getMulExpr(NewMulOps); +        return true; +      } +  } + +  // In an AddRec, check if both start and step are divisible. +  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) { +    const SCEV *Step = A->getStepRecurrence(SE); +    const SCEV *StepRem = SE.getConstant(Step->getType(), 0); +    if (!FactorOutConstant(Step, StepRem, Factor, SE, DL)) +      return false; +    if (!StepRem->isZero()) +      return false; +    const SCEV *Start = A->getStart(); +    if (!FactorOutConstant(Start, Remainder, Factor, SE, DL)) +      return false; +    S = SE.getAddRecExpr(Start, Step, A->getLoop(), +                         A->getNoWrapFlags(SCEV::FlagNW)); +    return true; +  } + +  return false; +} + +/// SimplifyAddOperands - Sort and simplify a list of add operands. NumAddRecs +/// is the number of SCEVAddRecExprs present, which are kept at the end of +/// the list. +/// +static void SimplifyAddOperands(SmallVectorImpl<const SCEV *> &Ops, +                                Type *Ty, +                                ScalarEvolution &SE) { +  unsigned NumAddRecs = 0; +  for (unsigned i = Ops.size(); i > 0 && isa<SCEVAddRecExpr>(Ops[i-1]); --i) +    ++NumAddRecs; +  // Group Ops into non-addrecs and addrecs. +  SmallVector<const SCEV *, 8> NoAddRecs(Ops.begin(), Ops.end() - NumAddRecs); +  SmallVector<const SCEV *, 8> AddRecs(Ops.end() - NumAddRecs, Ops.end()); +  // Let ScalarEvolution sort and simplify the non-addrecs list. +  const SCEV *Sum = NoAddRecs.empty() ? +                    SE.getConstant(Ty, 0) : +                    SE.getAddExpr(NoAddRecs); +  // If it returned an add, use the operands. Otherwise it simplified +  // the sum into a single value, so just use that. +  Ops.clear(); +  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Sum)) +    Ops.append(Add->op_begin(), Add->op_end()); +  else if (!Sum->isZero()) +    Ops.push_back(Sum); +  // Then append the addrecs. +  Ops.append(AddRecs.begin(), AddRecs.end()); +} + +/// SplitAddRecs - Flatten a list of add operands, moving addrec start values +/// out to the top level. For example, convert {a + b,+,c} to a, b, {0,+,d}. +/// This helps expose more opportunities for folding parts of the expressions +/// into GEP indices. +/// +static void SplitAddRecs(SmallVectorImpl<const SCEV *> &Ops, +                         Type *Ty, +                         ScalarEvolution &SE) { +  // Find the addrecs. +  SmallVector<const SCEV *, 8> AddRecs; +  for (unsigned i = 0, e = Ops.size(); i != e; ++i) +    while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Ops[i])) { +      const SCEV *Start = A->getStart(); +      if (Start->isZero()) break; +      const SCEV *Zero = SE.getConstant(Ty, 0); +      AddRecs.push_back(SE.getAddRecExpr(Zero, +                                         A->getStepRecurrence(SE), +                                         A->getLoop(), +                                         A->getNoWrapFlags(SCEV::FlagNW))); +      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Start)) { +        Ops[i] = Zero; +        Ops.append(Add->op_begin(), Add->op_end()); +        e += Add->getNumOperands(); +      } else { +        Ops[i] = Start; +      } +    } +  if (!AddRecs.empty()) { +    // Add the addrecs onto the end of the list. +    Ops.append(AddRecs.begin(), AddRecs.end()); +    // Resort the operand list, moving any constants to the front. +    SimplifyAddOperands(Ops, Ty, SE); +  } +} + +/// expandAddToGEP - Expand an addition expression with a pointer type into +/// a GEP instead of using ptrtoint+arithmetic+inttoptr. This helps +/// BasicAliasAnalysis and other passes analyze the result. See the rules +/// for getelementptr vs. inttoptr in +/// http://llvm.org/docs/LangRef.html#pointeraliasing +/// for details. +/// +/// Design note: The correctness of using getelementptr here depends on +/// ScalarEvolution not recognizing inttoptr and ptrtoint operators, as +/// they may introduce pointer arithmetic which may not be safely converted +/// into getelementptr. +/// +/// Design note: It might seem desirable for this function to be more +/// loop-aware. If some of the indices are loop-invariant while others +/// aren't, it might seem desirable to emit multiple GEPs, keeping the +/// loop-invariant portions of the overall computation outside the loop. +/// However, there are a few reasons this is not done here. Hoisting simple +/// arithmetic is a low-level optimization that often isn't very +/// important until late in the optimization process. In fact, passes +/// like InstructionCombining will combine GEPs, even if it means +/// pushing loop-invariant computation down into loops, so even if the +/// GEPs were split here, the work would quickly be undone. The +/// LoopStrengthReduction pass, which is usually run quite late (and +/// after the last InstructionCombining pass), takes care of hoisting +/// loop-invariant portions of expressions, after considering what +/// can be folded using target addressing modes. +/// +Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, +                                    const SCEV *const *op_end, +                                    PointerType *PTy, +                                    Type *Ty, +                                    Value *V) { +  Type *OriginalElTy = PTy->getElementType(); +  Type *ElTy = OriginalElTy; +  SmallVector<Value *, 4> GepIndices; +  SmallVector<const SCEV *, 8> Ops(op_begin, op_end); +  bool AnyNonZeroIndices = false; + +  // Split AddRecs up into parts as either of the parts may be usable +  // without the other. +  SplitAddRecs(Ops, Ty, SE); + +  Type *IntPtrTy = DL.getIntPtrType(PTy); + +  // Descend down the pointer's type and attempt to convert the other +  // operands into GEP indices, at each level. The first index in a GEP +  // indexes into the array implied by the pointer operand; the rest of +  // the indices index into the element or field type selected by the +  // preceding index. +  for (;;) { +    // If the scale size is not 0, attempt to factor out a scale for +    // array indexing. +    SmallVector<const SCEV *, 8> ScaledOps; +    if (ElTy->isSized()) { +      const SCEV *ElSize = SE.getSizeOfExpr(IntPtrTy, ElTy); +      if (!ElSize->isZero()) { +        SmallVector<const SCEV *, 8> NewOps; +        for (const SCEV *Op : Ops) { +          const SCEV *Remainder = SE.getConstant(Ty, 0); +          if (FactorOutConstant(Op, Remainder, ElSize, SE, DL)) { +            // Op now has ElSize factored out. +            ScaledOps.push_back(Op); +            if (!Remainder->isZero()) +              NewOps.push_back(Remainder); +            AnyNonZeroIndices = true; +          } else { +            // The operand was not divisible, so add it to the list of operands +            // we'll scan next iteration. +            NewOps.push_back(Op); +          } +        } +        // If we made any changes, update Ops. +        if (!ScaledOps.empty()) { +          Ops = NewOps; +          SimplifyAddOperands(Ops, Ty, SE); +        } +      } +    } + +    // Record the scaled array index for this level of the type. If +    // we didn't find any operands that could be factored, tentatively +    // assume that element zero was selected (since the zero offset +    // would obviously be folded away). +    Value *Scaled = ScaledOps.empty() ? +                    Constant::getNullValue(Ty) : +                    expandCodeFor(SE.getAddExpr(ScaledOps), Ty); +    GepIndices.push_back(Scaled); + +    // Collect struct field index operands. +    while (StructType *STy = dyn_cast<StructType>(ElTy)) { +      bool FoundFieldNo = false; +      // An empty struct has no fields. +      if (STy->getNumElements() == 0) break; +      // Field offsets are known. See if a constant offset falls within any of +      // the struct fields. +      if (Ops.empty()) +        break; +      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[0])) +        if (SE.getTypeSizeInBits(C->getType()) <= 64) { +          const StructLayout &SL = *DL.getStructLayout(STy); +          uint64_t FullOffset = C->getValue()->getZExtValue(); +          if (FullOffset < SL.getSizeInBytes()) { +            unsigned ElIdx = SL.getElementContainingOffset(FullOffset); +            GepIndices.push_back( +                ConstantInt::get(Type::getInt32Ty(Ty->getContext()), ElIdx)); +            ElTy = STy->getTypeAtIndex(ElIdx); +            Ops[0] = +                SE.getConstant(Ty, FullOffset - SL.getElementOffset(ElIdx)); +            AnyNonZeroIndices = true; +            FoundFieldNo = true; +          } +        } +      // If no struct field offsets were found, tentatively assume that +      // field zero was selected (since the zero offset would obviously +      // be folded away). +      if (!FoundFieldNo) { +        ElTy = STy->getTypeAtIndex(0u); +        GepIndices.push_back( +          Constant::getNullValue(Type::getInt32Ty(Ty->getContext()))); +      } +    } + +    if (ArrayType *ATy = dyn_cast<ArrayType>(ElTy)) +      ElTy = ATy->getElementType(); +    else +      break; +  } + +  // If none of the operands were convertible to proper GEP indices, cast +  // the base to i8* and do an ugly getelementptr with that. It's still +  // better than ptrtoint+arithmetic+inttoptr at least. +  if (!AnyNonZeroIndices) { +    // Cast the base to i8*. +    V = InsertNoopCastOfTo(V, +       Type::getInt8PtrTy(Ty->getContext(), PTy->getAddressSpace())); + +    assert(!isa<Instruction>(V) || +           SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint())); + +    // Expand the operands for a plain byte offset. +    Value *Idx = expandCodeFor(SE.getAddExpr(Ops), Ty); + +    // Fold a GEP with constant operands. +    if (Constant *CLHS = dyn_cast<Constant>(V)) +      if (Constant *CRHS = dyn_cast<Constant>(Idx)) +        return ConstantExpr::getGetElementPtr(Type::getInt8Ty(Ty->getContext()), +                                              CLHS, CRHS); + +    // Do a quick scan to see if we have this GEP nearby.  If so, reuse it. +    unsigned ScanLimit = 6; +    BasicBlock::iterator BlockBegin = Builder.GetInsertBlock()->begin(); +    // Scanning starts from the last instruction before the insertion point. +    BasicBlock::iterator IP = Builder.GetInsertPoint(); +    if (IP != BlockBegin) { +      --IP; +      for (; ScanLimit; --IP, --ScanLimit) { +        // Don't count dbg.value against the ScanLimit, to avoid perturbing the +        // generated code. +        if (isa<DbgInfoIntrinsic>(IP)) +          ScanLimit++; +        if (IP->getOpcode() == Instruction::GetElementPtr && +            IP->getOperand(0) == V && IP->getOperand(1) == Idx) +          return &*IP; +        if (IP == BlockBegin) break; +      } +    } + +    // Save the original insertion point so we can restore it when we're done. +    SCEVInsertPointGuard Guard(Builder, this); + +    // Move the insertion point out of as many loops as we can. +    while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { +      if (!L->isLoopInvariant(V) || !L->isLoopInvariant(Idx)) break; +      BasicBlock *Preheader = L->getLoopPreheader(); +      if (!Preheader) break; + +      // Ok, move up a level. +      Builder.SetInsertPoint(Preheader->getTerminator()); +    } + +    // Emit a GEP. +    Value *GEP = Builder.CreateGEP(Builder.getInt8Ty(), V, Idx, "uglygep"); +    rememberInstruction(GEP); + +    return GEP; +  } + +  { +    SCEVInsertPointGuard Guard(Builder, this); + +    // Move the insertion point out of as many loops as we can. +    while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { +      if (!L->isLoopInvariant(V)) break; + +      bool AnyIndexNotLoopInvariant = any_of( +          GepIndices, [L](Value *Op) { return !L->isLoopInvariant(Op); }); + +      if (AnyIndexNotLoopInvariant) +        break; + +      BasicBlock *Preheader = L->getLoopPreheader(); +      if (!Preheader) break; + +      // Ok, move up a level. +      Builder.SetInsertPoint(Preheader->getTerminator()); +    } + +    // Insert a pretty getelementptr. Note that this GEP is not marked inbounds, +    // because ScalarEvolution may have changed the address arithmetic to +    // compute a value which is beyond the end of the allocated object. +    Value *Casted = V; +    if (V->getType() != PTy) +      Casted = InsertNoopCastOfTo(Casted, PTy); +    Value *GEP = Builder.CreateGEP(OriginalElTy, Casted, GepIndices, "scevgep"); +    Ops.push_back(SE.getUnknown(GEP)); +    rememberInstruction(GEP); +  } + +  return expand(SE.getAddExpr(Ops)); +} + +Value *SCEVExpander::expandAddToGEP(const SCEV *Op, PointerType *PTy, Type *Ty, +                                    Value *V) { +  const SCEV *const Ops[1] = {Op}; +  return expandAddToGEP(Ops, Ops + 1, PTy, Ty, V); +} + +/// PickMostRelevantLoop - Given two loops pick the one that's most relevant for +/// SCEV expansion. If they are nested, this is the most nested. If they are +/// neighboring, pick the later. +static const Loop *PickMostRelevantLoop(const Loop *A, const Loop *B, +                                        DominatorTree &DT) { +  if (!A) return B; +  if (!B) return A; +  if (A->contains(B)) return B; +  if (B->contains(A)) return A; +  if (DT.dominates(A->getHeader(), B->getHeader())) return B; +  if (DT.dominates(B->getHeader(), A->getHeader())) return A; +  return A; // Arbitrarily break the tie. +} + +/// getRelevantLoop - Get the most relevant loop associated with the given +/// expression, according to PickMostRelevantLoop. +const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) { +  // Test whether we've already computed the most relevant loop for this SCEV. +  auto Pair = RelevantLoops.insert(std::make_pair(S, nullptr)); +  if (!Pair.second) +    return Pair.first->second; + +  if (isa<SCEVConstant>(S)) +    // A constant has no relevant loops. +    return nullptr; +  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { +    if (const Instruction *I = dyn_cast<Instruction>(U->getValue())) +      return Pair.first->second = SE.LI.getLoopFor(I->getParent()); +    // A non-instruction has no relevant loops. +    return nullptr; +  } +  if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S)) { +    const Loop *L = nullptr; +    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) +      L = AR->getLoop(); +    for (const SCEV *Op : N->operands()) +      L = PickMostRelevantLoop(L, getRelevantLoop(Op), SE.DT); +    return RelevantLoops[N] = L; +  } +  if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(S)) { +    const Loop *Result = getRelevantLoop(C->getOperand()); +    return RelevantLoops[C] = Result; +  } +  if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { +    const Loop *Result = PickMostRelevantLoop( +        getRelevantLoop(D->getLHS()), getRelevantLoop(D->getRHS()), SE.DT); +    return RelevantLoops[D] = Result; +  } +  llvm_unreachable("Unexpected SCEV type!"); +} + +namespace { + +/// LoopCompare - Compare loops by PickMostRelevantLoop. +class LoopCompare { +  DominatorTree &DT; +public: +  explicit LoopCompare(DominatorTree &dt) : DT(dt) {} + +  bool operator()(std::pair<const Loop *, const SCEV *> LHS, +                  std::pair<const Loop *, const SCEV *> RHS) const { +    // Keep pointer operands sorted at the end. +    if (LHS.second->getType()->isPointerTy() != +        RHS.second->getType()->isPointerTy()) +      return LHS.second->getType()->isPointerTy(); + +    // Compare loops with PickMostRelevantLoop. +    if (LHS.first != RHS.first) +      return PickMostRelevantLoop(LHS.first, RHS.first, DT) != LHS.first; + +    // If one operand is a non-constant negative and the other is not, +    // put the non-constant negative on the right so that a sub can +    // be used instead of a negate and add. +    if (LHS.second->isNonConstantNegative()) { +      if (!RHS.second->isNonConstantNegative()) +        return false; +    } else if (RHS.second->isNonConstantNegative()) +      return true; + +    // Otherwise they are equivalent according to this comparison. +    return false; +  } +}; + +} + +Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) { +  Type *Ty = SE.getEffectiveSCEVType(S->getType()); + +  // Collect all the add operands in a loop, along with their associated loops. +  // Iterate in reverse so that constants are emitted last, all else equal, and +  // so that pointer operands are inserted first, which the code below relies on +  // to form more involved GEPs. +  SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops; +  for (std::reverse_iterator<SCEVAddExpr::op_iterator> I(S->op_end()), +       E(S->op_begin()); I != E; ++I) +    OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); + +  // Sort by loop. Use a stable sort so that constants follow non-constants and +  // pointer operands precede non-pointer operands. +  std::stable_sort(OpsAndLoops.begin(), OpsAndLoops.end(), LoopCompare(SE.DT)); + +  // Emit instructions to add all the operands. Hoist as much as possible +  // out of loops, and form meaningful getelementptrs where possible. +  Value *Sum = nullptr; +  for (auto I = OpsAndLoops.begin(), E = OpsAndLoops.end(); I != E;) { +    const Loop *CurLoop = I->first; +    const SCEV *Op = I->second; +    if (!Sum) { +      // This is the first operand. Just expand it. +      Sum = expand(Op); +      ++I; +    } else if (PointerType *PTy = dyn_cast<PointerType>(Sum->getType())) { +      // The running sum expression is a pointer. Try to form a getelementptr +      // at this level with that as the base. +      SmallVector<const SCEV *, 4> NewOps; +      for (; I != E && I->first == CurLoop; ++I) { +        // If the operand is SCEVUnknown and not instructions, peek through +        // it, to enable more of it to be folded into the GEP. +        const SCEV *X = I->second; +        if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(X)) +          if (!isa<Instruction>(U->getValue())) +            X = SE.getSCEV(U->getValue()); +        NewOps.push_back(X); +      } +      Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, Sum); +    } else if (PointerType *PTy = dyn_cast<PointerType>(Op->getType())) { +      // The running sum is an integer, and there's a pointer at this level. +      // Try to form a getelementptr. If the running sum is instructions, +      // use a SCEVUnknown to avoid re-analyzing them. +      SmallVector<const SCEV *, 4> NewOps; +      NewOps.push_back(isa<Instruction>(Sum) ? SE.getUnknown(Sum) : +                                               SE.getSCEV(Sum)); +      for (++I; I != E && I->first == CurLoop; ++I) +        NewOps.push_back(I->second); +      Sum = expandAddToGEP(NewOps.begin(), NewOps.end(), PTy, Ty, expand(Op)); +    } else if (Op->isNonConstantNegative()) { +      // Instead of doing a negate and add, just do a subtract. +      Value *W = expandCodeFor(SE.getNegativeSCEV(Op), Ty); +      Sum = InsertNoopCastOfTo(Sum, Ty); +      Sum = InsertBinop(Instruction::Sub, Sum, W); +      ++I; +    } else { +      // A simple add. +      Value *W = expandCodeFor(Op, Ty); +      Sum = InsertNoopCastOfTo(Sum, Ty); +      // Canonicalize a constant to the RHS. +      if (isa<Constant>(Sum)) std::swap(Sum, W); +      Sum = InsertBinop(Instruction::Add, Sum, W); +      ++I; +    } +  } + +  return Sum; +} + +Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { +  Type *Ty = SE.getEffectiveSCEVType(S->getType()); + +  // Collect all the mul operands in a loop, along with their associated loops. +  // Iterate in reverse so that constants are emitted last, all else equal. +  SmallVector<std::pair<const Loop *, const SCEV *>, 8> OpsAndLoops; +  for (std::reverse_iterator<SCEVMulExpr::op_iterator> I(S->op_end()), +       E(S->op_begin()); I != E; ++I) +    OpsAndLoops.push_back(std::make_pair(getRelevantLoop(*I), *I)); + +  // Sort by loop. Use a stable sort so that constants follow non-constants. +  std::stable_sort(OpsAndLoops.begin(), OpsAndLoops.end(), LoopCompare(SE.DT)); + +  // Emit instructions to mul all the operands. Hoist as much as possible +  // out of loops. +  Value *Prod = nullptr; +  auto I = OpsAndLoops.begin(); + +  // Expand the calculation of X pow N in the following manner: +  // Let N = P1 + P2 + ... + PK, where all P are powers of 2. Then: +  // X pow N = (X pow P1) * (X pow P2) * ... * (X pow PK). +  const auto ExpandOpBinPowN = [this, &I, &OpsAndLoops, &Ty]() { +    auto E = I; +    // Calculate how many times the same operand from the same loop is included +    // into this power. +    uint64_t Exponent = 0; +    const uint64_t MaxExponent = UINT64_MAX >> 1; +    // No one sane will ever try to calculate such huge exponents, but if we +    // need this, we stop on UINT64_MAX / 2 because we need to exit the loop +    // below when the power of 2 exceeds our Exponent, and we want it to be +    // 1u << 31 at most to not deal with unsigned overflow. +    while (E != OpsAndLoops.end() && *I == *E && Exponent != MaxExponent) { +      ++Exponent; +      ++E; +    } +    assert(Exponent > 0 && "Trying to calculate a zeroth exponent of operand?"); + +    // Calculate powers with exponents 1, 2, 4, 8 etc. and include those of them +    // that are needed into the result. +    Value *P = expandCodeFor(I->second, Ty); +    Value *Result = nullptr; +    if (Exponent & 1) +      Result = P; +    for (uint64_t BinExp = 2; BinExp <= Exponent; BinExp <<= 1) { +      P = InsertBinop(Instruction::Mul, P, P); +      if (Exponent & BinExp) +        Result = Result ? InsertBinop(Instruction::Mul, Result, P) : P; +    } + +    I = E; +    assert(Result && "Nothing was expanded?"); +    return Result; +  }; + +  while (I != OpsAndLoops.end()) { +    if (!Prod) { +      // This is the first operand. Just expand it. +      Prod = ExpandOpBinPowN(); +    } else if (I->second->isAllOnesValue()) { +      // Instead of doing a multiply by negative one, just do a negate. +      Prod = InsertNoopCastOfTo(Prod, Ty); +      Prod = InsertBinop(Instruction::Sub, Constant::getNullValue(Ty), Prod); +      ++I; +    } else { +      // A simple mul. +      Value *W = ExpandOpBinPowN(); +      Prod = InsertNoopCastOfTo(Prod, Ty); +      // Canonicalize a constant to the RHS. +      if (isa<Constant>(Prod)) std::swap(Prod, W); +      const APInt *RHS; +      if (match(W, m_Power2(RHS))) { +        // Canonicalize Prod*(1<<C) to Prod<<C. +        assert(!Ty->isVectorTy() && "vector types are not SCEVable"); +        Prod = InsertBinop(Instruction::Shl, Prod, +                           ConstantInt::get(Ty, RHS->logBase2())); +      } else { +        Prod = InsertBinop(Instruction::Mul, Prod, W); +      } +    } +  } + +  return Prod; +} + +Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) { +  Type *Ty = SE.getEffectiveSCEVType(S->getType()); + +  Value *LHS = expandCodeFor(S->getLHS(), Ty); +  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) { +    const APInt &RHS = SC->getAPInt(); +    if (RHS.isPowerOf2()) +      return InsertBinop(Instruction::LShr, LHS, +                         ConstantInt::get(Ty, RHS.logBase2())); +  } + +  Value *RHS = expandCodeFor(S->getRHS(), Ty); +  return InsertBinop(Instruction::UDiv, LHS, RHS); +} + +/// Move parts of Base into Rest to leave Base with the minimal +/// expression that provides a pointer operand suitable for a +/// GEP expansion. +static void ExposePointerBase(const SCEV *&Base, const SCEV *&Rest, +                              ScalarEvolution &SE) { +  while (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(Base)) { +    Base = A->getStart(); +    Rest = SE.getAddExpr(Rest, +                         SE.getAddRecExpr(SE.getConstant(A->getType(), 0), +                                          A->getStepRecurrence(SE), +                                          A->getLoop(), +                                          A->getNoWrapFlags(SCEV::FlagNW))); +  } +  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(Base)) { +    Base = A->getOperand(A->getNumOperands()-1); +    SmallVector<const SCEV *, 8> NewAddOps(A->op_begin(), A->op_end()); +    NewAddOps.back() = Rest; +    Rest = SE.getAddExpr(NewAddOps); +    ExposePointerBase(Base, Rest, SE); +  } +} + +/// Determine if this is a well-behaved chain of instructions leading back to +/// the PHI. If so, it may be reused by expanded expressions. +bool SCEVExpander::isNormalAddRecExprPHI(PHINode *PN, Instruction *IncV, +                                         const Loop *L) { +  if (IncV->getNumOperands() == 0 || isa<PHINode>(IncV) || +      (isa<CastInst>(IncV) && !isa<BitCastInst>(IncV))) +    return false; +  // If any of the operands don't dominate the insert position, bail. +  // Addrec operands are always loop-invariant, so this can only happen +  // if there are instructions which haven't been hoisted. +  if (L == IVIncInsertLoop) { +    for (User::op_iterator OI = IncV->op_begin()+1, +           OE = IncV->op_end(); OI != OE; ++OI) +      if (Instruction *OInst = dyn_cast<Instruction>(OI)) +        if (!SE.DT.dominates(OInst, IVIncInsertPos)) +          return false; +  } +  // Advance to the next instruction. +  IncV = dyn_cast<Instruction>(IncV->getOperand(0)); +  if (!IncV) +    return false; + +  if (IncV->mayHaveSideEffects()) +    return false; + +  if (IncV == PN) +    return true; + +  return isNormalAddRecExprPHI(PN, IncV, L); +} + +/// getIVIncOperand returns an induction variable increment's induction +/// variable operand. +/// +/// If allowScale is set, any type of GEP is allowed as long as the nonIV +/// operands dominate InsertPos. +/// +/// If allowScale is not set, ensure that a GEP increment conforms to one of the +/// simple patterns generated by getAddRecExprPHILiterally and +/// expandAddtoGEP. If the pattern isn't recognized, return NULL. +Instruction *SCEVExpander::getIVIncOperand(Instruction *IncV, +                                           Instruction *InsertPos, +                                           bool allowScale) { +  if (IncV == InsertPos) +    return nullptr; + +  switch (IncV->getOpcode()) { +  default: +    return nullptr; +  // Check for a simple Add/Sub or GEP of a loop invariant step. +  case Instruction::Add: +  case Instruction::Sub: { +    Instruction *OInst = dyn_cast<Instruction>(IncV->getOperand(1)); +    if (!OInst || SE.DT.dominates(OInst, InsertPos)) +      return dyn_cast<Instruction>(IncV->getOperand(0)); +    return nullptr; +  } +  case Instruction::BitCast: +    return dyn_cast<Instruction>(IncV->getOperand(0)); +  case Instruction::GetElementPtr: +    for (auto I = IncV->op_begin() + 1, E = IncV->op_end(); I != E; ++I) { +      if (isa<Constant>(*I)) +        continue; +      if (Instruction *OInst = dyn_cast<Instruction>(*I)) { +        if (!SE.DT.dominates(OInst, InsertPos)) +          return nullptr; +      } +      if (allowScale) { +        // allow any kind of GEP as long as it can be hoisted. +        continue; +      } +      // This must be a pointer addition of constants (pretty), which is already +      // handled, or some number of address-size elements (ugly). Ugly geps +      // have 2 operands. i1* is used by the expander to represent an +      // address-size element. +      if (IncV->getNumOperands() != 2) +        return nullptr; +      unsigned AS = cast<PointerType>(IncV->getType())->getAddressSpace(); +      if (IncV->getType() != Type::getInt1PtrTy(SE.getContext(), AS) +          && IncV->getType() != Type::getInt8PtrTy(SE.getContext(), AS)) +        return nullptr; +      break; +    } +    return dyn_cast<Instruction>(IncV->getOperand(0)); +  } +} + +/// If the insert point of the current builder or any of the builders on the +/// stack of saved builders has 'I' as its insert point, update it to point to +/// the instruction after 'I'.  This is intended to be used when the instruction +/// 'I' is being moved.  If this fixup is not done and 'I' is moved to a +/// different block, the inconsistent insert point (with a mismatched +/// Instruction and Block) can lead to an instruction being inserted in a block +/// other than its parent. +void SCEVExpander::fixupInsertPoints(Instruction *I) { +  BasicBlock::iterator It(*I); +  BasicBlock::iterator NewInsertPt = std::next(It); +  if (Builder.GetInsertPoint() == It) +    Builder.SetInsertPoint(&*NewInsertPt); +  for (auto *InsertPtGuard : InsertPointGuards) +    if (InsertPtGuard->GetInsertPoint() == It) +      InsertPtGuard->SetInsertPoint(NewInsertPt); +} + +/// hoistStep - Attempt to hoist a simple IV increment above InsertPos to make +/// it available to other uses in this loop. Recursively hoist any operands, +/// until we reach a value that dominates InsertPos. +bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos) { +  if (SE.DT.dominates(IncV, InsertPos)) +      return true; + +  // InsertPos must itself dominate IncV so that IncV's new position satisfies +  // its existing users. +  if (isa<PHINode>(InsertPos) || +      !SE.DT.dominates(InsertPos->getParent(), IncV->getParent())) +    return false; + +  if (!SE.LI.movementPreservesLCSSAForm(IncV, InsertPos)) +    return false; + +  // Check that the chain of IV operands leading back to Phi can be hoisted. +  SmallVector<Instruction*, 4> IVIncs; +  for(;;) { +    Instruction *Oper = getIVIncOperand(IncV, InsertPos, /*allowScale*/true); +    if (!Oper) +      return false; +    // IncV is safe to hoist. +    IVIncs.push_back(IncV); +    IncV = Oper; +    if (SE.DT.dominates(IncV, InsertPos)) +      break; +  } +  for (auto I = IVIncs.rbegin(), E = IVIncs.rend(); I != E; ++I) { +    fixupInsertPoints(*I); +    (*I)->moveBefore(InsertPos); +  } +  return true; +} + +/// Determine if this cyclic phi is in a form that would have been generated by +/// LSR. We don't care if the phi was actually expanded in this pass, as long +/// as it is in a low-cost form, for example, no implied multiplication. This +/// should match any patterns generated by getAddRecExprPHILiterally and +/// expandAddtoGEP. +bool SCEVExpander::isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, +                                           const Loop *L) { +  for(Instruction *IVOper = IncV; +      (IVOper = getIVIncOperand(IVOper, L->getLoopPreheader()->getTerminator(), +                                /*allowScale=*/false));) { +    if (IVOper == PN) +      return true; +  } +  return false; +} + +/// expandIVInc - Expand an IV increment at Builder's current InsertPos. +/// Typically this is the LatchBlock terminator or IVIncInsertPos, but we may +/// need to materialize IV increments elsewhere to handle difficult situations. +Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, +                                 Type *ExpandTy, Type *IntTy, +                                 bool useSubtract) { +  Value *IncV; +  // If the PHI is a pointer, use a GEP, otherwise use an add or sub. +  if (ExpandTy->isPointerTy()) { +    PointerType *GEPPtrTy = cast<PointerType>(ExpandTy); +    // If the step isn't constant, don't use an implicitly scaled GEP, because +    // that would require a multiply inside the loop. +    if (!isa<ConstantInt>(StepV)) +      GEPPtrTy = PointerType::get(Type::getInt1Ty(SE.getContext()), +                                  GEPPtrTy->getAddressSpace()); +    IncV = expandAddToGEP(SE.getSCEV(StepV), GEPPtrTy, IntTy, PN); +    if (IncV->getType() != PN->getType()) { +      IncV = Builder.CreateBitCast(IncV, PN->getType()); +      rememberInstruction(IncV); +    } +  } else { +    IncV = useSubtract ? +      Builder.CreateSub(PN, StepV, Twine(IVName) + ".iv.next") : +      Builder.CreateAdd(PN, StepV, Twine(IVName) + ".iv.next"); +    rememberInstruction(IncV); +  } +  return IncV; +} + +/// Hoist the addrec instruction chain rooted in the loop phi above the +/// position. This routine assumes that this is possible (has been checked). +void SCEVExpander::hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist, +                                  Instruction *Pos, PHINode *LoopPhi) { +  do { +    if (DT->dominates(InstToHoist, Pos)) +      break; +    // Make sure the increment is where we want it. But don't move it +    // down past a potential existing post-inc user. +    fixupInsertPoints(InstToHoist); +    InstToHoist->moveBefore(Pos); +    Pos = InstToHoist; +    InstToHoist = cast<Instruction>(InstToHoist->getOperand(0)); +  } while (InstToHoist != LoopPhi); +} + +/// Check whether we can cheaply express the requested SCEV in terms of +/// the available PHI SCEV by truncation and/or inversion of the step. +static bool canBeCheaplyTransformed(ScalarEvolution &SE, +                                    const SCEVAddRecExpr *Phi, +                                    const SCEVAddRecExpr *Requested, +                                    bool &InvertStep) { +  Type *PhiTy = SE.getEffectiveSCEVType(Phi->getType()); +  Type *RequestedTy = SE.getEffectiveSCEVType(Requested->getType()); + +  if (RequestedTy->getIntegerBitWidth() > PhiTy->getIntegerBitWidth()) +    return false; + +  // Try truncate it if necessary. +  Phi = dyn_cast<SCEVAddRecExpr>(SE.getTruncateOrNoop(Phi, RequestedTy)); +  if (!Phi) +    return false; + +  // Check whether truncation will help. +  if (Phi == Requested) { +    InvertStep = false; +    return true; +  } + +  // Check whether inverting will help: {R,+,-1} == R - {0,+,1}. +  if (SE.getAddExpr(Requested->getStart(), +                    SE.getNegativeSCEV(Requested)) == Phi) { +    InvertStep = true; +    return true; +  } + +  return false; +} + +static bool IsIncrementNSW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) { +  if (!isa<IntegerType>(AR->getType())) +    return false; + +  unsigned BitWidth = cast<IntegerType>(AR->getType())->getBitWidth(); +  Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2); +  const SCEV *Step = AR->getStepRecurrence(SE); +  const SCEV *OpAfterExtend = SE.getAddExpr(SE.getSignExtendExpr(Step, WideTy), +                                            SE.getSignExtendExpr(AR, WideTy)); +  const SCEV *ExtendAfterOp = +    SE.getSignExtendExpr(SE.getAddExpr(AR, Step), WideTy); +  return ExtendAfterOp == OpAfterExtend; +} + +static bool IsIncrementNUW(ScalarEvolution &SE, const SCEVAddRecExpr *AR) { +  if (!isa<IntegerType>(AR->getType())) +    return false; + +  unsigned BitWidth = cast<IntegerType>(AR->getType())->getBitWidth(); +  Type *WideTy = IntegerType::get(AR->getType()->getContext(), BitWidth * 2); +  const SCEV *Step = AR->getStepRecurrence(SE); +  const SCEV *OpAfterExtend = SE.getAddExpr(SE.getZeroExtendExpr(Step, WideTy), +                                            SE.getZeroExtendExpr(AR, WideTy)); +  const SCEV *ExtendAfterOp = +    SE.getZeroExtendExpr(SE.getAddExpr(AR, Step), WideTy); +  return ExtendAfterOp == OpAfterExtend; +} + +/// getAddRecExprPHILiterally - Helper for expandAddRecExprLiterally. Expand +/// the base addrec, which is the addrec without any non-loop-dominating +/// values, and return the PHI. +PHINode * +SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, +                                        const Loop *L, +                                        Type *ExpandTy, +                                        Type *IntTy, +                                        Type *&TruncTy, +                                        bool &InvertStep) { +  assert((!IVIncInsertLoop||IVIncInsertPos) && "Uninitialized insert position"); + +  // Reuse a previously-inserted PHI, if present. +  BasicBlock *LatchBlock = L->getLoopLatch(); +  if (LatchBlock) { +    PHINode *AddRecPhiMatch = nullptr; +    Instruction *IncV = nullptr; +    TruncTy = nullptr; +    InvertStep = false; + +    // Only try partially matching scevs that need truncation and/or +    // step-inversion if we know this loop is outside the current loop. +    bool TryNonMatchingSCEV = +        IVIncInsertLoop && +        SE.DT.properlyDominates(LatchBlock, IVIncInsertLoop->getHeader()); + +    for (PHINode &PN : L->getHeader()->phis()) { +      if (!SE.isSCEVable(PN.getType())) +        continue; + +      const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); +      if (!PhiSCEV) +        continue; + +      bool IsMatchingSCEV = PhiSCEV == Normalized; +      // We only handle truncation and inversion of phi recurrences for the +      // expanded expression if the expanded expression's loop dominates the +      // loop we insert to. Check now, so we can bail out early. +      if (!IsMatchingSCEV && !TryNonMatchingSCEV) +          continue; + +      // TODO: this possibly can be reworked to avoid this cast at all. +      Instruction *TempIncV = +          dyn_cast<Instruction>(PN.getIncomingValueForBlock(LatchBlock)); +      if (!TempIncV) +        continue; + +      // Check whether we can reuse this PHI node. +      if (LSRMode) { +        if (!isExpandedAddRecExprPHI(&PN, TempIncV, L)) +          continue; +        if (L == IVIncInsertLoop && !hoistIVInc(TempIncV, IVIncInsertPos)) +          continue; +      } else { +        if (!isNormalAddRecExprPHI(&PN, TempIncV, L)) +          continue; +      } + +      // Stop if we have found an exact match SCEV. +      if (IsMatchingSCEV) { +        IncV = TempIncV; +        TruncTy = nullptr; +        InvertStep = false; +        AddRecPhiMatch = &PN; +        break; +      } + +      // Try whether the phi can be translated into the requested form +      // (truncated and/or offset by a constant). +      if ((!TruncTy || InvertStep) && +          canBeCheaplyTransformed(SE, PhiSCEV, Normalized, InvertStep)) { +        // Record the phi node. But don't stop we might find an exact match +        // later. +        AddRecPhiMatch = &PN; +        IncV = TempIncV; +        TruncTy = SE.getEffectiveSCEVType(Normalized->getType()); +      } +    } + +    if (AddRecPhiMatch) { +      // Potentially, move the increment. We have made sure in +      // isExpandedAddRecExprPHI or hoistIVInc that this is possible. +      if (L == IVIncInsertLoop) +        hoistBeforePos(&SE.DT, IncV, IVIncInsertPos, AddRecPhiMatch); + +      // Ok, the add recurrence looks usable. +      // Remember this PHI, even in post-inc mode. +      InsertedValues.insert(AddRecPhiMatch); +      // Remember the increment. +      rememberInstruction(IncV); +      return AddRecPhiMatch; +    } +  } + +  // Save the original insertion point so we can restore it when we're done. +  SCEVInsertPointGuard Guard(Builder, this); + +  // Another AddRec may need to be recursively expanded below. For example, if +  // this AddRec is quadratic, the StepV may itself be an AddRec in this +  // loop. Remove this loop from the PostIncLoops set before expanding such +  // AddRecs. Otherwise, we cannot find a valid position for the step +  // (i.e. StepV can never dominate its loop header).  Ideally, we could do +  // SavedIncLoops.swap(PostIncLoops), but we generally have a single element, +  // so it's not worth implementing SmallPtrSet::swap. +  PostIncLoopSet SavedPostIncLoops = PostIncLoops; +  PostIncLoops.clear(); + +  // Expand code for the start value into the loop preheader. +  assert(L->getLoopPreheader() && +         "Can't expand add recurrences without a loop preheader!"); +  Value *StartV = expandCodeFor(Normalized->getStart(), ExpandTy, +                                L->getLoopPreheader()->getTerminator()); + +  // StartV must have been be inserted into L's preheader to dominate the new +  // phi. +  assert(!isa<Instruction>(StartV) || +         SE.DT.properlyDominates(cast<Instruction>(StartV)->getParent(), +                                 L->getHeader())); + +  // Expand code for the step value. Do this before creating the PHI so that PHI +  // reuse code doesn't see an incomplete PHI. +  const SCEV *Step = Normalized->getStepRecurrence(SE); +  // If the stride is negative, insert a sub instead of an add for the increment +  // (unless it's a constant, because subtracts of constants are canonicalized +  // to adds). +  bool useSubtract = !ExpandTy->isPointerTy() && Step->isNonConstantNegative(); +  if (useSubtract) +    Step = SE.getNegativeSCEV(Step); +  // Expand the step somewhere that dominates the loop header. +  Value *StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front()); + +  // The no-wrap behavior proved by IsIncrement(NUW|NSW) is only applicable if +  // we actually do emit an addition.  It does not apply if we emit a +  // subtraction. +  bool IncrementIsNUW = !useSubtract && IsIncrementNUW(SE, Normalized); +  bool IncrementIsNSW = !useSubtract && IsIncrementNSW(SE, Normalized); + +  // Create the PHI. +  BasicBlock *Header = L->getHeader(); +  Builder.SetInsertPoint(Header, Header->begin()); +  pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); +  PHINode *PN = Builder.CreatePHI(ExpandTy, std::distance(HPB, HPE), +                                  Twine(IVName) + ".iv"); +  rememberInstruction(PN); + +  // Create the step instructions and populate the PHI. +  for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { +    BasicBlock *Pred = *HPI; + +    // Add a start value. +    if (!L->contains(Pred)) { +      PN->addIncoming(StartV, Pred); +      continue; +    } + +    // Create a step value and add it to the PHI. +    // If IVIncInsertLoop is non-null and equal to the addrec's loop, insert the +    // instructions at IVIncInsertPos. +    Instruction *InsertPos = L == IVIncInsertLoop ? +      IVIncInsertPos : Pred->getTerminator(); +    Builder.SetInsertPoint(InsertPos); +    Value *IncV = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); + +    if (isa<OverflowingBinaryOperator>(IncV)) { +      if (IncrementIsNUW) +        cast<BinaryOperator>(IncV)->setHasNoUnsignedWrap(); +      if (IncrementIsNSW) +        cast<BinaryOperator>(IncV)->setHasNoSignedWrap(); +    } +    PN->addIncoming(IncV, Pred); +  } + +  // After expanding subexpressions, restore the PostIncLoops set so the caller +  // can ensure that IVIncrement dominates the current uses. +  PostIncLoops = SavedPostIncLoops; + +  // Remember this PHI, even in post-inc mode. +  InsertedValues.insert(PN); + +  return PN; +} + +Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { +  Type *STy = S->getType(); +  Type *IntTy = SE.getEffectiveSCEVType(STy); +  const Loop *L = S->getLoop(); + +  // Determine a normalized form of this expression, which is the expression +  // before any post-inc adjustment is made. +  const SCEVAddRecExpr *Normalized = S; +  if (PostIncLoops.count(L)) { +    PostIncLoopSet Loops; +    Loops.insert(L); +    Normalized = cast<SCEVAddRecExpr>(normalizeForPostIncUse(S, Loops, SE)); +  } + +  // Strip off any non-loop-dominating component from the addrec start. +  const SCEV *Start = Normalized->getStart(); +  const SCEV *PostLoopOffset = nullptr; +  if (!SE.properlyDominates(Start, L->getHeader())) { +    PostLoopOffset = Start; +    Start = SE.getConstant(Normalized->getType(), 0); +    Normalized = cast<SCEVAddRecExpr>( +      SE.getAddRecExpr(Start, Normalized->getStepRecurrence(SE), +                       Normalized->getLoop(), +                       Normalized->getNoWrapFlags(SCEV::FlagNW))); +  } + +  // Strip off any non-loop-dominating component from the addrec step. +  const SCEV *Step = Normalized->getStepRecurrence(SE); +  const SCEV *PostLoopScale = nullptr; +  if (!SE.dominates(Step, L->getHeader())) { +    PostLoopScale = Step; +    Step = SE.getConstant(Normalized->getType(), 1); +    if (!Start->isZero()) { +        // The normalization below assumes that Start is constant zero, so if +        // it isn't re-associate Start to PostLoopOffset. +        assert(!PostLoopOffset && "Start not-null but PostLoopOffset set?"); +        PostLoopOffset = Start; +        Start = SE.getConstant(Normalized->getType(), 0); +    } +    Normalized = +      cast<SCEVAddRecExpr>(SE.getAddRecExpr( +                             Start, Step, Normalized->getLoop(), +                             Normalized->getNoWrapFlags(SCEV::FlagNW))); +  } + +  // Expand the core addrec. If we need post-loop scaling, force it to +  // expand to an integer type to avoid the need for additional casting. +  Type *ExpandTy = PostLoopScale ? IntTy : STy; +  // We can't use a pointer type for the addrec if the pointer type is +  // non-integral. +  Type *AddRecPHIExpandTy = +      DL.isNonIntegralPointerType(STy) ? Normalized->getType() : ExpandTy; + +  // In some cases, we decide to reuse an existing phi node but need to truncate +  // it and/or invert the step. +  Type *TruncTy = nullptr; +  bool InvertStep = false; +  PHINode *PN = getAddRecExprPHILiterally(Normalized, L, AddRecPHIExpandTy, +                                          IntTy, TruncTy, InvertStep); + +  // Accommodate post-inc mode, if necessary. +  Value *Result; +  if (!PostIncLoops.count(L)) +    Result = PN; +  else { +    // In PostInc mode, use the post-incremented value. +    BasicBlock *LatchBlock = L->getLoopLatch(); +    assert(LatchBlock && "PostInc mode requires a unique loop latch!"); +    Result = PN->getIncomingValueForBlock(LatchBlock); + +    // For an expansion to use the postinc form, the client must call +    // expandCodeFor with an InsertPoint that is either outside the PostIncLoop +    // or dominated by IVIncInsertPos. +    if (isa<Instruction>(Result) && +        !SE.DT.dominates(cast<Instruction>(Result), +                         &*Builder.GetInsertPoint())) { +      // The induction variable's postinc expansion does not dominate this use. +      // IVUsers tries to prevent this case, so it is rare. However, it can +      // happen when an IVUser outside the loop is not dominated by the latch +      // block. Adjusting IVIncInsertPos before expansion begins cannot handle +      // all cases. Consider a phi outside whose operand is replaced during +      // expansion with the value of the postinc user. Without fundamentally +      // changing the way postinc users are tracked, the only remedy is +      // inserting an extra IV increment. StepV might fold into PostLoopOffset, +      // but hopefully expandCodeFor handles that. +      bool useSubtract = +        !ExpandTy->isPointerTy() && Step->isNonConstantNegative(); +      if (useSubtract) +        Step = SE.getNegativeSCEV(Step); +      Value *StepV; +      { +        // Expand the step somewhere that dominates the loop header. +        SCEVInsertPointGuard Guard(Builder, this); +        StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front()); +      } +      Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); +    } +  } + +  // We have decided to reuse an induction variable of a dominating loop. Apply +  // truncation and/or inversion of the step. +  if (TruncTy) { +    Type *ResTy = Result->getType(); +    // Normalize the result type. +    if (ResTy != SE.getEffectiveSCEVType(ResTy)) +      Result = InsertNoopCastOfTo(Result, SE.getEffectiveSCEVType(ResTy)); +    // Truncate the result. +    if (TruncTy != Result->getType()) { +      Result = Builder.CreateTrunc(Result, TruncTy); +      rememberInstruction(Result); +    } +    // Invert the result. +    if (InvertStep) { +      Result = Builder.CreateSub(expandCodeFor(Normalized->getStart(), TruncTy), +                                 Result); +      rememberInstruction(Result); +    } +  } + +  // Re-apply any non-loop-dominating scale. +  if (PostLoopScale) { +    assert(S->isAffine() && "Can't linearly scale non-affine recurrences."); +    Result = InsertNoopCastOfTo(Result, IntTy); +    Result = Builder.CreateMul(Result, +                               expandCodeFor(PostLoopScale, IntTy)); +    rememberInstruction(Result); +  } + +  // Re-apply any non-loop-dominating offset. +  if (PostLoopOffset) { +    if (PointerType *PTy = dyn_cast<PointerType>(ExpandTy)) { +      if (Result->getType()->isIntegerTy()) { +        Value *Base = expandCodeFor(PostLoopOffset, ExpandTy); +        Result = expandAddToGEP(SE.getUnknown(Result), PTy, IntTy, Base); +      } else { +        Result = expandAddToGEP(PostLoopOffset, PTy, IntTy, Result); +      } +    } else { +      Result = InsertNoopCastOfTo(Result, IntTy); +      Result = Builder.CreateAdd(Result, +                                 expandCodeFor(PostLoopOffset, IntTy)); +      rememberInstruction(Result); +    } +  } + +  return Result; +} + +Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { +  if (!CanonicalMode) return expandAddRecExprLiterally(S); + +  Type *Ty = SE.getEffectiveSCEVType(S->getType()); +  const Loop *L = S->getLoop(); + +  // First check for an existing canonical IV in a suitable type. +  PHINode *CanonicalIV = nullptr; +  if (PHINode *PN = L->getCanonicalInductionVariable()) +    if (SE.getTypeSizeInBits(PN->getType()) >= SE.getTypeSizeInBits(Ty)) +      CanonicalIV = PN; + +  // Rewrite an AddRec in terms of the canonical induction variable, if +  // its type is more narrow. +  if (CanonicalIV && +      SE.getTypeSizeInBits(CanonicalIV->getType()) > +      SE.getTypeSizeInBits(Ty)) { +    SmallVector<const SCEV *, 4> NewOps(S->getNumOperands()); +    for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i) +      NewOps[i] = SE.getAnyExtendExpr(S->op_begin()[i], CanonicalIV->getType()); +    Value *V = expand(SE.getAddRecExpr(NewOps, S->getLoop(), +                                       S->getNoWrapFlags(SCEV::FlagNW))); +    BasicBlock::iterator NewInsertPt = +        findInsertPointAfter(cast<Instruction>(V), Builder.GetInsertBlock()); +    V = expandCodeFor(SE.getTruncateExpr(SE.getUnknown(V), Ty), nullptr, +                      &*NewInsertPt); +    return V; +  } + +  // {X,+,F} --> X + {0,+,F} +  if (!S->getStart()->isZero()) { +    SmallVector<const SCEV *, 4> NewOps(S->op_begin(), S->op_end()); +    NewOps[0] = SE.getConstant(Ty, 0); +    const SCEV *Rest = SE.getAddRecExpr(NewOps, L, +                                        S->getNoWrapFlags(SCEV::FlagNW)); + +    // Turn things like ptrtoint+arithmetic+inttoptr into GEP. See the +    // comments on expandAddToGEP for details. +    const SCEV *Base = S->getStart(); +    // Dig into the expression to find the pointer base for a GEP. +    const SCEV *ExposedRest = Rest; +    ExposePointerBase(Base, ExposedRest, SE); +    // If we found a pointer, expand the AddRec with a GEP. +    if (PointerType *PTy = dyn_cast<PointerType>(Base->getType())) { +      // Make sure the Base isn't something exotic, such as a multiplied +      // or divided pointer value. In those cases, the result type isn't +      // actually a pointer type. +      if (!isa<SCEVMulExpr>(Base) && !isa<SCEVUDivExpr>(Base)) { +        Value *StartV = expand(Base); +        assert(StartV->getType() == PTy && "Pointer type mismatch for GEP!"); +        return expandAddToGEP(ExposedRest, PTy, Ty, StartV); +      } +    } + +    // Just do a normal add. Pre-expand the operands to suppress folding. +    // +    // The LHS and RHS values are factored out of the expand call to make the +    // output independent of the argument evaluation order. +    const SCEV *AddExprLHS = SE.getUnknown(expand(S->getStart())); +    const SCEV *AddExprRHS = SE.getUnknown(expand(Rest)); +    return expand(SE.getAddExpr(AddExprLHS, AddExprRHS)); +  } + +  // If we don't yet have a canonical IV, create one. +  if (!CanonicalIV) { +    // Create and insert the PHI node for the induction variable in the +    // specified loop. +    BasicBlock *Header = L->getHeader(); +    pred_iterator HPB = pred_begin(Header), HPE = pred_end(Header); +    CanonicalIV = PHINode::Create(Ty, std::distance(HPB, HPE), "indvar", +                                  &Header->front()); +    rememberInstruction(CanonicalIV); + +    SmallSet<BasicBlock *, 4> PredSeen; +    Constant *One = ConstantInt::get(Ty, 1); +    for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { +      BasicBlock *HP = *HPI; +      if (!PredSeen.insert(HP).second) { +        // There must be an incoming value for each predecessor, even the +        // duplicates! +        CanonicalIV->addIncoming(CanonicalIV->getIncomingValueForBlock(HP), HP); +        continue; +      } + +      if (L->contains(HP)) { +        // Insert a unit add instruction right before the terminator +        // corresponding to the back-edge. +        Instruction *Add = BinaryOperator::CreateAdd(CanonicalIV, One, +                                                     "indvar.next", +                                                     HP->getTerminator()); +        Add->setDebugLoc(HP->getTerminator()->getDebugLoc()); +        rememberInstruction(Add); +        CanonicalIV->addIncoming(Add, HP); +      } else { +        CanonicalIV->addIncoming(Constant::getNullValue(Ty), HP); +      } +    } +  } + +  // {0,+,1} --> Insert a canonical induction variable into the loop! +  if (S->isAffine() && S->getOperand(1)->isOne()) { +    assert(Ty == SE.getEffectiveSCEVType(CanonicalIV->getType()) && +           "IVs with types different from the canonical IV should " +           "already have been handled!"); +    return CanonicalIV; +  } + +  // {0,+,F} --> {0,+,1} * F + +  // If this is a simple linear addrec, emit it now as a special case. +  if (S->isAffine())    // {0,+,F} --> i*F +    return +      expand(SE.getTruncateOrNoop( +        SE.getMulExpr(SE.getUnknown(CanonicalIV), +                      SE.getNoopOrAnyExtend(S->getOperand(1), +                                            CanonicalIV->getType())), +        Ty)); + +  // If this is a chain of recurrences, turn it into a closed form, using the +  // folders, then expandCodeFor the closed form.  This allows the folders to +  // simplify the expression without having to build a bunch of special code +  // into this folder. +  const SCEV *IH = SE.getUnknown(CanonicalIV);   // Get I as a "symbolic" SCEV. + +  // Promote S up to the canonical IV type, if the cast is foldable. +  const SCEV *NewS = S; +  const SCEV *Ext = SE.getNoopOrAnyExtend(S, CanonicalIV->getType()); +  if (isa<SCEVAddRecExpr>(Ext)) +    NewS = Ext; + +  const SCEV *V = cast<SCEVAddRecExpr>(NewS)->evaluateAtIteration(IH, SE); +  //cerr << "Evaluated: " << *this << "\n     to: " << *V << "\n"; + +  // Truncate the result down to the original type, if needed. +  const SCEV *T = SE.getTruncateOrNoop(V, Ty); +  return expand(T); +} + +Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) { +  Type *Ty = SE.getEffectiveSCEVType(S->getType()); +  Value *V = expandCodeFor(S->getOperand(), +                           SE.getEffectiveSCEVType(S->getOperand()->getType())); +  Value *I = Builder.CreateTrunc(V, Ty); +  rememberInstruction(I); +  return I; +} + +Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) { +  Type *Ty = SE.getEffectiveSCEVType(S->getType()); +  Value *V = expandCodeFor(S->getOperand(), +                           SE.getEffectiveSCEVType(S->getOperand()->getType())); +  Value *I = Builder.CreateZExt(V, Ty); +  rememberInstruction(I); +  return I; +} + +Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) { +  Type *Ty = SE.getEffectiveSCEVType(S->getType()); +  Value *V = expandCodeFor(S->getOperand(), +                           SE.getEffectiveSCEVType(S->getOperand()->getType())); +  Value *I = Builder.CreateSExt(V, Ty); +  rememberInstruction(I); +  return I; +} + +Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) { +  Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); +  Type *Ty = LHS->getType(); +  for (int i = S->getNumOperands()-2; i >= 0; --i) { +    // In the case of mixed integer and pointer types, do the +    // rest of the comparisons as integer. +    if (S->getOperand(i)->getType() != Ty) { +      Ty = SE.getEffectiveSCEVType(Ty); +      LHS = InsertNoopCastOfTo(LHS, Ty); +    } +    Value *RHS = expandCodeFor(S->getOperand(i), Ty); +    Value *ICmp = Builder.CreateICmpSGT(LHS, RHS); +    rememberInstruction(ICmp); +    Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "smax"); +    rememberInstruction(Sel); +    LHS = Sel; +  } +  // In the case of mixed integer and pointer types, cast the +  // final result back to the pointer type. +  if (LHS->getType() != S->getType()) +    LHS = InsertNoopCastOfTo(LHS, S->getType()); +  return LHS; +} + +Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) { +  Value *LHS = expand(S->getOperand(S->getNumOperands()-1)); +  Type *Ty = LHS->getType(); +  for (int i = S->getNumOperands()-2; i >= 0; --i) { +    // In the case of mixed integer and pointer types, do the +    // rest of the comparisons as integer. +    if (S->getOperand(i)->getType() != Ty) { +      Ty = SE.getEffectiveSCEVType(Ty); +      LHS = InsertNoopCastOfTo(LHS, Ty); +    } +    Value *RHS = expandCodeFor(S->getOperand(i), Ty); +    Value *ICmp = Builder.CreateICmpUGT(LHS, RHS); +    rememberInstruction(ICmp); +    Value *Sel = Builder.CreateSelect(ICmp, LHS, RHS, "umax"); +    rememberInstruction(Sel); +    LHS = Sel; +  } +  // In the case of mixed integer and pointer types, cast the +  // final result back to the pointer type. +  if (LHS->getType() != S->getType()) +    LHS = InsertNoopCastOfTo(LHS, S->getType()); +  return LHS; +} + +Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty, +                                   Instruction *IP) { +  setInsertPoint(IP); +  return expandCodeFor(SH, Ty); +} + +Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) { +  // Expand the code for this SCEV. +  Value *V = expand(SH); +  if (Ty) { +    assert(SE.getTypeSizeInBits(Ty) == SE.getTypeSizeInBits(SH->getType()) && +           "non-trivial casts should be done with the SCEVs directly!"); +    V = InsertNoopCastOfTo(V, Ty); +  } +  return V; +} + +ScalarEvolution::ValueOffsetPair +SCEVExpander::FindValueInExprValueMap(const SCEV *S, +                                      const Instruction *InsertPt) { +  SetVector<ScalarEvolution::ValueOffsetPair> *Set = SE.getSCEVValues(S); +  // If the expansion is not in CanonicalMode, and the SCEV contains any +  // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally. +  if (CanonicalMode || !SE.containsAddRecurrence(S)) { +    // If S is scConstant, it may be worse to reuse an existing Value. +    if (S->getSCEVType() != scConstant && Set) { +      // Choose a Value from the set which dominates the insertPt. +      // insertPt should be inside the Value's parent loop so as not to break +      // the LCSSA form. +      for (auto const &VOPair : *Set) { +        Value *V = VOPair.first; +        ConstantInt *Offset = VOPair.second; +        Instruction *EntInst = nullptr; +        if (V && isa<Instruction>(V) && (EntInst = cast<Instruction>(V)) && +            S->getType() == V->getType() && +            EntInst->getFunction() == InsertPt->getFunction() && +            SE.DT.dominates(EntInst, InsertPt) && +            (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || +             SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) +          return {V, Offset}; +      } +    } +  } +  return {nullptr, nullptr}; +} + +// The expansion of SCEV will either reuse a previous Value in ExprValueMap, +// or expand the SCEV literally. Specifically, if the expansion is in LSRMode, +// and the SCEV contains any sub scAddRecExpr type SCEV, it will be expanded +// literally, to prevent LSR's transformed SCEV from being reverted. Otherwise, +// the expansion will try to reuse Value from ExprValueMap, and only when it +// fails, expand the SCEV literally. +Value *SCEVExpander::expand(const SCEV *S) { +  // Compute an insertion point for this SCEV object. Hoist the instructions +  // as far out in the loop nest as possible. +  Instruction *InsertPt = &*Builder.GetInsertPoint(); +  for (Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock());; +       L = L->getParentLoop()) +    if (SE.isLoopInvariant(S, L)) { +      if (!L) break; +      if (BasicBlock *Preheader = L->getLoopPreheader()) +        InsertPt = Preheader->getTerminator(); +      else { +        // LSR sets the insertion point for AddRec start/step values to the +        // block start to simplify value reuse, even though it's an invalid +        // position. SCEVExpander must correct for this in all cases. +        InsertPt = &*L->getHeader()->getFirstInsertionPt(); +      } +    } else { +      // We can move insertion point only if there is no div or rem operations +      // otherwise we are risky to move it over the check for zero denominator. +      auto SafeToHoist = [](const SCEV *S) { +        return !SCEVExprContains(S, [](const SCEV *S) { +                  if (const auto *D = dyn_cast<SCEVUDivExpr>(S)) { +                    if (const auto *SC = dyn_cast<SCEVConstant>(D->getRHS())) +                      // Division by non-zero constants can be hoisted. +                      return SC->getValue()->isZero(); +                    // All other divisions should not be moved as they may be +                    // divisions by zero and should be kept within the +                    // conditions of the surrounding loops that guard their +                    // execution (see PR35406). +                    return true; +                  } +                  return false; +                }); +      }; +      // If the SCEV is computable at this level, insert it into the header +      // after the PHIs (and after any other instructions that we've inserted +      // there) so that it is guaranteed to dominate any user inside the loop. +      if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L) && +          SafeToHoist(S)) +        InsertPt = &*L->getHeader()->getFirstInsertionPt(); +      while (InsertPt->getIterator() != Builder.GetInsertPoint() && +             (isInsertedInstruction(InsertPt) || +              isa<DbgInfoIntrinsic>(InsertPt))) { +        InsertPt = &*std::next(InsertPt->getIterator()); +      } +      break; +    } + +  // Check to see if we already expanded this here. +  auto I = InsertedExpressions.find(std::make_pair(S, InsertPt)); +  if (I != InsertedExpressions.end()) +    return I->second; + +  SCEVInsertPointGuard Guard(Builder, this); +  Builder.SetInsertPoint(InsertPt); + +  // Expand the expression into instructions. +  ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, InsertPt); +  Value *V = VO.first; + +  if (!V) +    V = visit(S); +  else if (VO.second) { +    if (PointerType *Vty = dyn_cast<PointerType>(V->getType())) { +      Type *Ety = Vty->getPointerElementType(); +      int64_t Offset = VO.second->getSExtValue(); +      int64_t ESize = SE.getTypeSizeInBits(Ety); +      if ((Offset * 8) % ESize == 0) { +        ConstantInt *Idx = +            ConstantInt::getSigned(VO.second->getType(), -(Offset * 8) / ESize); +        V = Builder.CreateGEP(Ety, V, Idx, "scevgep"); +      } else { +        ConstantInt *Idx = +            ConstantInt::getSigned(VO.second->getType(), -Offset); +        unsigned AS = Vty->getAddressSpace(); +        V = Builder.CreateBitCast(V, Type::getInt8PtrTy(SE.getContext(), AS)); +        V = Builder.CreateGEP(Type::getInt8Ty(SE.getContext()), V, Idx, +                              "uglygep"); +        V = Builder.CreateBitCast(V, Vty); +      } +    } else { +      V = Builder.CreateSub(V, VO.second); +    } +  } +  // Remember the expanded value for this SCEV at this location. +  // +  // This is independent of PostIncLoops. The mapped value simply materializes +  // the expression at this insertion point. If the mapped value happened to be +  // a postinc expansion, it could be reused by a non-postinc user, but only if +  // its insertion point was already at the head of the loop. +  InsertedExpressions[std::make_pair(S, InsertPt)] = V; +  return V; +} + +void SCEVExpander::rememberInstruction(Value *I) { +  if (!PostIncLoops.empty()) +    InsertedPostIncValues.insert(I); +  else +    InsertedValues.insert(I); +} + +/// getOrInsertCanonicalInductionVariable - This method returns the +/// canonical induction variable of the specified type for the specified +/// loop (inserting one if there is none).  A canonical induction variable +/// starts at zero and steps by one on each iteration. +PHINode * +SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L, +                                                    Type *Ty) { +  assert(Ty->isIntegerTy() && "Can only insert integer induction variables!"); + +  // Build a SCEV for {0,+,1}<L>. +  // Conservatively use FlagAnyWrap for now. +  const SCEV *H = SE.getAddRecExpr(SE.getConstant(Ty, 0), +                                   SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap); + +  // Emit code for it. +  SCEVInsertPointGuard Guard(Builder, this); +  PHINode *V = +      cast<PHINode>(expandCodeFor(H, nullptr, &L->getHeader()->front())); + +  return V; +} + +/// replaceCongruentIVs - Check for congruent phis in this loop header and +/// replace them with their most canonical representative. Return the number of +/// phis eliminated. +/// +/// This does not depend on any SCEVExpander state but should be used in +/// the same context that SCEVExpander is used. +unsigned +SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, +                                  SmallVectorImpl<WeakTrackingVH> &DeadInsts, +                                  const TargetTransformInfo *TTI) { +  // Find integer phis in order of increasing width. +  SmallVector<PHINode*, 8> Phis; +  for (PHINode &PN : L->getHeader()->phis()) +    Phis.push_back(&PN); + +  if (TTI) +    llvm::sort(Phis.begin(), Phis.end(), [](Value *LHS, Value *RHS) { +      // Put pointers at the back and make sure pointer < pointer = false. +      if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) +        return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy(); +      return RHS->getType()->getPrimitiveSizeInBits() < +             LHS->getType()->getPrimitiveSizeInBits(); +    }); + +  unsigned NumElim = 0; +  DenseMap<const SCEV *, PHINode *> ExprToIVMap; +  // Process phis from wide to narrow. Map wide phis to their truncation +  // so narrow phis can reuse them. +  for (PHINode *Phi : Phis) { +    auto SimplifyPHINode = [&](PHINode *PN) -> Value * { +      if (Value *V = SimplifyInstruction(PN, {DL, &SE.TLI, &SE.DT, &SE.AC})) +        return V; +      if (!SE.isSCEVable(PN->getType())) +        return nullptr; +      auto *Const = dyn_cast<SCEVConstant>(SE.getSCEV(PN)); +      if (!Const) +        return nullptr; +      return Const->getValue(); +    }; + +    // Fold constant phis. They may be congruent to other constant phis and +    // would confuse the logic below that expects proper IVs. +    if (Value *V = SimplifyPHINode(Phi)) { +      if (V->getType() != Phi->getType()) +        continue; +      Phi->replaceAllUsesWith(V); +      DeadInsts.emplace_back(Phi); +      ++NumElim; +      DEBUG_WITH_TYPE(DebugType, dbgs() +                      << "INDVARS: Eliminated constant iv: " << *Phi << '\n'); +      continue; +    } + +    if (!SE.isSCEVable(Phi->getType())) +      continue; + +    PHINode *&OrigPhiRef = ExprToIVMap[SE.getSCEV(Phi)]; +    if (!OrigPhiRef) { +      OrigPhiRef = Phi; +      if (Phi->getType()->isIntegerTy() && TTI && +          TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) { +        // This phi can be freely truncated to the narrowest phi type. Map the +        // truncated expression to it so it will be reused for narrow types. +        const SCEV *TruncExpr = +          SE.getTruncateExpr(SE.getSCEV(Phi), Phis.back()->getType()); +        ExprToIVMap[TruncExpr] = Phi; +      } +      continue; +    } + +    // Replacing a pointer phi with an integer phi or vice-versa doesn't make +    // sense. +    if (OrigPhiRef->getType()->isPointerTy() != Phi->getType()->isPointerTy()) +      continue; + +    if (BasicBlock *LatchBlock = L->getLoopLatch()) { +      Instruction *OrigInc = dyn_cast<Instruction>( +          OrigPhiRef->getIncomingValueForBlock(LatchBlock)); +      Instruction *IsomorphicInc = +          dyn_cast<Instruction>(Phi->getIncomingValueForBlock(LatchBlock)); + +      if (OrigInc && IsomorphicInc) { +        // If this phi has the same width but is more canonical, replace the +        // original with it. As part of the "more canonical" determination, +        // respect a prior decision to use an IV chain. +        if (OrigPhiRef->getType() == Phi->getType() && +            !(ChainedPhis.count(Phi) || +              isExpandedAddRecExprPHI(OrigPhiRef, OrigInc, L)) && +            (ChainedPhis.count(Phi) || +             isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) { +          std::swap(OrigPhiRef, Phi); +          std::swap(OrigInc, IsomorphicInc); +        } +        // Replacing the congruent phi is sufficient because acyclic +        // redundancy elimination, CSE/GVN, should handle the +        // rest. However, once SCEV proves that a phi is congruent, +        // it's often the head of an IV user cycle that is isomorphic +        // with the original phi. It's worth eagerly cleaning up the +        // common case of a single IV increment so that DeleteDeadPHIs +        // can remove cycles that had postinc uses. +        const SCEV *TruncExpr = +            SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType()); +        if (OrigInc != IsomorphicInc && +            TruncExpr == SE.getSCEV(IsomorphicInc) && +            SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc) && +            hoistIVInc(OrigInc, IsomorphicInc)) { +          DEBUG_WITH_TYPE(DebugType, +                          dbgs() << "INDVARS: Eliminated congruent iv.inc: " +                                 << *IsomorphicInc << '\n'); +          Value *NewInc = OrigInc; +          if (OrigInc->getType() != IsomorphicInc->getType()) { +            Instruction *IP = nullptr; +            if (PHINode *PN = dyn_cast<PHINode>(OrigInc)) +              IP = &*PN->getParent()->getFirstInsertionPt(); +            else +              IP = OrigInc->getNextNode(); + +            IRBuilder<> Builder(IP); +            Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); +            NewInc = Builder.CreateTruncOrBitCast( +                OrigInc, IsomorphicInc->getType(), IVName); +          } +          IsomorphicInc->replaceAllUsesWith(NewInc); +          DeadInsts.emplace_back(IsomorphicInc); +        } +      } +    } +    DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated congruent iv: " +                                      << *Phi << '\n'); +    ++NumElim; +    Value *NewIV = OrigPhiRef; +    if (OrigPhiRef->getType() != Phi->getType()) { +      IRBuilder<> Builder(&*L->getHeader()->getFirstInsertionPt()); +      Builder.SetCurrentDebugLocation(Phi->getDebugLoc()); +      NewIV = Builder.CreateTruncOrBitCast(OrigPhiRef, Phi->getType(), IVName); +    } +    Phi->replaceAllUsesWith(NewIV); +    DeadInsts.emplace_back(Phi); +  } +  return NumElim; +} + +Value *SCEVExpander::getExactExistingExpansion(const SCEV *S, +                                               const Instruction *At, Loop *L) { +  Optional<ScalarEvolution::ValueOffsetPair> VO = +      getRelatedExistingExpansion(S, At, L); +  if (VO && VO.getValue().second == nullptr) +    return VO.getValue().first; +  return nullptr; +} + +Optional<ScalarEvolution::ValueOffsetPair> +SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At, +                                          Loop *L) { +  using namespace llvm::PatternMatch; + +  SmallVector<BasicBlock *, 4> ExitingBlocks; +  L->getExitingBlocks(ExitingBlocks); + +  // Look for suitable value in simple conditions at the loop exits. +  for (BasicBlock *BB : ExitingBlocks) { +    ICmpInst::Predicate Pred; +    Instruction *LHS, *RHS; +    BasicBlock *TrueBB, *FalseBB; + +    if (!match(BB->getTerminator(), +               m_Br(m_ICmp(Pred, m_Instruction(LHS), m_Instruction(RHS)), +                    TrueBB, FalseBB))) +      continue; + +    if (SE.getSCEV(LHS) == S && SE.DT.dominates(LHS, At)) +      return ScalarEvolution::ValueOffsetPair(LHS, nullptr); + +    if (SE.getSCEV(RHS) == S && SE.DT.dominates(RHS, At)) +      return ScalarEvolution::ValueOffsetPair(RHS, nullptr); +  } + +  // Use expand's logic which is used for reusing a previous Value in +  // ExprValueMap. +  ScalarEvolution::ValueOffsetPair VO = FindValueInExprValueMap(S, At); +  if (VO.first) +    return VO; + +  // There is potential to make this significantly smarter, but this simple +  // heuristic already gets some interesting cases. + +  // Can not find suitable value. +  return None; +} + +bool SCEVExpander::isHighCostExpansionHelper( +    const SCEV *S, Loop *L, const Instruction *At, +    SmallPtrSetImpl<const SCEV *> &Processed) { + +  // If we can find an existing value for this scev available at the point "At" +  // then consider the expression cheap. +  if (At && getRelatedExistingExpansion(S, At, L)) +    return false; + +  // Zero/One operand expressions +  switch (S->getSCEVType()) { +  case scUnknown: +  case scConstant: +    return false; +  case scTruncate: +    return isHighCostExpansionHelper(cast<SCEVTruncateExpr>(S)->getOperand(), +                                     L, At, Processed); +  case scZeroExtend: +    return isHighCostExpansionHelper(cast<SCEVZeroExtendExpr>(S)->getOperand(), +                                     L, At, Processed); +  case scSignExtend: +    return isHighCostExpansionHelper(cast<SCEVSignExtendExpr>(S)->getOperand(), +                                     L, At, Processed); +  } + +  if (!Processed.insert(S).second) +    return false; + +  if (auto *UDivExpr = dyn_cast<SCEVUDivExpr>(S)) { +    // If the divisor is a power of two and the SCEV type fits in a native +    // integer, consider the division cheap irrespective of whether it occurs in +    // the user code since it can be lowered into a right shift. +    if (auto *SC = dyn_cast<SCEVConstant>(UDivExpr->getRHS())) +      if (SC->getAPInt().isPowerOf2()) { +        const DataLayout &DL = +            L->getHeader()->getParent()->getParent()->getDataLayout(); +        unsigned Width = cast<IntegerType>(UDivExpr->getType())->getBitWidth(); +        return DL.isIllegalInteger(Width); +      } + +    // UDivExpr is very likely a UDiv that ScalarEvolution's HowFarToZero or +    // HowManyLessThans produced to compute a precise expression, rather than a +    // UDiv from the user's code. If we can't find a UDiv in the code with some +    // simple searching, assume the former consider UDivExpr expensive to +    // compute. +    BasicBlock *ExitingBB = L->getExitingBlock(); +    if (!ExitingBB) +      return true; + +    // At the beginning of this function we already tried to find existing value +    // for plain 'S'. Now try to lookup 'S + 1' since it is common pattern +    // involving division. This is just a simple search heuristic. +    if (!At) +      At = &ExitingBB->back(); +    if (!getRelatedExistingExpansion( +            SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), At, L)) +      return true; +  } + +  // HowManyLessThans uses a Max expression whenever the loop is not guarded by +  // the exit condition. +  if (isa<SCEVSMaxExpr>(S) || isa<SCEVUMaxExpr>(S)) +    return true; + +  // Recurse past nary expressions, which commonly occur in the +  // BackedgeTakenCount. They may already exist in program code, and if not, +  // they are not too expensive rematerialize. +  if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(S)) { +    for (auto *Op : NAry->operands()) +      if (isHighCostExpansionHelper(Op, L, At, Processed)) +        return true; +  } + +  // If we haven't recognized an expensive SCEV pattern, assume it's an +  // expression produced by program code. +  return false; +} + +Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, +                                            Instruction *IP) { +  assert(IP); +  switch (Pred->getKind()) { +  case SCEVPredicate::P_Union: +    return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP); +  case SCEVPredicate::P_Equal: +    return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP); +  case SCEVPredicate::P_Wrap: { +    auto *AddRecPred = cast<SCEVWrapPredicate>(Pred); +    return expandWrapPredicate(AddRecPred, IP); +  } +  } +  llvm_unreachable("Unknown SCEV predicate type"); +} + +Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, +                                          Instruction *IP) { +  Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP); +  Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP); + +  Builder.SetInsertPoint(IP); +  auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); +  return I; +} + +Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, +                                           Instruction *Loc, bool Signed) { +  assert(AR->isAffine() && "Cannot generate RT check for " +                           "non-affine expression"); + +  SCEVUnionPredicate Pred; +  const SCEV *ExitCount = +      SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); + +  assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + +  const SCEV *Step = AR->getStepRecurrence(SE); +  const SCEV *Start = AR->getStart(); + +  Type *ARTy = AR->getType(); +  unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); +  unsigned DstBits = SE.getTypeSizeInBits(ARTy); + +  // The expression {Start,+,Step} has nusw/nssw if +  //   Step < 0, Start - |Step| * Backedge <= Start +  //   Step >= 0, Start + |Step| * Backedge > Start +  // and |Step| * Backedge doesn't unsigned overflow. + +  IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); +  Builder.SetInsertPoint(Loc); +  Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc); + +  IntegerType *Ty = +      IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(ARTy)); +  Type *ARExpandTy = DL.isNonIntegralPointerType(ARTy) ? ARTy : Ty; + +  Value *StepValue = expandCodeFor(Step, Ty, Loc); +  Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc); +  Value *StartValue = expandCodeFor(Start, ARExpandTy, Loc); + +  ConstantInt *Zero = +      ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); + +  Builder.SetInsertPoint(Loc); +  // Compute |Step| +  Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero); +  Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue); + +  // Get the backedge taken count and truncate or extended to the AR type. +  Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); +  auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), +                                         Intrinsic::umul_with_overflow, Ty); + +  // Compute |Step| * Backedge +  CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); +  Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); +  Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); + +  // Compute: +  //   Start + |Step| * Backedge < Start +  //   Start - |Step| * Backedge > Start +  Value *Add = nullptr, *Sub = nullptr; +  if (PointerType *ARPtrTy = dyn_cast<PointerType>(ARExpandTy)) { +    const SCEV *MulS = SE.getSCEV(MulV); +    const SCEV *NegMulS = SE.getNegativeSCEV(MulS); +    Add = Builder.CreateBitCast(expandAddToGEP(MulS, ARPtrTy, Ty, StartValue), +                                ARPtrTy); +    Sub = Builder.CreateBitCast( +        expandAddToGEP(NegMulS, ARPtrTy, Ty, StartValue), ARPtrTy); +  } else { +    Add = Builder.CreateAdd(StartValue, MulV); +    Sub = Builder.CreateSub(StartValue, MulV); +  } + +  Value *EndCompareGT = Builder.CreateICmp( +      Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); + +  Value *EndCompareLT = Builder.CreateICmp( +      Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); + +  // Select the answer based on the sign of Step. +  Value *EndCheck = +      Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); + +  // If the backedge taken count type is larger than the AR type, +  // check that we don't drop any bits by truncating it. If we are +  // dropping bits, then we have overflow (unless the step is zero). +  if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) { +    auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits); +    auto *BackedgeCheck = +        Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal, +                           ConstantInt::get(Loc->getContext(), MaxVal)); +    BackedgeCheck = Builder.CreateAnd( +        BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero)); + +    EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck); +  } + +  EndCheck = Builder.CreateOr(EndCheck, OfMul); +  return EndCheck; +} + +Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, +                                         Instruction *IP) { +  const auto *A = cast<SCEVAddRecExpr>(Pred->getExpr()); +  Value *NSSWCheck = nullptr, *NUSWCheck = nullptr; + +  // Add a check for NUSW +  if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW) +    NUSWCheck = generateOverflowCheck(A, IP, false); + +  // Add a check for NSSW +  if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW) +    NSSWCheck = generateOverflowCheck(A, IP, true); + +  if (NUSWCheck && NSSWCheck) +    return Builder.CreateOr(NUSWCheck, NSSWCheck); + +  if (NUSWCheck) +    return NUSWCheck; + +  if (NSSWCheck) +    return NSSWCheck; + +  return ConstantInt::getFalse(IP->getContext()); +} + +Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, +                                          Instruction *IP) { +  auto *BoolType = IntegerType::get(IP->getContext(), 1); +  Value *Check = ConstantInt::getNullValue(BoolType); + +  // Loop over all checks in this set. +  for (auto Pred : Union->getPredicates()) { +    auto *NextCheck = expandCodeForPredicate(Pred, IP); +    Builder.SetInsertPoint(IP); +    Check = Builder.CreateOr(Check, NextCheck); +  } + +  return Check; +} + +namespace { +// Search for a SCEV subexpression that is not safe to expand.  Any expression +// that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely +// UDiv expressions. We don't know if the UDiv is derived from an IR divide +// instruction, but the important thing is that we prove the denominator is +// nonzero before expansion. +// +// IVUsers already checks that IV-derived expressions are safe. So this check is +// only needed when the expression includes some subexpression that is not IV +// derived. +// +// Currently, we only allow division by a nonzero constant here. If this is +// inadequate, we could easily allow division by SCEVUnknown by using +// ValueTracking to check isKnownNonZero(). +// +// We cannot generally expand recurrences unless the step dominates the loop +// header. The expander handles the special case of affine recurrences by +// scaling the recurrence outside the loop, but this technique isn't generally +// applicable. Expanding a nested recurrence outside a loop requires computing +// binomial coefficients. This could be done, but the recurrence has to be in a +// perfectly reduced form, which can't be guaranteed. +struct SCEVFindUnsafe { +  ScalarEvolution &SE; +  bool IsUnsafe; + +  SCEVFindUnsafe(ScalarEvolution &se): SE(se), IsUnsafe(false) {} + +  bool follow(const SCEV *S) { +    if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { +      const SCEVConstant *SC = dyn_cast<SCEVConstant>(D->getRHS()); +      if (!SC || SC->getValue()->isZero()) { +        IsUnsafe = true; +        return false; +      } +    } +    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { +      const SCEV *Step = AR->getStepRecurrence(SE); +      if (!AR->isAffine() && !SE.dominates(Step, AR->getLoop()->getHeader())) { +        IsUnsafe = true; +        return false; +      } +    } +    return true; +  } +  bool isDone() const { return IsUnsafe; } +}; +} + +namespace llvm { +bool isSafeToExpand(const SCEV *S, ScalarEvolution &SE) { +  SCEVFindUnsafe Search(SE); +  visitAll(S, Search); +  return !Search.IsUnsafe; +} + +bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint, +                      ScalarEvolution &SE) { +  return isSafeToExpand(S, SE) && SE.dominates(S, InsertionPoint->getParent()); +} +} diff --git a/contrib/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/contrib/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp new file mode 100644 index 000000000000..3740039b8f86 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -0,0 +1,118 @@ +//===- ScalarEvolutionNormalization.cpp - See below -----------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for working with "normalized" expressions. +// See the comments at the top of ScalarEvolutionNormalization.h for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScalarEvolutionNormalization.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +using namespace llvm; + +/// TransformKind - Different types of transformations that +/// TransformForPostIncUse can do. +enum TransformKind { +  /// Normalize - Normalize according to the given loops. +  Normalize, +  /// Denormalize - Perform the inverse transform on the expression with the +  /// given loop set. +  Denormalize +}; + +namespace { +struct NormalizeDenormalizeRewriter +    : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> { +  const TransformKind Kind; + +  // NB! Pred is a function_ref.  Storing it here is okay only because +  // we're careful about the lifetime of NormalizeDenormalizeRewriter. +  const NormalizePredTy Pred; + +  NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, +                               ScalarEvolution &SE) +      : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind), +        Pred(Pred) {} +  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); +}; +} // namespace + +const SCEV * +NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { +  SmallVector<const SCEV *, 8> Operands; + +  transform(AR->operands(), std::back_inserter(Operands), +            [&](const SCEV *Op) { return visit(Op); }); + +  if (!Pred(AR)) +    return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); + +  // Normalization and denormalization are fancy names for decrementing and +  // incrementing a SCEV expression with respect to a set of loops.  Since +  // Pred(AR) has returned true, we know we need to normalize or denormalize AR +  // with respect to its loop. + +  if (Kind == Denormalize) { +    // Denormalization / "partial increment" is essentially the same as \c +    // SCEVAddRecExpr::getPostIncExpr.  Here we use an explicit loop to make the +    // symmetry with Normalization clear. +    for (int i = 0, e = Operands.size() - 1; i < e; i++) +      Operands[i] = SE.getAddExpr(Operands[i], Operands[i + 1]); +  } else { +    assert(Kind == Normalize && "Only two possibilities!"); + +    // Normalization / "partial decrement" is a bit more subtle.  Since +    // incrementing a SCEV expression (in general) changes the step of the SCEV +    // expression as well, we cannot use the step of the current expression. +    // Instead, we have to use the step of the very expression we're trying to +    // compute! +    // +    // We solve the issue by recursively building up the result, starting from +    // the "least significant" operand in the add recurrence: +    // +    // Base case: +    //   Single operand add recurrence.  It's its own normalization. +    // +    // N-operand case: +    //   {S_{N-1},+,S_{N-2},+,...,+,S_0} = S +    // +    //   Since the step recurrence of S is {S_{N-2},+,...,+,S_0}, we know its +    //   normalization by induction.  We subtract the normalized step +    //   recurrence from S_{N-1} to get the normalization of S. + +    for (int i = Operands.size() - 2; i >= 0; i--) +      Operands[i] = SE.getMinusSCEV(Operands[i], Operands[i + 1]); +  } + +  return SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); +} + +const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, +                                         const PostIncLoopSet &Loops, +                                         ScalarEvolution &SE) { +  auto Pred = [&](const SCEVAddRecExpr *AR) { +    return Loops.count(AR->getLoop()); +  }; +  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); +} + +const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, +                                           ScalarEvolution &SE) { +  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); +} + +const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, +                                           const PostIncLoopSet &Loops, +                                           ScalarEvolution &SE) { +  auto Pred = [&](const SCEVAddRecExpr *AR) { +    return Loops.count(AR->getLoop()); +  }; +  return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); +} diff --git a/contrib/llvm/lib/Analysis/ScopedNoAliasAA.cpp b/contrib/llvm/lib/Analysis/ScopedNoAliasAA.cpp new file mode 100644 index 000000000000..f12275aff387 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ScopedNoAliasAA.cpp @@ -0,0 +1,211 @@ +//===- ScopedNoAliasAA.cpp - Scoped No-Alias Alias Analysis ---------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the ScopedNoAlias alias-analysis pass, which implements +// metadata-based scoped no-alias support. +// +// Alias-analysis scopes are defined by an id (which can be a string or some +// other metadata node), a domain node, and an optional descriptive string. +// A domain is defined by an id (which can be a string or some other metadata +// node), and an optional descriptive string. +// +// !dom0 =   metadata !{ metadata !"domain of foo()" } +// !scope1 = metadata !{ metadata !scope1, metadata !dom0, metadata !"scope 1" } +// !scope2 = metadata !{ metadata !scope2, metadata !dom0, metadata !"scope 2" } +// +// Loads and stores can be tagged with an alias-analysis scope, and also, with +// a noalias tag for a specific scope: +// +// ... = load %ptr1, !alias.scope !{ !scope1 } +// ... = load %ptr2, !alias.scope !{ !scope1, !scope2 }, !noalias !{ !scope1 } +// +// When evaluating an aliasing query, if one of the instructions is associated +// has a set of noalias scopes in some domain that is a superset of the alias +// scopes in that domain of some other instruction, then the two memory +// accesses are assumed not to alias. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; + +// A handy option for disabling scoped no-alias functionality. The same effect +// can also be achieved by stripping the associated metadata tags from IR, but +// this option is sometimes more convenient. +static cl::opt<bool> EnableScopedNoAlias("enable-scoped-noalias", +                                         cl::init(true), cl::Hidden); + +namespace { + +/// This is a simple wrapper around an MDNode which provides a higher-level +/// interface by hiding the details of how alias analysis information is encoded +/// in its operands. +class AliasScopeNode { +  const MDNode *Node = nullptr; + +public: +  AliasScopeNode() = default; +  explicit AliasScopeNode(const MDNode *N) : Node(N) {} + +  /// Get the MDNode for this AliasScopeNode. +  const MDNode *getNode() const { return Node; } + +  /// Get the MDNode for this AliasScopeNode's domain. +  const MDNode *getDomain() const { +    if (Node->getNumOperands() < 2) +      return nullptr; +    return dyn_cast_or_null<MDNode>(Node->getOperand(1)); +  } +}; + +} // end anonymous namespace + +AliasResult ScopedNoAliasAAResult::alias(const MemoryLocation &LocA, +                                         const MemoryLocation &LocB) { +  if (!EnableScopedNoAlias) +    return AAResultBase::alias(LocA, LocB); + +  // Get the attached MDNodes. +  const MDNode *AScopes = LocA.AATags.Scope, *BScopes = LocB.AATags.Scope; + +  const MDNode *ANoAlias = LocA.AATags.NoAlias, *BNoAlias = LocB.AATags.NoAlias; + +  if (!mayAliasInScopes(AScopes, BNoAlias)) +    return NoAlias; + +  if (!mayAliasInScopes(BScopes, ANoAlias)) +    return NoAlias; + +  // If they may alias, chain to the next AliasAnalysis. +  return AAResultBase::alias(LocA, LocB); +} + +ModRefInfo ScopedNoAliasAAResult::getModRefInfo(ImmutableCallSite CS, +                                                const MemoryLocation &Loc) { +  if (!EnableScopedNoAlias) +    return AAResultBase::getModRefInfo(CS, Loc); + +  if (!mayAliasInScopes(Loc.AATags.Scope, CS.getInstruction()->getMetadata( +                                              LLVMContext::MD_noalias))) +    return ModRefInfo::NoModRef; + +  if (!mayAliasInScopes( +          CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), +          Loc.AATags.NoAlias)) +    return ModRefInfo::NoModRef; + +  return AAResultBase::getModRefInfo(CS, Loc); +} + +ModRefInfo ScopedNoAliasAAResult::getModRefInfo(ImmutableCallSite CS1, +                                                ImmutableCallSite CS2) { +  if (!EnableScopedNoAlias) +    return AAResultBase::getModRefInfo(CS1, CS2); + +  if (!mayAliasInScopes( +          CS1.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), +          CS2.getInstruction()->getMetadata(LLVMContext::MD_noalias))) +    return ModRefInfo::NoModRef; + +  if (!mayAliasInScopes( +          CS2.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), +          CS1.getInstruction()->getMetadata(LLVMContext::MD_noalias))) +    return ModRefInfo::NoModRef; + +  return AAResultBase::getModRefInfo(CS1, CS2); +} + +static void collectMDInDomain(const MDNode *List, const MDNode *Domain, +                              SmallPtrSetImpl<const MDNode *> &Nodes) { +  for (const MDOperand &MDOp : List->operands()) +    if (const MDNode *MD = dyn_cast<MDNode>(MDOp)) +      if (AliasScopeNode(MD).getDomain() == Domain) +        Nodes.insert(MD); +} + +bool ScopedNoAliasAAResult::mayAliasInScopes(const MDNode *Scopes, +                                             const MDNode *NoAlias) const { +  if (!Scopes || !NoAlias) +    return true; + +  // Collect the set of scope domains relevant to the noalias scopes. +  SmallPtrSet<const MDNode *, 16> Domains; +  for (const MDOperand &MDOp : NoAlias->operands()) +    if (const MDNode *NAMD = dyn_cast<MDNode>(MDOp)) +      if (const MDNode *Domain = AliasScopeNode(NAMD).getDomain()) +        Domains.insert(Domain); + +  // We alias unless, for some domain, the set of noalias scopes in that domain +  // is a superset of the set of alias scopes in that domain. +  for (const MDNode *Domain : Domains) { +    SmallPtrSet<const MDNode *, 16> ScopeNodes; +    collectMDInDomain(Scopes, Domain, ScopeNodes); +    if (ScopeNodes.empty()) +      continue; + +    SmallPtrSet<const MDNode *, 16> NANodes; +    collectMDInDomain(NoAlias, Domain, NANodes); + +    // To not alias, all of the nodes in ScopeNodes must be in NANodes. +    bool FoundAll = true; +    for (const MDNode *SMD : ScopeNodes) +      if (!NANodes.count(SMD)) { +        FoundAll = false; +        break; +      } + +    if (FoundAll) +      return false; +  } + +  return true; +} + +AnalysisKey ScopedNoAliasAA::Key; + +ScopedNoAliasAAResult ScopedNoAliasAA::run(Function &F, +                                           FunctionAnalysisManager &AM) { +  return ScopedNoAliasAAResult(); +} + +char ScopedNoAliasAAWrapperPass::ID = 0; + +INITIALIZE_PASS(ScopedNoAliasAAWrapperPass, "scoped-noalias", +                "Scoped NoAlias Alias Analysis", false, true) + +ImmutablePass *llvm::createScopedNoAliasAAWrapperPass() { +  return new ScopedNoAliasAAWrapperPass(); +} + +ScopedNoAliasAAWrapperPass::ScopedNoAliasAAWrapperPass() : ImmutablePass(ID) { +  initializeScopedNoAliasAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ScopedNoAliasAAWrapperPass::doInitialization(Module &M) { +  Result.reset(new ScopedNoAliasAAResult()); +  return false; +} + +bool ScopedNoAliasAAWrapperPass::doFinalization(Module &M) { +  Result.reset(); +  return false; +} + +void ScopedNoAliasAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +} diff --git a/contrib/llvm/lib/Analysis/StratifiedSets.h b/contrib/llvm/lib/Analysis/StratifiedSets.h new file mode 100644 index 000000000000..2f20cd12506c --- /dev/null +++ b/contrib/llvm/lib/Analysis/StratifiedSets.h @@ -0,0 +1,597 @@ +//===- StratifiedSets.h - Abstract stratified sets implementation. --------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_STRATIFIEDSETS_H +#define LLVM_ADT_STRATIFIEDSETS_H + +#include "AliasAnalysisSummary.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include <bitset> +#include <cassert> +#include <cmath> +#include <type_traits> +#include <utility> +#include <vector> + +namespace llvm { +namespace cflaa { +/// An index into Stratified Sets. +typedef unsigned StratifiedIndex; +/// NOTE: ^ This can't be a short -- bootstrapping clang has a case where +/// ~1M sets exist. + +// Container of information related to a value in a StratifiedSet. +struct StratifiedInfo { +  StratifiedIndex Index; +  /// For field sensitivity, etc. we can tack fields on here. +}; + +/// A "link" between two StratifiedSets. +struct StratifiedLink { +  /// This is a value used to signify "does not exist" where the +  /// StratifiedIndex type is used. +  /// +  /// This is used instead of Optional<StratifiedIndex> because +  /// Optional<StratifiedIndex> would eat up a considerable amount of extra +  /// memory, after struct padding/alignment is taken into account. +  static const StratifiedIndex SetSentinel; + +  /// The index for the set "above" current +  StratifiedIndex Above; + +  /// The link for the set "below" current +  StratifiedIndex Below; + +  /// Attributes for these StratifiedSets. +  AliasAttrs Attrs; + +  StratifiedLink() : Above(SetSentinel), Below(SetSentinel) {} + +  bool hasBelow() const { return Below != SetSentinel; } +  bool hasAbove() const { return Above != SetSentinel; } + +  void clearBelow() { Below = SetSentinel; } +  void clearAbove() { Above = SetSentinel; } +}; + +/// These are stratified sets, as described in "Fast algorithms for +/// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M +/// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets +/// of Value*s. If two Value*s are in the same set, or if both sets have +/// overlapping attributes, then the Value*s are said to alias. +/// +/// Sets may be related by position, meaning that one set may be considered as +/// above or below another. In CFL Alias Analysis, this gives us an indication +/// of how two variables are related; if the set of variable A is below a set +/// containing variable B, then at some point, a variable that has interacted +/// with B (or B itself) was either used in order to extract the variable A, or +/// was used as storage of variable A. +/// +/// Sets may also have attributes (as noted above). These attributes are +/// generally used for noting whether a variable in the set has interacted with +/// a variable whose origins we don't quite know (i.e. globals/arguments), or if +/// the variable may have had operations performed on it (modified in a function +/// call). All attributes that exist in a set A must exist in all sets marked as +/// below set A. +template <typename T> class StratifiedSets { +public: +  StratifiedSets() = default; +  StratifiedSets(StratifiedSets &&) = default; +  StratifiedSets &operator=(StratifiedSets &&) = default; + +  StratifiedSets(DenseMap<T, StratifiedInfo> Map, +                 std::vector<StratifiedLink> Links) +      : Values(std::move(Map)), Links(std::move(Links)) {} + +  Optional<StratifiedInfo> find(const T &Elem) const { +    auto Iter = Values.find(Elem); +    if (Iter == Values.end()) +      return None; +    return Iter->second; +  } + +  const StratifiedLink &getLink(StratifiedIndex Index) const { +    assert(inbounds(Index)); +    return Links[Index]; +  } + +private: +  DenseMap<T, StratifiedInfo> Values; +  std::vector<StratifiedLink> Links; + +  bool inbounds(StratifiedIndex Idx) const { return Idx < Links.size(); } +}; + +/// Generic Builder class that produces StratifiedSets instances. +/// +/// The goal of this builder is to efficiently produce correct StratifiedSets +/// instances. To this end, we use a few tricks: +///   > Set chains (A method for linking sets together) +///   > Set remaps (A method for marking a set as an alias [irony?] of another) +/// +/// ==== Set chains ==== +/// This builder has a notion of some value A being above, below, or with some +/// other value B: +///   > The `A above B` relationship implies that there is a reference edge +///   going from A to B. Namely, it notes that A can store anything in B's set. +///   > The `A below B` relationship is the opposite of `A above B`. It implies +///   that there's a dereference edge going from A to B. +///   > The `A with B` relationship states that there's an assignment edge going +///   from A to B, and that A and B should be treated as equals. +/// +/// As an example, take the following code snippet: +/// +/// %a = alloca i32, align 4 +/// %ap = alloca i32*, align 8 +/// %app = alloca i32**, align 8 +/// store %a, %ap +/// store %ap, %app +/// %aw = getelementptr %ap, i32 0 +/// +/// Given this, the following relations exist: +///   - %a below %ap & %ap above %a +///   - %ap below %app & %app above %ap +///   - %aw with %ap & %ap with %aw +/// +/// These relations produce the following sets: +///   [{%a}, {%ap, %aw}, {%app}] +/// +/// ...Which state that the only MayAlias relationship in the above program is +/// between %ap and %aw. +/// +/// Because LLVM allows arbitrary casts, code like the following needs to be +/// supported: +///   %ip = alloca i64, align 8 +///   %ipp = alloca i64*, align 8 +///   %i = bitcast i64** ipp to i64 +///   store i64* %ip, i64** %ipp +///   store i64 %i, i64* %ip +/// +/// Which, because %ipp ends up *both* above and below %ip, is fun. +/// +/// This is solved by merging %i and %ipp into a single set (...which is the +/// only way to solve this, since their bit patterns are equivalent). Any sets +/// that ended up in between %i and %ipp at the time of merging (in this case, +/// the set containing %ip) also get conservatively merged into the set of %i +/// and %ipp. In short, the resulting StratifiedSet from the above code would be +/// {%ip, %ipp, %i}. +/// +/// ==== Set remaps ==== +/// More of an implementation detail than anything -- when merging sets, we need +/// to update the numbers of all of the elements mapped to those sets. Rather +/// than doing this at each merge, we note in the BuilderLink structure that a +/// remap has occurred, and use this information so we can defer renumbering set +/// elements until build time. +template <typename T> class StratifiedSetsBuilder { +  /// Represents a Stratified Set, with information about the Stratified +  /// Set above it, the set below it, and whether the current set has been +  /// remapped to another. +  struct BuilderLink { +    const StratifiedIndex Number; + +    BuilderLink(StratifiedIndex N) : Number(N) { +      Remap = StratifiedLink::SetSentinel; +    } + +    bool hasAbove() const { +      assert(!isRemapped()); +      return Link.hasAbove(); +    } + +    bool hasBelow() const { +      assert(!isRemapped()); +      return Link.hasBelow(); +    } + +    void setBelow(StratifiedIndex I) { +      assert(!isRemapped()); +      Link.Below = I; +    } + +    void setAbove(StratifiedIndex I) { +      assert(!isRemapped()); +      Link.Above = I; +    } + +    void clearBelow() { +      assert(!isRemapped()); +      Link.clearBelow(); +    } + +    void clearAbove() { +      assert(!isRemapped()); +      Link.clearAbove(); +    } + +    StratifiedIndex getBelow() const { +      assert(!isRemapped()); +      assert(hasBelow()); +      return Link.Below; +    } + +    StratifiedIndex getAbove() const { +      assert(!isRemapped()); +      assert(hasAbove()); +      return Link.Above; +    } + +    AliasAttrs getAttrs() { +      assert(!isRemapped()); +      return Link.Attrs; +    } + +    void setAttrs(AliasAttrs Other) { +      assert(!isRemapped()); +      Link.Attrs |= Other; +    } + +    bool isRemapped() const { return Remap != StratifiedLink::SetSentinel; } + +    /// For initial remapping to another set +    void remapTo(StratifiedIndex Other) { +      assert(!isRemapped()); +      Remap = Other; +    } + +    StratifiedIndex getRemapIndex() const { +      assert(isRemapped()); +      return Remap; +    } + +    /// Should only be called when we're already remapped. +    void updateRemap(StratifiedIndex Other) { +      assert(isRemapped()); +      Remap = Other; +    } + +    /// Prefer the above functions to calling things directly on what's returned +    /// from this -- they guard against unexpected calls when the current +    /// BuilderLink is remapped. +    const StratifiedLink &getLink() const { return Link; } + +  private: +    StratifiedLink Link; +    StratifiedIndex Remap; +  }; + +  /// This function performs all of the set unioning/value renumbering +  /// that we've been putting off, and generates a vector<StratifiedLink> that +  /// may be placed in a StratifiedSets instance. +  void finalizeSets(std::vector<StratifiedLink> &StratLinks) { +    DenseMap<StratifiedIndex, StratifiedIndex> Remaps; +    for (auto &Link : Links) { +      if (Link.isRemapped()) +        continue; + +      StratifiedIndex Number = StratLinks.size(); +      Remaps.insert(std::make_pair(Link.Number, Number)); +      StratLinks.push_back(Link.getLink()); +    } + +    for (auto &Link : StratLinks) { +      if (Link.hasAbove()) { +        auto &Above = linksAt(Link.Above); +        auto Iter = Remaps.find(Above.Number); +        assert(Iter != Remaps.end()); +        Link.Above = Iter->second; +      } + +      if (Link.hasBelow()) { +        auto &Below = linksAt(Link.Below); +        auto Iter = Remaps.find(Below.Number); +        assert(Iter != Remaps.end()); +        Link.Below = Iter->second; +      } +    } + +    for (auto &Pair : Values) { +      auto &Info = Pair.second; +      auto &Link = linksAt(Info.Index); +      auto Iter = Remaps.find(Link.Number); +      assert(Iter != Remaps.end()); +      Info.Index = Iter->second; +    } +  } + +  /// There's a guarantee in StratifiedLink where all bits set in a +  /// Link.externals will be set in all Link.externals "below" it. +  static void propagateAttrs(std::vector<StratifiedLink> &Links) { +    const auto getHighestParentAbove = [&Links](StratifiedIndex Idx) { +      const auto *Link = &Links[Idx]; +      while (Link->hasAbove()) { +        Idx = Link->Above; +        Link = &Links[Idx]; +      } +      return Idx; +    }; + +    SmallSet<StratifiedIndex, 16> Visited; +    for (unsigned I = 0, E = Links.size(); I < E; ++I) { +      auto CurrentIndex = getHighestParentAbove(I); +      if (!Visited.insert(CurrentIndex).second) +        continue; + +      while (Links[CurrentIndex].hasBelow()) { +        auto &CurrentBits = Links[CurrentIndex].Attrs; +        auto NextIndex = Links[CurrentIndex].Below; +        auto &NextBits = Links[NextIndex].Attrs; +        NextBits |= CurrentBits; +        CurrentIndex = NextIndex; +      } +    } +  } + +public: +  /// Builds a StratifiedSet from the information we've been given since either +  /// construction or the prior build() call. +  StratifiedSets<T> build() { +    std::vector<StratifiedLink> StratLinks; +    finalizeSets(StratLinks); +    propagateAttrs(StratLinks); +    Links.clear(); +    return StratifiedSets<T>(std::move(Values), std::move(StratLinks)); +  } + +  bool has(const T &Elem) const { return get(Elem).hasValue(); } + +  bool add(const T &Main) { +    if (get(Main).hasValue()) +      return false; + +    auto NewIndex = getNewUnlinkedIndex(); +    return addAtMerging(Main, NewIndex); +  } + +  /// Restructures the stratified sets as necessary to make "ToAdd" in a +  /// set above "Main". There are some cases where this is not possible (see +  /// above), so we merge them such that ToAdd and Main are in the same set. +  bool addAbove(const T &Main, const T &ToAdd) { +    assert(has(Main)); +    auto Index = *indexOf(Main); +    if (!linksAt(Index).hasAbove()) +      addLinkAbove(Index); + +    auto Above = linksAt(Index).getAbove(); +    return addAtMerging(ToAdd, Above); +  } + +  /// Restructures the stratified sets as necessary to make "ToAdd" in a +  /// set below "Main". There are some cases where this is not possible (see +  /// above), so we merge them such that ToAdd and Main are in the same set. +  bool addBelow(const T &Main, const T &ToAdd) { +    assert(has(Main)); +    auto Index = *indexOf(Main); +    if (!linksAt(Index).hasBelow()) +      addLinkBelow(Index); + +    auto Below = linksAt(Index).getBelow(); +    return addAtMerging(ToAdd, Below); +  } + +  bool addWith(const T &Main, const T &ToAdd) { +    assert(has(Main)); +    auto MainIndex = *indexOf(Main); +    return addAtMerging(ToAdd, MainIndex); +  } + +  void noteAttributes(const T &Main, AliasAttrs NewAttrs) { +    assert(has(Main)); +    auto *Info = *get(Main); +    auto &Link = linksAt(Info->Index); +    Link.setAttrs(NewAttrs); +  } + +private: +  DenseMap<T, StratifiedInfo> Values; +  std::vector<BuilderLink> Links; + +  /// Adds the given element at the given index, merging sets if necessary. +  bool addAtMerging(const T &ToAdd, StratifiedIndex Index) { +    StratifiedInfo Info = {Index}; +    auto Pair = Values.insert(std::make_pair(ToAdd, Info)); +    if (Pair.second) +      return true; + +    auto &Iter = Pair.first; +    auto &IterSet = linksAt(Iter->second.Index); +    auto &ReqSet = linksAt(Index); + +    // Failed to add where we wanted to. Merge the sets. +    if (&IterSet != &ReqSet) +      merge(IterSet.Number, ReqSet.Number); + +    return false; +  } + +  /// Gets the BuilderLink at the given index, taking set remapping into +  /// account. +  BuilderLink &linksAt(StratifiedIndex Index) { +    auto *Start = &Links[Index]; +    if (!Start->isRemapped()) +      return *Start; + +    auto *Current = Start; +    while (Current->isRemapped()) +      Current = &Links[Current->getRemapIndex()]; + +    auto NewRemap = Current->Number; + +    // Run through everything that has yet to be updated, and update them to +    // remap to NewRemap +    Current = Start; +    while (Current->isRemapped()) { +      auto *Next = &Links[Current->getRemapIndex()]; +      Current->updateRemap(NewRemap); +      Current = Next; +    } + +    return *Current; +  } + +  /// Merges two sets into one another. Assumes that these sets are not +  /// already one in the same. +  void merge(StratifiedIndex Idx1, StratifiedIndex Idx2) { +    assert(inbounds(Idx1) && inbounds(Idx2)); +    assert(&linksAt(Idx1) != &linksAt(Idx2) && +           "Merging a set into itself is not allowed"); + +    // CASE 1: If the set at `Idx1` is above or below `Idx2`, we need to merge +    // both the +    // given sets, and all sets between them, into one. +    if (tryMergeUpwards(Idx1, Idx2)) +      return; + +    if (tryMergeUpwards(Idx2, Idx1)) +      return; + +    // CASE 2: The set at `Idx1` is not in the same chain as the set at `Idx2`. +    // We therefore need to merge the two chains together. +    mergeDirect(Idx1, Idx2); +  } + +  /// Merges two sets assuming that the set at `Idx1` is unreachable from +  /// traversing above or below the set at `Idx2`. +  void mergeDirect(StratifiedIndex Idx1, StratifiedIndex Idx2) { +    assert(inbounds(Idx1) && inbounds(Idx2)); + +    auto *LinksInto = &linksAt(Idx1); +    auto *LinksFrom = &linksAt(Idx2); +    // Merging everything above LinksInto then proceeding to merge everything +    // below LinksInto becomes problematic, so we go as far "up" as possible! +    while (LinksInto->hasAbove() && LinksFrom->hasAbove()) { +      LinksInto = &linksAt(LinksInto->getAbove()); +      LinksFrom = &linksAt(LinksFrom->getAbove()); +    } + +    if (LinksFrom->hasAbove()) { +      LinksInto->setAbove(LinksFrom->getAbove()); +      auto &NewAbove = linksAt(LinksInto->getAbove()); +      NewAbove.setBelow(LinksInto->Number); +    } + +    // Merging strategy: +    //  > If neither has links below, stop. +    //  > If only `LinksInto` has links below, stop. +    //  > If only `LinksFrom` has links below, reset `LinksInto.Below` to +    //  match `LinksFrom.Below` +    //  > If both have links above, deal with those next. +    while (LinksInto->hasBelow() && LinksFrom->hasBelow()) { +      auto FromAttrs = LinksFrom->getAttrs(); +      LinksInto->setAttrs(FromAttrs); + +      // Remap needs to happen after getBelow(), but before +      // assignment of LinksFrom +      auto *NewLinksFrom = &linksAt(LinksFrom->getBelow()); +      LinksFrom->remapTo(LinksInto->Number); +      LinksFrom = NewLinksFrom; +      LinksInto = &linksAt(LinksInto->getBelow()); +    } + +    if (LinksFrom->hasBelow()) { +      LinksInto->setBelow(LinksFrom->getBelow()); +      auto &NewBelow = linksAt(LinksInto->getBelow()); +      NewBelow.setAbove(LinksInto->Number); +    } + +    LinksInto->setAttrs(LinksFrom->getAttrs()); +    LinksFrom->remapTo(LinksInto->Number); +  } + +  /// Checks to see if lowerIndex is at a level lower than upperIndex. If so, it +  /// will merge lowerIndex with upperIndex (and all of the sets between) and +  /// return true. Otherwise, it will return false. +  bool tryMergeUpwards(StratifiedIndex LowerIndex, StratifiedIndex UpperIndex) { +    assert(inbounds(LowerIndex) && inbounds(UpperIndex)); +    auto *Lower = &linksAt(LowerIndex); +    auto *Upper = &linksAt(UpperIndex); +    if (Lower == Upper) +      return true; + +    SmallVector<BuilderLink *, 8> Found; +    auto *Current = Lower; +    auto Attrs = Current->getAttrs(); +    while (Current->hasAbove() && Current != Upper) { +      Found.push_back(Current); +      Attrs |= Current->getAttrs(); +      Current = &linksAt(Current->getAbove()); +    } + +    if (Current != Upper) +      return false; + +    Upper->setAttrs(Attrs); + +    if (Lower->hasBelow()) { +      auto NewBelowIndex = Lower->getBelow(); +      Upper->setBelow(NewBelowIndex); +      auto &NewBelow = linksAt(NewBelowIndex); +      NewBelow.setAbove(UpperIndex); +    } else { +      Upper->clearBelow(); +    } + +    for (const auto &Ptr : Found) +      Ptr->remapTo(Upper->Number); + +    return true; +  } + +  Optional<const StratifiedInfo *> get(const T &Val) const { +    auto Result = Values.find(Val); +    if (Result == Values.end()) +      return None; +    return &Result->second; +  } + +  Optional<StratifiedInfo *> get(const T &Val) { +    auto Result = Values.find(Val); +    if (Result == Values.end()) +      return None; +    return &Result->second; +  } + +  Optional<StratifiedIndex> indexOf(const T &Val) { +    auto MaybeVal = get(Val); +    if (!MaybeVal.hasValue()) +      return None; +    auto *Info = *MaybeVal; +    auto &Link = linksAt(Info->Index); +    return Link.Number; +  } + +  StratifiedIndex addLinkBelow(StratifiedIndex Set) { +    auto At = addLinks(); +    Links[Set].setBelow(At); +    Links[At].setAbove(Set); +    return At; +  } + +  StratifiedIndex addLinkAbove(StratifiedIndex Set) { +    auto At = addLinks(); +    Links[At].setBelow(Set); +    Links[Set].setAbove(At); +    return At; +  } + +  StratifiedIndex getNewUnlinkedIndex() { return addLinks(); } + +  StratifiedIndex addLinks() { +    auto Link = Links.size(); +    Links.push_back(BuilderLink(Link)); +    return Link; +  } + +  bool inbounds(StratifiedIndex N) const { return N < Links.size(); } +}; +} +} +#endif // LLVM_ADT_STRATIFIEDSETS_H diff --git a/contrib/llvm/lib/Analysis/SyntheticCountsUtils.cpp b/contrib/llvm/lib/Analysis/SyntheticCountsUtils.cpp new file mode 100644 index 000000000000..b085fa274d7f --- /dev/null +++ b/contrib/llvm/lib/Analysis/SyntheticCountsUtils.cpp @@ -0,0 +1,113 @@ +//===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines utilities for propagating synthetic counts. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SyntheticCountsUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/CallGraph.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +// Given an SCC, propagate entry counts along the edge of the SCC nodes. +template <typename CallGraphType> +void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( +    const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, GetCountTy GetCount, +    AddCountTy AddCount) { + +  SmallPtrSet<NodeRef, 8> SCCNodes; +  SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges; + +  for (auto &Node : SCC) +    SCCNodes.insert(Node); + +  // Partition the edges coming out of the SCC into those whose destination is +  // in the SCC and the rest. +  for (const auto &Node : SCCNodes) { +    for (auto &E : children_edges<CallGraphType>(Node)) { +      if (SCCNodes.count(CGT::edge_dest(E))) +        SCCEdges.emplace_back(Node, E); +      else +        NonSCCEdges.emplace_back(Node, E); +    } +  } + +  // For nodes in the same SCC, update the counts in two steps: +  // 1. Compute the additional count for each node by propagating the counts +  // along all incoming edges to the node that originate from within the same +  // SCC and summing them up. +  // 2. Add the additional counts to the nodes in the SCC. +  // This ensures that the order of +  // traversal of nodes within the SCC doesn't affect the final result. + +  DenseMap<NodeRef, uint64_t> AdditionalCounts; +  for (auto &E : SCCEdges) { +    auto OptRelFreq = GetRelBBFreq(E.second); +    if (!OptRelFreq) +      continue; +    Scaled64 RelFreq = OptRelFreq.getValue(); +    auto Caller = E.first; +    auto Callee = CGT::edge_dest(E.second); +    RelFreq *= Scaled64(GetCount(Caller), 0); +    uint64_t AdditionalCount = RelFreq.toInt<uint64_t>(); +    AdditionalCounts[Callee] += AdditionalCount; +  } + +  // Update the counts for the nodes in the SCC. +  for (auto &Entry : AdditionalCounts) +    AddCount(Entry.first, Entry.second); + +  // Now update the counts for nodes outside the SCC. +  for (auto &E : NonSCCEdges) { +    auto OptRelFreq = GetRelBBFreq(E.second); +    if (!OptRelFreq) +      continue; +    Scaled64 RelFreq = OptRelFreq.getValue(); +    auto Caller = E.first; +    auto Callee = CGT::edge_dest(E.second); +    RelFreq *= Scaled64(GetCount(Caller), 0); +    AddCount(Callee, RelFreq.toInt<uint64_t>()); +  } +} + +/// Propgate synthetic entry counts on a callgraph \p CG. +/// +/// This performs a reverse post-order traversal of the callgraph SCC. For each +/// SCC, it first propagates the entry counts to the nodes within the SCC +/// through call edges and updates them in one shot. Then the entry counts are +/// propagated to nodes outside the SCC. This requires \p GraphTraits +/// to have a specialization for \p CallGraphType. + +template <typename CallGraphType> +void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG, +                                                    GetRelBBFreqTy GetRelBBFreq, +                                                    GetCountTy GetCount, +                                                    AddCountTy AddCount) { +  std::vector<SccTy> SCCs; + +  // Collect all the SCCs. +  for (auto I = scc_begin(CG); !I.isAtEnd(); ++I) +    SCCs.push_back(*I); + +  // The callgraph-scc needs to be visited in top-down order for propagation. +  // The scc iterator returns the scc in bottom-up order, so reverse the SCCs +  // and call propagateFromSCC. +  for (auto &SCC : reverse(SCCs)) +    propagateFromSCC(SCC, GetRelBBFreq, GetCount, AddCount); +} + +template class llvm::SyntheticCountsUtils<const CallGraph *>; diff --git a/contrib/llvm/lib/Analysis/TargetLibraryInfo.cpp b/contrib/llvm/lib/Analysis/TargetLibraryInfo.cpp new file mode 100644 index 000000000000..102135fbf313 --- /dev/null +++ b/contrib/llvm/lib/Analysis/TargetLibraryInfo.cpp @@ -0,0 +1,1663 @@ +//===-- TargetLibraryInfo.cpp - Runtime library information ----------------==// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the TargetLibraryInfo class. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/ADT/Triple.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/CommandLine.h" +using namespace llvm; + +static cl::opt<TargetLibraryInfoImpl::VectorLibrary> ClVectorLibrary( +    "vector-library", cl::Hidden, cl::desc("Vector functions library"), +    cl::init(TargetLibraryInfoImpl::NoLibrary), +    cl::values(clEnumValN(TargetLibraryInfoImpl::NoLibrary, "none", +                          "No vector functions library"), +               clEnumValN(TargetLibraryInfoImpl::Accelerate, "Accelerate", +                          "Accelerate framework"), +               clEnumValN(TargetLibraryInfoImpl::SVML, "SVML", +                          "Intel SVML library"))); + +StringRef const TargetLibraryInfoImpl::StandardNames[LibFunc::NumLibFuncs] = { +#define TLI_DEFINE_STRING +#include "llvm/Analysis/TargetLibraryInfo.def" +}; + +static bool hasSinCosPiStret(const Triple &T) { +  // Only Darwin variants have _stret versions of combined trig functions. +  if (!T.isOSDarwin()) +    return false; + +  // The ABI is rather complicated on x86, so don't do anything special there. +  if (T.getArch() == Triple::x86) +    return false; + +  if (T.isMacOSX() && T.isMacOSXVersionLT(10, 9)) +    return false; + +  if (T.isiOS() && T.isOSVersionLT(7, 0)) +    return false; + +  return true; +} + +/// Initialize the set of available library functions based on the specified +/// target triple. This should be carefully written so that a missing target +/// triple gets a sane set of defaults. +static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, +                       ArrayRef<StringRef> StandardNames) { +  // Verify that the StandardNames array is in alphabetical order. +  assert(std::is_sorted(StandardNames.begin(), StandardNames.end(), +                        [](StringRef LHS, StringRef RHS) { +                          return LHS < RHS; +                        }) && +         "TargetLibraryInfoImpl function names must be sorted"); + +  // Set IO unlocked variants as unavailable +  // Set them as available per system below +  TLI.setUnavailable(LibFunc_getchar_unlocked); +  TLI.setUnavailable(LibFunc_putc_unlocked); +  TLI.setUnavailable(LibFunc_putchar_unlocked); +  TLI.setUnavailable(LibFunc_fputc_unlocked); +  TLI.setUnavailable(LibFunc_fgetc_unlocked); +  TLI.setUnavailable(LibFunc_fread_unlocked); +  TLI.setUnavailable(LibFunc_fwrite_unlocked); +  TLI.setUnavailable(LibFunc_fputs_unlocked); +  TLI.setUnavailable(LibFunc_fgets_unlocked); + +  bool ShouldExtI32Param = false, ShouldExtI32Return = false, +       ShouldSignExtI32Param = false; +  // PowerPC64, Sparc64, SystemZ need signext/zeroext on i32 parameters and +  // returns corresponding to C-level ints and unsigned ints. +  if (T.getArch() == Triple::ppc64 || T.getArch() == Triple::ppc64le || +      T.getArch() == Triple::sparcv9 || T.getArch() == Triple::systemz) { +    ShouldExtI32Param = true; +    ShouldExtI32Return = true; +  } +  // Mips, on the other hand, needs signext on i32 parameters corresponding +  // to both signed and unsigned ints. +  if (T.isMIPS()) { +    ShouldSignExtI32Param = true; +  } +  TLI.setShouldExtI32Param(ShouldExtI32Param); +  TLI.setShouldExtI32Return(ShouldExtI32Return); +  TLI.setShouldSignExtI32Param(ShouldSignExtI32Param); + +  if (T.getArch() == Triple::r600 || +      T.getArch() == Triple::amdgcn) { +    TLI.setUnavailable(LibFunc_ldexp); +    TLI.setUnavailable(LibFunc_ldexpf); +    TLI.setUnavailable(LibFunc_ldexpl); +    TLI.setUnavailable(LibFunc_exp10); +    TLI.setUnavailable(LibFunc_exp10f); +    TLI.setUnavailable(LibFunc_exp10l); +    TLI.setUnavailable(LibFunc_log10); +    TLI.setUnavailable(LibFunc_log10f); +    TLI.setUnavailable(LibFunc_log10l); +  } + +  // There are no library implementations of mempcy and memset for AMD gpus and +  // these can be difficult to lower in the backend. +  if (T.getArch() == Triple::r600 || +      T.getArch() == Triple::amdgcn) { +    TLI.setUnavailable(LibFunc_memcpy); +    TLI.setUnavailable(LibFunc_memset); +    TLI.setUnavailable(LibFunc_memset_pattern16); +    return; +  } + +  // memset_pattern16 is only available on iOS 3.0 and Mac OS X 10.5 and later. +  // All versions of watchOS support it. +  if (T.isMacOSX()) { +    // available IO unlocked variants on Mac OS X +    TLI.setAvailable(LibFunc_getc_unlocked); +    TLI.setAvailable(LibFunc_getchar_unlocked); +    TLI.setAvailable(LibFunc_putc_unlocked); +    TLI.setAvailable(LibFunc_putchar_unlocked); + +    if (T.isMacOSXVersionLT(10, 5)) +      TLI.setUnavailable(LibFunc_memset_pattern16); +  } else if (T.isiOS()) { +    if (T.isOSVersionLT(3, 0)) +      TLI.setUnavailable(LibFunc_memset_pattern16); +  } else if (!T.isWatchOS()) { +    TLI.setUnavailable(LibFunc_memset_pattern16); +  } + +  if (!hasSinCosPiStret(T)) { +    TLI.setUnavailable(LibFunc_sinpi); +    TLI.setUnavailable(LibFunc_sinpif); +    TLI.setUnavailable(LibFunc_cospi); +    TLI.setUnavailable(LibFunc_cospif); +    TLI.setUnavailable(LibFunc_sincospi_stret); +    TLI.setUnavailable(LibFunc_sincospif_stret); +  } + +  if (T.isMacOSX() && T.getArch() == Triple::x86 && +      !T.isMacOSXVersionLT(10, 7)) { +    // x86-32 OSX has a scheme where fwrite and fputs (and some other functions +    // we don't care about) have two versions; on recent OSX, the one we want +    // has a $UNIX2003 suffix. The two implementations are identical except +    // for the return value in some edge cases.  However, we don't want to +    // generate code that depends on the old symbols. +    TLI.setAvailableWithName(LibFunc_fwrite, "fwrite$UNIX2003"); +    TLI.setAvailableWithName(LibFunc_fputs, "fputs$UNIX2003"); +  } + +  // iprintf and friends are only available on XCore and TCE. +  if (T.getArch() != Triple::xcore && T.getArch() != Triple::tce) { +    TLI.setUnavailable(LibFunc_iprintf); +    TLI.setUnavailable(LibFunc_siprintf); +    TLI.setUnavailable(LibFunc_fiprintf); +  } + +  if (T.isOSWindows() && !T.isOSCygMing()) { +    // Win32 does not support long double +    TLI.setUnavailable(LibFunc_acosl); +    TLI.setUnavailable(LibFunc_asinl); +    TLI.setUnavailable(LibFunc_atanl); +    TLI.setUnavailable(LibFunc_atan2l); +    TLI.setUnavailable(LibFunc_ceill); +    TLI.setUnavailable(LibFunc_copysignl); +    TLI.setUnavailable(LibFunc_cosl); +    TLI.setUnavailable(LibFunc_coshl); +    TLI.setUnavailable(LibFunc_expl); +    TLI.setUnavailable(LibFunc_fabsf); // Win32 and Win64 both lack fabsf +    TLI.setUnavailable(LibFunc_fabsl); +    TLI.setUnavailable(LibFunc_floorl); +    TLI.setUnavailable(LibFunc_fmaxl); +    TLI.setUnavailable(LibFunc_fminl); +    TLI.setUnavailable(LibFunc_fmodl); +    TLI.setUnavailable(LibFunc_frexpl); +    TLI.setUnavailable(LibFunc_ldexpf); +    TLI.setUnavailable(LibFunc_ldexpl); +    TLI.setUnavailable(LibFunc_logl); +    TLI.setUnavailable(LibFunc_modfl); +    TLI.setUnavailable(LibFunc_powl); +    TLI.setUnavailable(LibFunc_sinl); +    TLI.setUnavailable(LibFunc_sinhl); +    TLI.setUnavailable(LibFunc_sqrtl); +    TLI.setUnavailable(LibFunc_tanl); +    TLI.setUnavailable(LibFunc_tanhl); + +    // Win32 only has C89 math +    TLI.setUnavailable(LibFunc_acosh); +    TLI.setUnavailable(LibFunc_acoshf); +    TLI.setUnavailable(LibFunc_acoshl); +    TLI.setUnavailable(LibFunc_asinh); +    TLI.setUnavailable(LibFunc_asinhf); +    TLI.setUnavailable(LibFunc_asinhl); +    TLI.setUnavailable(LibFunc_atanh); +    TLI.setUnavailable(LibFunc_atanhf); +    TLI.setUnavailable(LibFunc_atanhl); +    TLI.setUnavailable(LibFunc_cabs); +    TLI.setUnavailable(LibFunc_cabsf); +    TLI.setUnavailable(LibFunc_cabsl); +    TLI.setUnavailable(LibFunc_cbrt); +    TLI.setUnavailable(LibFunc_cbrtf); +    TLI.setUnavailable(LibFunc_cbrtl); +    TLI.setUnavailable(LibFunc_exp2); +    TLI.setUnavailable(LibFunc_exp2f); +    TLI.setUnavailable(LibFunc_exp2l); +    TLI.setUnavailable(LibFunc_expm1); +    TLI.setUnavailable(LibFunc_expm1f); +    TLI.setUnavailable(LibFunc_expm1l); +    TLI.setUnavailable(LibFunc_log2); +    TLI.setUnavailable(LibFunc_log2f); +    TLI.setUnavailable(LibFunc_log2l); +    TLI.setUnavailable(LibFunc_log1p); +    TLI.setUnavailable(LibFunc_log1pf); +    TLI.setUnavailable(LibFunc_log1pl); +    TLI.setUnavailable(LibFunc_logb); +    TLI.setUnavailable(LibFunc_logbf); +    TLI.setUnavailable(LibFunc_logbl); +    TLI.setUnavailable(LibFunc_nearbyint); +    TLI.setUnavailable(LibFunc_nearbyintf); +    TLI.setUnavailable(LibFunc_nearbyintl); +    TLI.setUnavailable(LibFunc_rint); +    TLI.setUnavailable(LibFunc_rintf); +    TLI.setUnavailable(LibFunc_rintl); +    TLI.setUnavailable(LibFunc_round); +    TLI.setUnavailable(LibFunc_roundf); +    TLI.setUnavailable(LibFunc_roundl); +    TLI.setUnavailable(LibFunc_trunc); +    TLI.setUnavailable(LibFunc_truncf); +    TLI.setUnavailable(LibFunc_truncl); + +    // Win32 provides some C99 math with mangled names +    TLI.setAvailableWithName(LibFunc_copysign, "_copysign"); + +    if (T.getArch() == Triple::x86) { +      // Win32 on x86 implements single-precision math functions as macros +      TLI.setUnavailable(LibFunc_acosf); +      TLI.setUnavailable(LibFunc_asinf); +      TLI.setUnavailable(LibFunc_atanf); +      TLI.setUnavailable(LibFunc_atan2f); +      TLI.setUnavailable(LibFunc_ceilf); +      TLI.setUnavailable(LibFunc_copysignf); +      TLI.setUnavailable(LibFunc_cosf); +      TLI.setUnavailable(LibFunc_coshf); +      TLI.setUnavailable(LibFunc_expf); +      TLI.setUnavailable(LibFunc_floorf); +      TLI.setUnavailable(LibFunc_fminf); +      TLI.setUnavailable(LibFunc_fmaxf); +      TLI.setUnavailable(LibFunc_fmodf); +      TLI.setUnavailable(LibFunc_logf); +      TLI.setUnavailable(LibFunc_log10f); +      TLI.setUnavailable(LibFunc_modff); +      TLI.setUnavailable(LibFunc_powf); +      TLI.setUnavailable(LibFunc_sinf); +      TLI.setUnavailable(LibFunc_sinhf); +      TLI.setUnavailable(LibFunc_sqrtf); +      TLI.setUnavailable(LibFunc_tanf); +      TLI.setUnavailable(LibFunc_tanhf); +    } + +    // Win32 does *not* provide these functions, but they are +    // generally available on POSIX-compliant systems: +    TLI.setUnavailable(LibFunc_access); +    TLI.setUnavailable(LibFunc_bcmp); +    TLI.setUnavailable(LibFunc_bcopy); +    TLI.setUnavailable(LibFunc_bzero); +    TLI.setUnavailable(LibFunc_chmod); +    TLI.setUnavailable(LibFunc_chown); +    TLI.setUnavailable(LibFunc_closedir); +    TLI.setUnavailable(LibFunc_ctermid); +    TLI.setUnavailable(LibFunc_fdopen); +    TLI.setUnavailable(LibFunc_ffs); +    TLI.setUnavailable(LibFunc_fileno); +    TLI.setUnavailable(LibFunc_flockfile); +    TLI.setUnavailable(LibFunc_fseeko); +    TLI.setUnavailable(LibFunc_fstat); +    TLI.setUnavailable(LibFunc_fstatvfs); +    TLI.setUnavailable(LibFunc_ftello); +    TLI.setUnavailable(LibFunc_ftrylockfile); +    TLI.setUnavailable(LibFunc_funlockfile); +    TLI.setUnavailable(LibFunc_getitimer); +    TLI.setUnavailable(LibFunc_getlogin_r); +    TLI.setUnavailable(LibFunc_getpwnam); +    TLI.setUnavailable(LibFunc_gettimeofday); +    TLI.setUnavailable(LibFunc_htonl); +    TLI.setUnavailable(LibFunc_htons); +    TLI.setUnavailable(LibFunc_lchown); +    TLI.setUnavailable(LibFunc_lstat); +    TLI.setUnavailable(LibFunc_memccpy); +    TLI.setUnavailable(LibFunc_mkdir); +    TLI.setUnavailable(LibFunc_ntohl); +    TLI.setUnavailable(LibFunc_ntohs); +    TLI.setUnavailable(LibFunc_open); +    TLI.setUnavailable(LibFunc_opendir); +    TLI.setUnavailable(LibFunc_pclose); +    TLI.setUnavailable(LibFunc_popen); +    TLI.setUnavailable(LibFunc_pread); +    TLI.setUnavailable(LibFunc_pwrite); +    TLI.setUnavailable(LibFunc_read); +    TLI.setUnavailable(LibFunc_readlink); +    TLI.setUnavailable(LibFunc_realpath); +    TLI.setUnavailable(LibFunc_rmdir); +    TLI.setUnavailable(LibFunc_setitimer); +    TLI.setUnavailable(LibFunc_stat); +    TLI.setUnavailable(LibFunc_statvfs); +    TLI.setUnavailable(LibFunc_stpcpy); +    TLI.setUnavailable(LibFunc_stpncpy); +    TLI.setUnavailable(LibFunc_strcasecmp); +    TLI.setUnavailable(LibFunc_strncasecmp); +    TLI.setUnavailable(LibFunc_times); +    TLI.setUnavailable(LibFunc_uname); +    TLI.setUnavailable(LibFunc_unlink); +    TLI.setUnavailable(LibFunc_unsetenv); +    TLI.setUnavailable(LibFunc_utime); +    TLI.setUnavailable(LibFunc_utimes); +    TLI.setUnavailable(LibFunc_write); + +    // Win32 does *not* provide provide these functions, but they are +    // specified by C99: +    TLI.setUnavailable(LibFunc_atoll); +    TLI.setUnavailable(LibFunc_frexpf); +    TLI.setUnavailable(LibFunc_llabs); +  } + +  switch (T.getOS()) { +  case Triple::MacOSX: +    // exp10 and exp10f are not available on OS X until 10.9 and iOS until 7.0 +    // and their names are __exp10 and __exp10f. exp10l is not available on +    // OS X or iOS. +    TLI.setUnavailable(LibFunc_exp10l); +    if (T.isMacOSXVersionLT(10, 9)) { +      TLI.setUnavailable(LibFunc_exp10); +      TLI.setUnavailable(LibFunc_exp10f); +    } else { +      TLI.setAvailableWithName(LibFunc_exp10, "__exp10"); +      TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f"); +    } +    break; +  case Triple::IOS: +  case Triple::TvOS: +  case Triple::WatchOS: +    TLI.setUnavailable(LibFunc_exp10l); +    if (!T.isWatchOS() && (T.isOSVersionLT(7, 0) || +                           (T.isOSVersionLT(9, 0) && +                            (T.getArch() == Triple::x86 || +                             T.getArch() == Triple::x86_64)))) { +      TLI.setUnavailable(LibFunc_exp10); +      TLI.setUnavailable(LibFunc_exp10f); +    } else { +      TLI.setAvailableWithName(LibFunc_exp10, "__exp10"); +      TLI.setAvailableWithName(LibFunc_exp10f, "__exp10f"); +    } +    break; +  case Triple::Linux: +    // exp10, exp10f, exp10l is available on Linux (GLIBC) but are extremely +    // buggy prior to glibc version 2.18. Until this version is widely deployed +    // or we have a reasonable detection strategy, we cannot use exp10 reliably +    // on Linux. +    // +    // Fall through to disable all of them. +    LLVM_FALLTHROUGH; +  default: +    TLI.setUnavailable(LibFunc_exp10); +    TLI.setUnavailable(LibFunc_exp10f); +    TLI.setUnavailable(LibFunc_exp10l); +  } + +  // ffsl is available on at least Darwin, Mac OS X, iOS, FreeBSD, and +  // Linux (GLIBC): +  // http://developer.apple.com/library/mac/#documentation/Darwin/Reference/ManPages/man3/ffsl.3.html +  // http://svn.freebsd.org/base/head/lib/libc/string/ffsl.c +  // http://www.gnu.org/software/gnulib/manual/html_node/ffsl.html +  switch (T.getOS()) { +  case Triple::Darwin: +  case Triple::MacOSX: +  case Triple::IOS: +  case Triple::TvOS: +  case Triple::WatchOS: +  case Triple::FreeBSD: +  case Triple::Linux: +    break; +  default: +    TLI.setUnavailable(LibFunc_ffsl); +  } + +  // ffsll is available on at least FreeBSD and Linux (GLIBC): +  // http://svn.freebsd.org/base/head/lib/libc/string/ffsll.c +  // http://www.gnu.org/software/gnulib/manual/html_node/ffsll.html +  switch (T.getOS()) { +  case Triple::Darwin: +  case Triple::MacOSX: +  case Triple::IOS: +  case Triple::TvOS: +  case Triple::WatchOS: +  case Triple::FreeBSD: +  case Triple::Linux: +    break; +  default: +    TLI.setUnavailable(LibFunc_ffsll); +  } + +  // The following functions are available on at least FreeBSD: +  // http://svn.freebsd.org/base/head/lib/libc/string/fls.c +  // http://svn.freebsd.org/base/head/lib/libc/string/flsl.c +  // http://svn.freebsd.org/base/head/lib/libc/string/flsll.c +  if (!T.isOSFreeBSD()) { +    TLI.setUnavailable(LibFunc_fls); +    TLI.setUnavailable(LibFunc_flsl); +    TLI.setUnavailable(LibFunc_flsll); +  } + +  // The following functions are available on Linux, +  // but Android uses bionic instead of glibc. +  if (!T.isOSLinux() || T.isAndroid()) { +    TLI.setUnavailable(LibFunc_dunder_strdup); +    TLI.setUnavailable(LibFunc_dunder_strtok_r); +    TLI.setUnavailable(LibFunc_dunder_isoc99_scanf); +    TLI.setUnavailable(LibFunc_dunder_isoc99_sscanf); +    TLI.setUnavailable(LibFunc_under_IO_getc); +    TLI.setUnavailable(LibFunc_under_IO_putc); +    // But, Android has memalign. +    if (!T.isAndroid()) +      TLI.setUnavailable(LibFunc_memalign); +    TLI.setUnavailable(LibFunc_fopen64); +    TLI.setUnavailable(LibFunc_fseeko64); +    TLI.setUnavailable(LibFunc_fstat64); +    TLI.setUnavailable(LibFunc_fstatvfs64); +    TLI.setUnavailable(LibFunc_ftello64); +    TLI.setUnavailable(LibFunc_lstat64); +    TLI.setUnavailable(LibFunc_open64); +    TLI.setUnavailable(LibFunc_stat64); +    TLI.setUnavailable(LibFunc_statvfs64); +    TLI.setUnavailable(LibFunc_tmpfile64); + +    // Relaxed math functions are included in math-finite.h on Linux (GLIBC). +    TLI.setUnavailable(LibFunc_acos_finite); +    TLI.setUnavailable(LibFunc_acosf_finite); +    TLI.setUnavailable(LibFunc_acosl_finite); +    TLI.setUnavailable(LibFunc_acosh_finite); +    TLI.setUnavailable(LibFunc_acoshf_finite); +    TLI.setUnavailable(LibFunc_acoshl_finite); +    TLI.setUnavailable(LibFunc_asin_finite); +    TLI.setUnavailable(LibFunc_asinf_finite); +    TLI.setUnavailable(LibFunc_asinl_finite); +    TLI.setUnavailable(LibFunc_atan2_finite); +    TLI.setUnavailable(LibFunc_atan2f_finite); +    TLI.setUnavailable(LibFunc_atan2l_finite); +    TLI.setUnavailable(LibFunc_atanh_finite); +    TLI.setUnavailable(LibFunc_atanhf_finite); +    TLI.setUnavailable(LibFunc_atanhl_finite); +    TLI.setUnavailable(LibFunc_cosh_finite); +    TLI.setUnavailable(LibFunc_coshf_finite); +    TLI.setUnavailable(LibFunc_coshl_finite); +    TLI.setUnavailable(LibFunc_exp10_finite); +    TLI.setUnavailable(LibFunc_exp10f_finite); +    TLI.setUnavailable(LibFunc_exp10l_finite); +    TLI.setUnavailable(LibFunc_exp2_finite); +    TLI.setUnavailable(LibFunc_exp2f_finite); +    TLI.setUnavailable(LibFunc_exp2l_finite); +    TLI.setUnavailable(LibFunc_exp_finite); +    TLI.setUnavailable(LibFunc_expf_finite); +    TLI.setUnavailable(LibFunc_expl_finite); +    TLI.setUnavailable(LibFunc_log10_finite); +    TLI.setUnavailable(LibFunc_log10f_finite); +    TLI.setUnavailable(LibFunc_log10l_finite); +    TLI.setUnavailable(LibFunc_log2_finite); +    TLI.setUnavailable(LibFunc_log2f_finite); +    TLI.setUnavailable(LibFunc_log2l_finite); +    TLI.setUnavailable(LibFunc_log_finite); +    TLI.setUnavailable(LibFunc_logf_finite); +    TLI.setUnavailable(LibFunc_logl_finite); +    TLI.setUnavailable(LibFunc_pow_finite); +    TLI.setUnavailable(LibFunc_powf_finite); +    TLI.setUnavailable(LibFunc_powl_finite); +    TLI.setUnavailable(LibFunc_sinh_finite); +    TLI.setUnavailable(LibFunc_sinhf_finite); +    TLI.setUnavailable(LibFunc_sinhl_finite); +  } + +  if ((T.isOSLinux() && T.isGNUEnvironment()) || +      (T.isAndroid() && !T.isAndroidVersionLT(28))) { +    // available IO unlocked variants on GNU/Linux and Android P or later +    TLI.setAvailable(LibFunc_getc_unlocked); +    TLI.setAvailable(LibFunc_getchar_unlocked); +    TLI.setAvailable(LibFunc_putc_unlocked); +    TLI.setAvailable(LibFunc_putchar_unlocked); +    TLI.setAvailable(LibFunc_fputc_unlocked); +    TLI.setAvailable(LibFunc_fgetc_unlocked); +    TLI.setAvailable(LibFunc_fread_unlocked); +    TLI.setAvailable(LibFunc_fwrite_unlocked); +    TLI.setAvailable(LibFunc_fputs_unlocked); +    TLI.setAvailable(LibFunc_fgets_unlocked); +  } + +  // As currently implemented in clang, NVPTX code has no standard library to +  // speak of.  Headers provide a standard-ish library implementation, but many +  // of the signatures are wrong -- for example, many libm functions are not +  // extern "C". +  // +  // libdevice, an IR library provided by nvidia, is linked in by the front-end, +  // but only used functions are provided to llvm.  Moreover, most of the +  // functions in libdevice don't map precisely to standard library functions. +  // +  // FIXME: Having no standard library prevents e.g. many fastmath +  // optimizations, so this situation should be fixed. +  if (T.isNVPTX()) { +    TLI.disableAllFunctions(); +    TLI.setAvailable(LibFunc_nvvm_reflect); +  } else { +    TLI.setUnavailable(LibFunc_nvvm_reflect); +  } + +  TLI.addVectorizableFunctionsFromVecLib(ClVectorLibrary); +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl() { +  // Default to everything being available. +  memset(AvailableArray, -1, sizeof(AvailableArray)); + +  initialize(*this, Triple(), StandardNames); +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl(const Triple &T) { +  // Default to everything being available. +  memset(AvailableArray, -1, sizeof(AvailableArray)); + +  initialize(*this, T, StandardNames); +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl(const TargetLibraryInfoImpl &TLI) +    : CustomNames(TLI.CustomNames), ShouldExtI32Param(TLI.ShouldExtI32Param), +      ShouldExtI32Return(TLI.ShouldExtI32Return), +      ShouldSignExtI32Param(TLI.ShouldSignExtI32Param) { +  memcpy(AvailableArray, TLI.AvailableArray, sizeof(AvailableArray)); +  VectorDescs = TLI.VectorDescs; +  ScalarDescs = TLI.ScalarDescs; +} + +TargetLibraryInfoImpl::TargetLibraryInfoImpl(TargetLibraryInfoImpl &&TLI) +    : CustomNames(std::move(TLI.CustomNames)), +      ShouldExtI32Param(TLI.ShouldExtI32Param), +      ShouldExtI32Return(TLI.ShouldExtI32Return), +      ShouldSignExtI32Param(TLI.ShouldSignExtI32Param) { +  std::move(std::begin(TLI.AvailableArray), std::end(TLI.AvailableArray), +            AvailableArray); +  VectorDescs = TLI.VectorDescs; +  ScalarDescs = TLI.ScalarDescs; +} + +TargetLibraryInfoImpl &TargetLibraryInfoImpl::operator=(const TargetLibraryInfoImpl &TLI) { +  CustomNames = TLI.CustomNames; +  ShouldExtI32Param = TLI.ShouldExtI32Param; +  ShouldExtI32Return = TLI.ShouldExtI32Return; +  ShouldSignExtI32Param = TLI.ShouldSignExtI32Param; +  memcpy(AvailableArray, TLI.AvailableArray, sizeof(AvailableArray)); +  return *this; +} + +TargetLibraryInfoImpl &TargetLibraryInfoImpl::operator=(TargetLibraryInfoImpl &&TLI) { +  CustomNames = std::move(TLI.CustomNames); +  ShouldExtI32Param = TLI.ShouldExtI32Param; +  ShouldExtI32Return = TLI.ShouldExtI32Return; +  ShouldSignExtI32Param = TLI.ShouldSignExtI32Param; +  std::move(std::begin(TLI.AvailableArray), std::end(TLI.AvailableArray), +            AvailableArray); +  return *this; +} + +static StringRef sanitizeFunctionName(StringRef funcName) { +  // Filter out empty names and names containing null bytes, those can't be in +  // our table. +  if (funcName.empty() || funcName.find('\0') != StringRef::npos) +    return StringRef(); + +  // Check for \01 prefix that is used to mangle __asm declarations and +  // strip it if present. +  return GlobalValue::dropLLVMManglingEscape(funcName); +} + +bool TargetLibraryInfoImpl::getLibFunc(StringRef funcName, +                                       LibFunc &F) const { +  StringRef const *Start = &StandardNames[0]; +  StringRef const *End = &StandardNames[NumLibFuncs]; + +  funcName = sanitizeFunctionName(funcName); +  if (funcName.empty()) +    return false; + +  StringRef const *I = std::lower_bound( +      Start, End, funcName, [](StringRef LHS, StringRef RHS) { +        return LHS < RHS; +      }); +  if (I != End && *I == funcName) { +    F = (LibFunc)(I - Start); +    return true; +  } +  return false; +} + +bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, +                                                   LibFunc F, +                                                   const DataLayout *DL) const { +  LLVMContext &Ctx = FTy.getContext(); +  Type *PCharTy = Type::getInt8PtrTy(Ctx); +  Type *SizeTTy = DL ? DL->getIntPtrType(Ctx, /*AS=*/0) : nullptr; +  auto IsSizeTTy = [SizeTTy](Type *Ty) { +    return SizeTTy ? Ty == SizeTTy : Ty->isIntegerTy(); +  }; +  unsigned NumParams = FTy.getNumParams(); + +  switch (F) { +  case LibFunc_strlen: +    return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && +            FTy.getReturnType()->isIntegerTy()); + +  case LibFunc_strchr: +  case LibFunc_strrchr: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0) == FTy.getReturnType() && +            FTy.getParamType(1)->isIntegerTy()); + +  case LibFunc_strtol: +  case LibFunc_strtod: +  case LibFunc_strtof: +  case LibFunc_strtoul: +  case LibFunc_strtoll: +  case LibFunc_strtold: +  case LibFunc_strtoull: +    return ((NumParams == 2 || NumParams == 3) && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_strcat: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0) == FTy.getReturnType() && +            FTy.getParamType(1) == FTy.getReturnType()); + +  case LibFunc_strncat: +    return (NumParams == 3 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0) == FTy.getReturnType() && +            FTy.getParamType(1) == FTy.getReturnType() && +            IsSizeTTy(FTy.getParamType(2))); + +  case LibFunc_strcpy_chk: +  case LibFunc_stpcpy_chk: +    --NumParams; +    if (!IsSizeTTy(FTy.getParamType(NumParams))) +      return false; +    LLVM_FALLTHROUGH; +  case LibFunc_strcpy: +  case LibFunc_stpcpy: +    return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getParamType(0) == FTy.getParamType(1) && +            FTy.getParamType(0) == PCharTy); + +  case LibFunc_strncpy_chk: +  case LibFunc_stpncpy_chk: +    --NumParams; +    if (!IsSizeTTy(FTy.getParamType(NumParams))) +      return false; +    LLVM_FALLTHROUGH; +  case LibFunc_strncpy: +  case LibFunc_stpncpy: +    return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getParamType(0) == FTy.getParamType(1) && +            FTy.getParamType(0) == PCharTy && +            IsSizeTTy(FTy.getParamType(2))); + +  case LibFunc_strxfrm: +    return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); + +  case LibFunc_strcmp: +    return (NumParams == 2 && FTy.getReturnType()->isIntegerTy(32) && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(0) == FTy.getParamType(1)); + +  case LibFunc_strncmp: +    return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(0) == FTy.getParamType(1) && +            IsSizeTTy(FTy.getParamType(2))); + +  case LibFunc_strspn: +  case LibFunc_strcspn: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(0) == FTy.getParamType(1) && +            FTy.getReturnType()->isIntegerTy()); + +  case LibFunc_strcoll: +  case LibFunc_strcasecmp: +  case LibFunc_strncasecmp: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); + +  case LibFunc_strstr: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); + +  case LibFunc_strpbrk: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getParamType(0) == FTy.getParamType(1)); + +  case LibFunc_strtok: +  case LibFunc_strtok_r: +    return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_scanf: +  case LibFunc_setbuf: +  case LibFunc_setvbuf: +    return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_strdup: +  case LibFunc_strndup: +    return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0)->isPointerTy()); +  case LibFunc_sscanf: +  case LibFunc_stat: +  case LibFunc_statvfs: +  case LibFunc_siprintf: +  case LibFunc_sprintf: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy() && +            FTy.getReturnType()->isIntegerTy(32)); +  case LibFunc_snprintf: +    return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(2)->isPointerTy() && +            FTy.getReturnType()->isIntegerTy(32)); +  case LibFunc_setitimer: +    return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && +            FTy.getParamType(2)->isPointerTy()); +  case LibFunc_system: +    return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_malloc: +    return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); +  case LibFunc_memcmp: +    return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); + +  case LibFunc_memchr: +  case LibFunc_memrchr: +    return (NumParams == 3 && FTy.getReturnType()->isPointerTy() && +            FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getParamType(1)->isIntegerTy(32) && +            IsSizeTTy(FTy.getParamType(2))); +  case LibFunc_modf: +  case LibFunc_modff: +  case LibFunc_modfl: +    return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); + +  case LibFunc_memcpy_chk: +  case LibFunc_memmove_chk: +    --NumParams; +    if (!IsSizeTTy(FTy.getParamType(NumParams))) +      return false; +    LLVM_FALLTHROUGH; +  case LibFunc_memcpy: +  case LibFunc_mempcpy: +  case LibFunc_memmove: +    return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy() && +            IsSizeTTy(FTy.getParamType(2))); + +  case LibFunc_memset_chk: +    --NumParams; +    if (!IsSizeTTy(FTy.getParamType(NumParams))) +      return false; +    LLVM_FALLTHROUGH; +  case LibFunc_memset: +    return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isIntegerTy() && +            IsSizeTTy(FTy.getParamType(2))); + +  case LibFunc_memccpy: +    return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_memalign: +    return (FTy.getReturnType()->isPointerTy()); +  case LibFunc_realloc: +  case LibFunc_reallocf: +    return (NumParams == 2 && FTy.getReturnType() == PCharTy && +            FTy.getParamType(0) == FTy.getReturnType() && +            IsSizeTTy(FTy.getParamType(1))); +  case LibFunc_read: +    return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_rewind: +  case LibFunc_rmdir: +  case LibFunc_remove: +  case LibFunc_realpath: +    return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_rename: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_readlink: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_write: +    return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_bcopy: +  case LibFunc_bcmp: +    return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_bzero: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_calloc: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy()); + +  case LibFunc_atof: +  case LibFunc_atoi: +  case LibFunc_atol: +  case LibFunc_atoll: +  case LibFunc_ferror: +  case LibFunc_getenv: +  case LibFunc_getpwnam: +  case LibFunc_iprintf: +  case LibFunc_pclose: +  case LibFunc_perror: +  case LibFunc_printf: +  case LibFunc_puts: +  case LibFunc_uname: +  case LibFunc_under_IO_getc: +  case LibFunc_unlink: +  case LibFunc_unsetenv: +    return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + +  case LibFunc_access: +  case LibFunc_chmod: +  case LibFunc_chown: +  case LibFunc_clearerr: +  case LibFunc_closedir: +  case LibFunc_ctermid: +  case LibFunc_fclose: +  case LibFunc_feof: +  case LibFunc_fflush: +  case LibFunc_fgetc: +  case LibFunc_fgetc_unlocked: +  case LibFunc_fileno: +  case LibFunc_flockfile: +  case LibFunc_free: +  case LibFunc_fseek: +  case LibFunc_fseeko64: +  case LibFunc_fseeko: +  case LibFunc_fsetpos: +  case LibFunc_ftell: +  case LibFunc_ftello64: +  case LibFunc_ftello: +  case LibFunc_ftrylockfile: +  case LibFunc_funlockfile: +  case LibFunc_getc: +  case LibFunc_getc_unlocked: +  case LibFunc_getlogin_r: +  case LibFunc_mkdir: +  case LibFunc_mktime: +  case LibFunc_times: +    return (NumParams != 0 && FTy.getParamType(0)->isPointerTy()); + +  case LibFunc_fopen: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_fdopen: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_fputc: +  case LibFunc_fputc_unlocked: +  case LibFunc_fstat: +  case LibFunc_frexp: +  case LibFunc_frexpf: +  case LibFunc_frexpl: +  case LibFunc_fstatvfs: +    return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_fgets: +  case LibFunc_fgets_unlocked: +    return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(2)->isPointerTy()); +  case LibFunc_fread: +  case LibFunc_fread_unlocked: +    return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(3)->isPointerTy()); +  case LibFunc_fwrite: +  case LibFunc_fwrite_unlocked: +    return (NumParams == 4 && FTy.getReturnType()->isIntegerTy() && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isIntegerTy() && +            FTy.getParamType(2)->isIntegerTy() && +            FTy.getParamType(3)->isPointerTy()); +  case LibFunc_fputs: +  case LibFunc_fputs_unlocked: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_fscanf: +  case LibFunc_fiprintf: +  case LibFunc_fprintf: +    return (NumParams >= 2 && FTy.getReturnType()->isIntegerTy() && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_fgetpos: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_getchar: +  case LibFunc_getchar_unlocked: +    return (NumParams == 0 && FTy.getReturnType()->isIntegerTy()); +  case LibFunc_gets: +    return (NumParams == 1 && FTy.getParamType(0) == PCharTy); +  case LibFunc_getitimer: +    return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_ungetc: +    return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_utime: +  case LibFunc_utimes: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_putc: +  case LibFunc_putc_unlocked: +    return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_pread: +  case LibFunc_pwrite: +    return (NumParams == 4 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_popen: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_vscanf: +    return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_vsscanf: +    return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && +            FTy.getParamType(2)->isPointerTy()); +  case LibFunc_vfscanf: +    return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && +            FTy.getParamType(2)->isPointerTy()); +  case LibFunc_valloc: +    return (FTy.getReturnType()->isPointerTy()); +  case LibFunc_vprintf: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_vfprintf: +  case LibFunc_vsprintf: +    return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_vsnprintf: +    return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(2)->isPointerTy()); +  case LibFunc_open: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_opendir: +    return (NumParams == 1 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0)->isPointerTy()); +  case LibFunc_tmpfile: +    return (FTy.getReturnType()->isPointerTy()); +  case LibFunc_htonl: +  case LibFunc_ntohl: +    return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && +            FTy.getReturnType() == FTy.getParamType(0)); +  case LibFunc_htons: +  case LibFunc_ntohs: +    return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(16) && +            FTy.getReturnType() == FTy.getParamType(0)); +  case LibFunc_lstat: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_lchown: +    return (NumParams == 3 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_qsort: +    return (NumParams == 4 && FTy.getParamType(3)->isPointerTy()); +  case LibFunc_dunder_strdup: +  case LibFunc_dunder_strndup: +    return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0)->isPointerTy()); +  case LibFunc_dunder_strtok_r: +    return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_under_IO_putc: +    return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_dunder_isoc99_scanf: +    return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_stat64: +  case LibFunc_lstat64: +  case LibFunc_statvfs64: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_dunder_isoc99_sscanf: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_fopen64: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); +  case LibFunc_tmpfile64: +    return (FTy.getReturnType()->isPointerTy()); +  case LibFunc_fstat64: +  case LibFunc_fstatvfs64: +    return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); +  case LibFunc_open64: +    return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); +  case LibFunc_gettimeofday: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy()); + +  // new(unsigned int); +  case LibFunc_Znwj: +  // new(unsigned long); +  case LibFunc_Znwm: +  // new[](unsigned int); +  case LibFunc_Znaj: +  // new[](unsigned long); +  case LibFunc_Znam: +  // new(unsigned int); +  case LibFunc_msvc_new_int: +  // new(unsigned long long); +  case LibFunc_msvc_new_longlong: +  // new[](unsigned int); +  case LibFunc_msvc_new_array_int: +  // new[](unsigned long long); +  case LibFunc_msvc_new_array_longlong: +    return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); + +  // new(unsigned int, nothrow); +  case LibFunc_ZnwjRKSt9nothrow_t: +  // new(unsigned long, nothrow); +  case LibFunc_ZnwmRKSt9nothrow_t: +  // new[](unsigned int, nothrow); +  case LibFunc_ZnajRKSt9nothrow_t: +  // new[](unsigned long, nothrow); +  case LibFunc_ZnamRKSt9nothrow_t: +  // new(unsigned int, nothrow); +  case LibFunc_msvc_new_int_nothrow: +  // new(unsigned long long, nothrow); +  case LibFunc_msvc_new_longlong_nothrow: +  // new[](unsigned int, nothrow); +  case LibFunc_msvc_new_array_int_nothrow: +  // new[](unsigned long long, nothrow); +  case LibFunc_msvc_new_array_longlong_nothrow: +  // new(unsigned int, align_val_t) +  case LibFunc_ZnwjSt11align_val_t: +  // new(unsigned long, align_val_t) +  case LibFunc_ZnwmSt11align_val_t: +  // new[](unsigned int, align_val_t) +  case LibFunc_ZnajSt11align_val_t: +  // new[](unsigned long, align_val_t) +  case LibFunc_ZnamSt11align_val_t: +    return (NumParams == 2 && FTy.getReturnType()->isPointerTy()); + +  // new(unsigned int, align_val_t, nothrow) +  case LibFunc_ZnwjSt11align_val_tRKSt9nothrow_t: +  // new(unsigned long, align_val_t, nothrow) +  case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t: +  // new[](unsigned int, align_val_t, nothrow) +  case LibFunc_ZnajSt11align_val_tRKSt9nothrow_t: +  // new[](unsigned long, align_val_t, nothrow) +  case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t: +    return (NumParams == 3 && FTy.getReturnType()->isPointerTy()); + +  // void operator delete[](void*); +  case LibFunc_ZdaPv: +  // void operator delete(void*); +  case LibFunc_ZdlPv: +  // void operator delete[](void*); +  case LibFunc_msvc_delete_array_ptr32: +  // void operator delete[](void*); +  case LibFunc_msvc_delete_array_ptr64: +  // void operator delete(void*); +  case LibFunc_msvc_delete_ptr32: +  // void operator delete(void*); +  case LibFunc_msvc_delete_ptr64: +    return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + +  // void operator delete[](void*, nothrow); +  case LibFunc_ZdaPvRKSt9nothrow_t: +  // void operator delete[](void*, unsigned int); +  case LibFunc_ZdaPvj: +  // void operator delete[](void*, unsigned long); +  case LibFunc_ZdaPvm: +  // void operator delete(void*, nothrow); +  case LibFunc_ZdlPvRKSt9nothrow_t: +  // void operator delete(void*, unsigned int); +  case LibFunc_ZdlPvj: +  // void operator delete(void*, unsigned long); +  case LibFunc_ZdlPvm: +  // void operator delete(void*, align_val_t) +  case LibFunc_ZdlPvSt11align_val_t: +  // void operator delete[](void*, align_val_t) +  case LibFunc_ZdaPvSt11align_val_t: +  // void operator delete[](void*, unsigned int); +  case LibFunc_msvc_delete_array_ptr32_int: +  // void operator delete[](void*, nothrow); +  case LibFunc_msvc_delete_array_ptr32_nothrow: +  // void operator delete[](void*, unsigned long long); +  case LibFunc_msvc_delete_array_ptr64_longlong: +  // void operator delete[](void*, nothrow); +  case LibFunc_msvc_delete_array_ptr64_nothrow: +  // void operator delete(void*, unsigned int); +  case LibFunc_msvc_delete_ptr32_int: +  // void operator delete(void*, nothrow); +  case LibFunc_msvc_delete_ptr32_nothrow: +  // void operator delete(void*, unsigned long long); +  case LibFunc_msvc_delete_ptr64_longlong: +  // void operator delete(void*, nothrow); +  case LibFunc_msvc_delete_ptr64_nothrow: +    return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + +  // void operator delete(void*, align_val_t, nothrow) +  case LibFunc_ZdlPvSt11align_val_tRKSt9nothrow_t: +  // void operator delete[](void*, align_val_t, nothrow) +  case LibFunc_ZdaPvSt11align_val_tRKSt9nothrow_t: +    return (NumParams == 3 && FTy.getParamType(0)->isPointerTy()); + +  case LibFunc_memset_pattern16: +    return (!FTy.isVarArg() && NumParams == 3 && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy() && +            FTy.getParamType(2)->isIntegerTy()); + +  case LibFunc_cxa_guard_abort: +  case LibFunc_cxa_guard_acquire: +  case LibFunc_cxa_guard_release: +  case LibFunc_nvvm_reflect: +    return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + +  case LibFunc_sincospi_stret: +  case LibFunc_sincospif_stret: +    return (NumParams == 1 && FTy.getParamType(0)->isFloatingPointTy()); + +  case LibFunc_acos: +  case LibFunc_acos_finite: +  case LibFunc_acosf: +  case LibFunc_acosf_finite: +  case LibFunc_acosh: +  case LibFunc_acosh_finite: +  case LibFunc_acoshf: +  case LibFunc_acoshf_finite: +  case LibFunc_acoshl: +  case LibFunc_acoshl_finite: +  case LibFunc_acosl: +  case LibFunc_acosl_finite: +  case LibFunc_asin: +  case LibFunc_asin_finite: +  case LibFunc_asinf: +  case LibFunc_asinf_finite: +  case LibFunc_asinh: +  case LibFunc_asinhf: +  case LibFunc_asinhl: +  case LibFunc_asinl: +  case LibFunc_asinl_finite: +  case LibFunc_atan: +  case LibFunc_atanf: +  case LibFunc_atanh: +  case LibFunc_atanh_finite: +  case LibFunc_atanhf: +  case LibFunc_atanhf_finite: +  case LibFunc_atanhl: +  case LibFunc_atanhl_finite: +  case LibFunc_atanl: +  case LibFunc_cbrt: +  case LibFunc_cbrtf: +  case LibFunc_cbrtl: +  case LibFunc_ceil: +  case LibFunc_ceilf: +  case LibFunc_ceill: +  case LibFunc_cos: +  case LibFunc_cosf: +  case LibFunc_cosh: +  case LibFunc_cosh_finite: +  case LibFunc_coshf: +  case LibFunc_coshf_finite: +  case LibFunc_coshl: +  case LibFunc_coshl_finite: +  case LibFunc_cosl: +  case LibFunc_exp10: +  case LibFunc_exp10_finite: +  case LibFunc_exp10f: +  case LibFunc_exp10f_finite: +  case LibFunc_exp10l: +  case LibFunc_exp10l_finite: +  case LibFunc_exp2: +  case LibFunc_exp2_finite: +  case LibFunc_exp2f: +  case LibFunc_exp2f_finite: +  case LibFunc_exp2l: +  case LibFunc_exp2l_finite: +  case LibFunc_exp: +  case LibFunc_exp_finite: +  case LibFunc_expf: +  case LibFunc_expf_finite: +  case LibFunc_expl: +  case LibFunc_expl_finite: +  case LibFunc_expm1: +  case LibFunc_expm1f: +  case LibFunc_expm1l: +  case LibFunc_fabs: +  case LibFunc_fabsf: +  case LibFunc_fabsl: +  case LibFunc_floor: +  case LibFunc_floorf: +  case LibFunc_floorl: +  case LibFunc_log10: +  case LibFunc_log10_finite: +  case LibFunc_log10f: +  case LibFunc_log10f_finite: +  case LibFunc_log10l: +  case LibFunc_log10l_finite: +  case LibFunc_log1p: +  case LibFunc_log1pf: +  case LibFunc_log1pl: +  case LibFunc_log2: +  case LibFunc_log2_finite: +  case LibFunc_log2f: +  case LibFunc_log2f_finite: +  case LibFunc_log2l: +  case LibFunc_log2l_finite: +  case LibFunc_log: +  case LibFunc_log_finite: +  case LibFunc_logb: +  case LibFunc_logbf: +  case LibFunc_logbl: +  case LibFunc_logf: +  case LibFunc_logf_finite: +  case LibFunc_logl: +  case LibFunc_logl_finite: +  case LibFunc_nearbyint: +  case LibFunc_nearbyintf: +  case LibFunc_nearbyintl: +  case LibFunc_rint: +  case LibFunc_rintf: +  case LibFunc_rintl: +  case LibFunc_round: +  case LibFunc_roundf: +  case LibFunc_roundl: +  case LibFunc_sin: +  case LibFunc_sinf: +  case LibFunc_sinh: +  case LibFunc_sinh_finite: +  case LibFunc_sinhf: +  case LibFunc_sinhf_finite: +  case LibFunc_sinhl: +  case LibFunc_sinhl_finite: +  case LibFunc_sinl: +  case LibFunc_sqrt: +  case LibFunc_sqrt_finite: +  case LibFunc_sqrtf: +  case LibFunc_sqrtf_finite: +  case LibFunc_sqrtl: +  case LibFunc_sqrtl_finite: +  case LibFunc_tan: +  case LibFunc_tanf: +  case LibFunc_tanh: +  case LibFunc_tanhf: +  case LibFunc_tanhl: +  case LibFunc_tanl: +  case LibFunc_trunc: +  case LibFunc_truncf: +  case LibFunc_truncl: +    return (NumParams == 1 && FTy.getReturnType()->isFloatingPointTy() && +            FTy.getReturnType() == FTy.getParamType(0)); + +  case LibFunc_atan2: +  case LibFunc_atan2_finite: +  case LibFunc_atan2f: +  case LibFunc_atan2f_finite: +  case LibFunc_atan2l: +  case LibFunc_atan2l_finite: +  case LibFunc_fmin: +  case LibFunc_fminf: +  case LibFunc_fminl: +  case LibFunc_fmax: +  case LibFunc_fmaxf: +  case LibFunc_fmaxl: +  case LibFunc_fmod: +  case LibFunc_fmodf: +  case LibFunc_fmodl: +  case LibFunc_copysign: +  case LibFunc_copysignf: +  case LibFunc_copysignl: +  case LibFunc_pow: +  case LibFunc_pow_finite: +  case LibFunc_powf: +  case LibFunc_powf_finite: +  case LibFunc_powl: +  case LibFunc_powl_finite: +    return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() && +            FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getReturnType() == FTy.getParamType(1)); + +  case LibFunc_ldexp: +  case LibFunc_ldexpf: +  case LibFunc_ldexpl: +    return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() && +            FTy.getReturnType() == FTy.getParamType(0) && +            FTy.getParamType(1)->isIntegerTy(32)); + +  case LibFunc_ffs: +  case LibFunc_ffsl: +  case LibFunc_ffsll: +  case LibFunc_fls: +  case LibFunc_flsl: +  case LibFunc_flsll: +    return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && +            FTy.getParamType(0)->isIntegerTy()); + +  case LibFunc_isdigit: +  case LibFunc_isascii: +  case LibFunc_toascii: +  case LibFunc_putchar: +  case LibFunc_putchar_unlocked: +    return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && +            FTy.getReturnType() == FTy.getParamType(0)); + +  case LibFunc_abs: +  case LibFunc_labs: +  case LibFunc_llabs: +    return (NumParams == 1 && FTy.getReturnType()->isIntegerTy() && +            FTy.getReturnType() == FTy.getParamType(0)); + +  case LibFunc_cxa_atexit: +    return (NumParams == 3 && FTy.getReturnType()->isIntegerTy() && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1)->isPointerTy() && +            FTy.getParamType(2)->isPointerTy()); + +  case LibFunc_sinpi: +  case LibFunc_cospi: +    return (NumParams == 1 && FTy.getReturnType()->isDoubleTy() && +            FTy.getReturnType() == FTy.getParamType(0)); + +  case LibFunc_sinpif: +  case LibFunc_cospif: +    return (NumParams == 1 && FTy.getReturnType()->isFloatTy() && +            FTy.getReturnType() == FTy.getParamType(0)); + +  case LibFunc_strnlen: +    return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(1) && +            FTy.getParamType(0) == PCharTy && +            FTy.getParamType(1) == SizeTTy); + +  case LibFunc_posix_memalign: +    return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && +            FTy.getParamType(0)->isPointerTy() && +            FTy.getParamType(1) == SizeTTy && FTy.getParamType(2) == SizeTTy); + +  case LibFunc_wcslen: +    return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && +            FTy.getReturnType()->isIntegerTy()); + +  case LibFunc_cabs: +  case LibFunc_cabsf: +  case LibFunc_cabsl: { +    Type* RetTy = FTy.getReturnType(); +    if (!RetTy->isFloatingPointTy()) +      return false; + +    // NOTE: These prototypes are target specific and currently support +    // "complex" passed as an array or discrete real & imaginary parameters. +    // Add other calling conventions to enable libcall optimizations. +    if (NumParams == 1) +      return (FTy.getParamType(0)->isArrayTy() && +              FTy.getParamType(0)->getArrayNumElements() == 2 && +              FTy.getParamType(0)->getArrayElementType() == RetTy); +    else if (NumParams == 2) +      return (FTy.getParamType(0) == RetTy && FTy.getParamType(1) == RetTy); +    else +      return false; +  } +  case LibFunc::NumLibFuncs: +    break; +  } + +  llvm_unreachable("Invalid libfunc"); +} + +bool TargetLibraryInfoImpl::getLibFunc(const Function &FDecl, +                                       LibFunc &F) const { +  const DataLayout *DL = +      FDecl.getParent() ? &FDecl.getParent()->getDataLayout() : nullptr; +  return getLibFunc(FDecl.getName(), F) && +         isValidProtoForLibFunc(*FDecl.getFunctionType(), F, DL); +} + +void TargetLibraryInfoImpl::disableAllFunctions() { +  memset(AvailableArray, 0, sizeof(AvailableArray)); +} + +static bool compareByScalarFnName(const VecDesc &LHS, const VecDesc &RHS) { +  return LHS.ScalarFnName < RHS.ScalarFnName; +} + +static bool compareByVectorFnName(const VecDesc &LHS, const VecDesc &RHS) { +  return LHS.VectorFnName < RHS.VectorFnName; +} + +static bool compareWithScalarFnName(const VecDesc &LHS, StringRef S) { +  return LHS.ScalarFnName < S; +} + +static bool compareWithVectorFnName(const VecDesc &LHS, StringRef S) { +  return LHS.VectorFnName < S; +} + +void TargetLibraryInfoImpl::addVectorizableFunctions(ArrayRef<VecDesc> Fns) { +  VectorDescs.insert(VectorDescs.end(), Fns.begin(), Fns.end()); +  llvm::sort(VectorDescs.begin(), VectorDescs.end(), compareByScalarFnName); + +  ScalarDescs.insert(ScalarDescs.end(), Fns.begin(), Fns.end()); +  llvm::sort(ScalarDescs.begin(), ScalarDescs.end(), compareByVectorFnName); +} + +void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( +    enum VectorLibrary VecLib) { +  switch (VecLib) { +  case Accelerate: { +    const VecDesc VecFuncs[] = { +        // Floating-Point Arithmetic and Auxiliary Functions +        {"ceilf", "vceilf", 4}, +        {"fabsf", "vfabsf", 4}, +        {"llvm.fabs.f32", "vfabsf", 4}, +        {"floorf", "vfloorf", 4}, +        {"sqrtf", "vsqrtf", 4}, +        {"llvm.sqrt.f32", "vsqrtf", 4}, + +        // Exponential and Logarithmic Functions +        {"expf", "vexpf", 4}, +        {"llvm.exp.f32", "vexpf", 4}, +        {"expm1f", "vexpm1f", 4}, +        {"logf", "vlogf", 4}, +        {"llvm.log.f32", "vlogf", 4}, +        {"log1pf", "vlog1pf", 4}, +        {"log10f", "vlog10f", 4}, +        {"llvm.log10.f32", "vlog10f", 4}, +        {"logbf", "vlogbf", 4}, + +        // Trigonometric Functions +        {"sinf", "vsinf", 4}, +        {"llvm.sin.f32", "vsinf", 4}, +        {"cosf", "vcosf", 4}, +        {"llvm.cos.f32", "vcosf", 4}, +        {"tanf", "vtanf", 4}, +        {"asinf", "vasinf", 4}, +        {"acosf", "vacosf", 4}, +        {"atanf", "vatanf", 4}, + +        // Hyperbolic Functions +        {"sinhf", "vsinhf", 4}, +        {"coshf", "vcoshf", 4}, +        {"tanhf", "vtanhf", 4}, +        {"asinhf", "vasinhf", 4}, +        {"acoshf", "vacoshf", 4}, +        {"atanhf", "vatanhf", 4}, +    }; +    addVectorizableFunctions(VecFuncs); +    break; +  } +  case SVML: { +    const VecDesc VecFuncs[] = { +        {"sin", "__svml_sin2", 2}, +        {"sin", "__svml_sin4", 4}, +        {"sin", "__svml_sin8", 8}, + +        {"sinf", "__svml_sinf4", 4}, +        {"sinf", "__svml_sinf8", 8}, +        {"sinf", "__svml_sinf16", 16}, + +        {"llvm.sin.f64", "__svml_sin2", 2}, +        {"llvm.sin.f64", "__svml_sin4", 4}, +        {"llvm.sin.f64", "__svml_sin8", 8}, + +        {"llvm.sin.f32", "__svml_sinf4", 4}, +        {"llvm.sin.f32", "__svml_sinf8", 8}, +        {"llvm.sin.f32", "__svml_sinf16", 16}, + +        {"cos", "__svml_cos2", 2}, +        {"cos", "__svml_cos4", 4}, +        {"cos", "__svml_cos8", 8}, + +        {"cosf", "__svml_cosf4", 4}, +        {"cosf", "__svml_cosf8", 8}, +        {"cosf", "__svml_cosf16", 16}, + +        {"llvm.cos.f64", "__svml_cos2", 2}, +        {"llvm.cos.f64", "__svml_cos4", 4}, +        {"llvm.cos.f64", "__svml_cos8", 8}, + +        {"llvm.cos.f32", "__svml_cosf4", 4}, +        {"llvm.cos.f32", "__svml_cosf8", 8}, +        {"llvm.cos.f32", "__svml_cosf16", 16}, + +        {"pow", "__svml_pow2", 2}, +        {"pow", "__svml_pow4", 4}, +        {"pow", "__svml_pow8", 8}, + +        {"powf", "__svml_powf4", 4}, +        {"powf", "__svml_powf8", 8}, +        {"powf", "__svml_powf16", 16}, + +        { "__pow_finite", "__svml_pow2", 2 }, +        { "__pow_finite", "__svml_pow4", 4 }, +        { "__pow_finite", "__svml_pow8", 8 }, + +        { "__powf_finite", "__svml_powf4", 4 }, +        { "__powf_finite", "__svml_powf8", 8 }, +        { "__powf_finite", "__svml_powf16", 16 }, + +        {"llvm.pow.f64", "__svml_pow2", 2}, +        {"llvm.pow.f64", "__svml_pow4", 4}, +        {"llvm.pow.f64", "__svml_pow8", 8}, + +        {"llvm.pow.f32", "__svml_powf4", 4}, +        {"llvm.pow.f32", "__svml_powf8", 8}, +        {"llvm.pow.f32", "__svml_powf16", 16}, + +        {"exp", "__svml_exp2", 2}, +        {"exp", "__svml_exp4", 4}, +        {"exp", "__svml_exp8", 8}, + +        {"expf", "__svml_expf4", 4}, +        {"expf", "__svml_expf8", 8}, +        {"expf", "__svml_expf16", 16}, + +        { "__exp_finite", "__svml_exp2", 2 }, +        { "__exp_finite", "__svml_exp4", 4 }, +        { "__exp_finite", "__svml_exp8", 8 }, + +        { "__expf_finite", "__svml_expf4", 4 }, +        { "__expf_finite", "__svml_expf8", 8 }, +        { "__expf_finite", "__svml_expf16", 16 }, + +        {"llvm.exp.f64", "__svml_exp2", 2}, +        {"llvm.exp.f64", "__svml_exp4", 4}, +        {"llvm.exp.f64", "__svml_exp8", 8}, + +        {"llvm.exp.f32", "__svml_expf4", 4}, +        {"llvm.exp.f32", "__svml_expf8", 8}, +        {"llvm.exp.f32", "__svml_expf16", 16}, + +        {"log", "__svml_log2", 2}, +        {"log", "__svml_log4", 4}, +        {"log", "__svml_log8", 8}, + +        {"logf", "__svml_logf4", 4}, +        {"logf", "__svml_logf8", 8}, +        {"logf", "__svml_logf16", 16}, + +        { "__log_finite", "__svml_log2", 2 }, +        { "__log_finite", "__svml_log4", 4 }, +        { "__log_finite", "__svml_log8", 8 }, + +        { "__logf_finite", "__svml_logf4", 4 }, +        { "__logf_finite", "__svml_logf8", 8 }, +        { "__logf_finite", "__svml_logf16", 16 }, + +        {"llvm.log.f64", "__svml_log2", 2}, +        {"llvm.log.f64", "__svml_log4", 4}, +        {"llvm.log.f64", "__svml_log8", 8}, + +        {"llvm.log.f32", "__svml_logf4", 4}, +        {"llvm.log.f32", "__svml_logf8", 8}, +        {"llvm.log.f32", "__svml_logf16", 16}, +    }; +    addVectorizableFunctions(VecFuncs); +    break; +  } +  case NoLibrary: +    break; +  } +} + +bool TargetLibraryInfoImpl::isFunctionVectorizable(StringRef funcName) const { +  funcName = sanitizeFunctionName(funcName); +  if (funcName.empty()) +    return false; + +  std::vector<VecDesc>::const_iterator I = std::lower_bound( +      VectorDescs.begin(), VectorDescs.end(), funcName, +      compareWithScalarFnName); +  return I != VectorDescs.end() && StringRef(I->ScalarFnName) == funcName; +} + +StringRef TargetLibraryInfoImpl::getVectorizedFunction(StringRef F, +                                                       unsigned VF) const { +  F = sanitizeFunctionName(F); +  if (F.empty()) +    return F; +  std::vector<VecDesc>::const_iterator I = std::lower_bound( +      VectorDescs.begin(), VectorDescs.end(), F, compareWithScalarFnName); +  while (I != VectorDescs.end() && StringRef(I->ScalarFnName) == F) { +    if (I->VectorizationFactor == VF) +      return I->VectorFnName; +    ++I; +  } +  return StringRef(); +} + +StringRef TargetLibraryInfoImpl::getScalarizedFunction(StringRef F, +                                                       unsigned &VF) const { +  F = sanitizeFunctionName(F); +  if (F.empty()) +    return F; + +  std::vector<VecDesc>::const_iterator I = std::lower_bound( +      ScalarDescs.begin(), ScalarDescs.end(), F, compareWithVectorFnName); +  if (I == VectorDescs.end() || StringRef(I->VectorFnName) != F) +    return StringRef(); +  VF = I->VectorizationFactor; +  return I->ScalarFnName; +} + +TargetLibraryInfo TargetLibraryAnalysis::run(Module &M, +                                             ModuleAnalysisManager &) { +  if (PresetInfoImpl) +    return TargetLibraryInfo(*PresetInfoImpl); + +  return TargetLibraryInfo(lookupInfoImpl(Triple(M.getTargetTriple()))); +} + +TargetLibraryInfo TargetLibraryAnalysis::run(Function &F, +                                             FunctionAnalysisManager &) { +  if (PresetInfoImpl) +    return TargetLibraryInfo(*PresetInfoImpl); + +  return TargetLibraryInfo( +      lookupInfoImpl(Triple(F.getParent()->getTargetTriple()))); +} + +TargetLibraryInfoImpl &TargetLibraryAnalysis::lookupInfoImpl(const Triple &T) { +  std::unique_ptr<TargetLibraryInfoImpl> &Impl = +      Impls[T.normalize()]; +  if (!Impl) +    Impl.reset(new TargetLibraryInfoImpl(T)); + +  return *Impl; +} + +unsigned TargetLibraryInfoImpl::getWCharSize(const Module &M) const { +  if (auto *ShortWChar = cast_or_null<ConstantAsMetadata>( +      M.getModuleFlag("wchar_size"))) +    return cast<ConstantInt>(ShortWChar->getValue())->getZExtValue(); +  return 0; +} + +TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass() +    : ImmutablePass(ID), TLIImpl(), TLI(TLIImpl) { +  initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass(const Triple &T) +    : ImmutablePass(ID), TLIImpl(T), TLI(TLIImpl) { +  initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass( +    const TargetLibraryInfoImpl &TLIImpl) +    : ImmutablePass(ID), TLIImpl(TLIImpl), TLI(this->TLIImpl) { +  initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +AnalysisKey TargetLibraryAnalysis::Key; + +// Register the basic pass. +INITIALIZE_PASS(TargetLibraryInfoWrapperPass, "targetlibinfo", +                "Target Library Information", false, true) +char TargetLibraryInfoWrapperPass::ID = 0; + +void TargetLibraryInfoWrapperPass::anchor() {} diff --git a/contrib/llvm/lib/Analysis/TargetTransformInfo.cpp b/contrib/llvm/lib/Analysis/TargetTransformInfo.cpp new file mode 100644 index 000000000000..7233a86e5daf --- /dev/null +++ b/contrib/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -0,0 +1,1195 @@ +//===- llvm/Analysis/TargetTransformInfo.cpp ------------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/TargetTransformInfoImpl.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include <utility> + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "tti" + +static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false), +                                     cl::Hidden, +                                     cl::desc("Recognize reduction patterns.")); + +namespace { +/// No-op implementation of the TTI interface using the utility base +/// classes. +/// +/// This is used when no target specific information is available. +struct NoTTIImpl : TargetTransformInfoImplCRTPBase<NoTTIImpl> { +  explicit NoTTIImpl(const DataLayout &DL) +      : TargetTransformInfoImplCRTPBase<NoTTIImpl>(DL) {} +}; +} + +TargetTransformInfo::TargetTransformInfo(const DataLayout &DL) +    : TTIImpl(new Model<NoTTIImpl>(NoTTIImpl(DL))) {} + +TargetTransformInfo::~TargetTransformInfo() {} + +TargetTransformInfo::TargetTransformInfo(TargetTransformInfo &&Arg) +    : TTIImpl(std::move(Arg.TTIImpl)) {} + +TargetTransformInfo &TargetTransformInfo::operator=(TargetTransformInfo &&RHS) { +  TTIImpl = std::move(RHS.TTIImpl); +  return *this; +} + +int TargetTransformInfo::getOperationCost(unsigned Opcode, Type *Ty, +                                          Type *OpTy) const { +  int Cost = TTIImpl->getOperationCost(Opcode, Ty, OpTy); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getCallCost(FunctionType *FTy, int NumArgs) const { +  int Cost = TTIImpl->getCallCost(FTy, NumArgs); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getCallCost(const Function *F, +                                     ArrayRef<const Value *> Arguments) const { +  int Cost = TTIImpl->getCallCost(F, Arguments); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +unsigned TargetTransformInfo::getInliningThresholdMultiplier() const { +  return TTIImpl->getInliningThresholdMultiplier(); +} + +int TargetTransformInfo::getGEPCost(Type *PointeeType, const Value *Ptr, +                                    ArrayRef<const Value *> Operands) const { +  return TTIImpl->getGEPCost(PointeeType, Ptr, Operands); +} + +int TargetTransformInfo::getExtCost(const Instruction *I, +                                    const Value *Src) const { +  return TTIImpl->getExtCost(I, Src); +} + +int TargetTransformInfo::getIntrinsicCost( +    Intrinsic::ID IID, Type *RetTy, ArrayRef<const Value *> Arguments) const { +  int Cost = TTIImpl->getIntrinsicCost(IID, RetTy, Arguments); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +unsigned +TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI, +                                                      unsigned &JTSize) const { +  return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); +} + +int TargetTransformInfo::getUserCost(const User *U, +    ArrayRef<const Value *> Operands) const { +  int Cost = TTIImpl->getUserCost(U, Operands); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +bool TargetTransformInfo::hasBranchDivergence() const { +  return TTIImpl->hasBranchDivergence(); +} + +bool TargetTransformInfo::isSourceOfDivergence(const Value *V) const { +  return TTIImpl->isSourceOfDivergence(V); +} + +bool llvm::TargetTransformInfo::isAlwaysUniform(const Value *V) const { +  return TTIImpl->isAlwaysUniform(V); +} + +unsigned TargetTransformInfo::getFlatAddressSpace() const { +  return TTIImpl->getFlatAddressSpace(); +} + +bool TargetTransformInfo::isLoweredToCall(const Function *F) const { +  return TTIImpl->isLoweredToCall(F); +} + +void TargetTransformInfo::getUnrollingPreferences( +    Loop *L, ScalarEvolution &SE, UnrollingPreferences &UP) const { +  return TTIImpl->getUnrollingPreferences(L, SE, UP); +} + +bool TargetTransformInfo::isLegalAddImmediate(int64_t Imm) const { +  return TTIImpl->isLegalAddImmediate(Imm); +} + +bool TargetTransformInfo::isLegalICmpImmediate(int64_t Imm) const { +  return TTIImpl->isLegalICmpImmediate(Imm); +} + +bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, +                                                int64_t BaseOffset, +                                                bool HasBaseReg, +                                                int64_t Scale, +                                                unsigned AddrSpace, +                                                Instruction *I) const { +  return TTIImpl->isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, +                                        Scale, AddrSpace, I); +} + +bool TargetTransformInfo::isLSRCostLess(LSRCost &C1, LSRCost &C2) const { +  return TTIImpl->isLSRCostLess(C1, C2); +} + +bool TargetTransformInfo::canMacroFuseCmp() const { +  return TTIImpl->canMacroFuseCmp(); +} + +bool TargetTransformInfo::shouldFavorPostInc() const { +  return TTIImpl->shouldFavorPostInc(); +} + +bool TargetTransformInfo::isLegalMaskedStore(Type *DataType) const { +  return TTIImpl->isLegalMaskedStore(DataType); +} + +bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType) const { +  return TTIImpl->isLegalMaskedLoad(DataType); +} + +bool TargetTransformInfo::isLegalMaskedGather(Type *DataType) const { +  return TTIImpl->isLegalMaskedGather(DataType); +} + +bool TargetTransformInfo::isLegalMaskedScatter(Type *DataType) const { +  return TTIImpl->isLegalMaskedScatter(DataType); +} + +bool TargetTransformInfo::hasDivRemOp(Type *DataType, bool IsSigned) const { +  return TTIImpl->hasDivRemOp(DataType, IsSigned); +} + +bool TargetTransformInfo::hasVolatileVariant(Instruction *I, +                                             unsigned AddrSpace) const { +  return TTIImpl->hasVolatileVariant(I, AddrSpace); +} + +bool TargetTransformInfo::prefersVectorizedAddressing() const { +  return TTIImpl->prefersVectorizedAddressing(); +} + +int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, +                                              int64_t BaseOffset, +                                              bool HasBaseReg, +                                              int64_t Scale, +                                              unsigned AddrSpace) const { +  int Cost = TTIImpl->getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, +                                           Scale, AddrSpace); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +bool TargetTransformInfo::LSRWithInstrQueries() const { +  return TTIImpl->LSRWithInstrQueries(); +} + +bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const { +  return TTIImpl->isTruncateFree(Ty1, Ty2); +} + +bool TargetTransformInfo::isProfitableToHoist(Instruction *I) const { +  return TTIImpl->isProfitableToHoist(I); +} + +bool TargetTransformInfo::useAA() const { return TTIImpl->useAA(); } + +bool TargetTransformInfo::isTypeLegal(Type *Ty) const { +  return TTIImpl->isTypeLegal(Ty); +} + +unsigned TargetTransformInfo::getJumpBufAlignment() const { +  return TTIImpl->getJumpBufAlignment(); +} + +unsigned TargetTransformInfo::getJumpBufSize() const { +  return TTIImpl->getJumpBufSize(); +} + +bool TargetTransformInfo::shouldBuildLookupTables() const { +  return TTIImpl->shouldBuildLookupTables(); +} +bool TargetTransformInfo::shouldBuildLookupTablesForConstant(Constant *C) const { +  return TTIImpl->shouldBuildLookupTablesForConstant(C); +} + +bool TargetTransformInfo::useColdCCForColdCall(Function &F) const { +  return TTIImpl->useColdCCForColdCall(F); +} + +unsigned TargetTransformInfo:: +getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) const { +  return TTIImpl->getScalarizationOverhead(Ty, Insert, Extract); +} + +unsigned TargetTransformInfo:: +getOperandsScalarizationOverhead(ArrayRef<const Value *> Args, +                                 unsigned VF) const { +  return TTIImpl->getOperandsScalarizationOverhead(Args, VF); +} + +bool TargetTransformInfo::supportsEfficientVectorElementLoadStore() const { +  return TTIImpl->supportsEfficientVectorElementLoadStore(); +} + +bool TargetTransformInfo::enableAggressiveInterleaving(bool LoopHasReductions) const { +  return TTIImpl->enableAggressiveInterleaving(LoopHasReductions); +} + +const TargetTransformInfo::MemCmpExpansionOptions * +TargetTransformInfo::enableMemCmpExpansion(bool IsZeroCmp) const { +  return TTIImpl->enableMemCmpExpansion(IsZeroCmp); +} + +bool TargetTransformInfo::enableInterleavedAccessVectorization() const { +  return TTIImpl->enableInterleavedAccessVectorization(); +} + +bool TargetTransformInfo::isFPVectorizationPotentiallyUnsafe() const { +  return TTIImpl->isFPVectorizationPotentiallyUnsafe(); +} + +bool TargetTransformInfo::allowsMisalignedMemoryAccesses(LLVMContext &Context, +                                                         unsigned BitWidth, +                                                         unsigned AddressSpace, +                                                         unsigned Alignment, +                                                         bool *Fast) const { +  return TTIImpl->allowsMisalignedMemoryAccesses(Context, BitWidth, AddressSpace, +                                                 Alignment, Fast); +} + +TargetTransformInfo::PopcntSupportKind +TargetTransformInfo::getPopcntSupport(unsigned IntTyWidthInBit) const { +  return TTIImpl->getPopcntSupport(IntTyWidthInBit); +} + +bool TargetTransformInfo::haveFastSqrt(Type *Ty) const { +  return TTIImpl->haveFastSqrt(Ty); +} + +bool TargetTransformInfo::isFCmpOrdCheaperThanFCmpZero(Type *Ty) const { +  return TTIImpl->isFCmpOrdCheaperThanFCmpZero(Ty); +} + +int TargetTransformInfo::getFPOpCost(Type *Ty) const { +  int Cost = TTIImpl->getFPOpCost(Ty); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getIntImmCodeSizeCost(unsigned Opcode, unsigned Idx, +                                               const APInt &Imm, +                                               Type *Ty) const { +  int Cost = TTIImpl->getIntImmCodeSizeCost(Opcode, Idx, Imm, Ty); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getIntImmCost(const APInt &Imm, Type *Ty) const { +  int Cost = TTIImpl->getIntImmCost(Imm, Ty); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getIntImmCost(unsigned Opcode, unsigned Idx, +                                       const APInt &Imm, Type *Ty) const { +  int Cost = TTIImpl->getIntImmCost(Opcode, Idx, Imm, Ty); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getIntImmCost(Intrinsic::ID IID, unsigned Idx, +                                       const APInt &Imm, Type *Ty) const { +  int Cost = TTIImpl->getIntImmCost(IID, Idx, Imm, Ty); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +unsigned TargetTransformInfo::getNumberOfRegisters(bool Vector) const { +  return TTIImpl->getNumberOfRegisters(Vector); +} + +unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const { +  return TTIImpl->getRegisterBitWidth(Vector); +} + +unsigned TargetTransformInfo::getMinVectorRegisterBitWidth() const { +  return TTIImpl->getMinVectorRegisterBitWidth(); +} + +bool TargetTransformInfo::shouldMaximizeVectorBandwidth(bool OptSize) const { +  return TTIImpl->shouldMaximizeVectorBandwidth(OptSize); +} + +unsigned TargetTransformInfo::getMinimumVF(unsigned ElemWidth) const { +  return TTIImpl->getMinimumVF(ElemWidth); +} + +bool TargetTransformInfo::shouldConsiderAddressTypePromotion( +    const Instruction &I, bool &AllowPromotionWithoutCommonHeader) const { +  return TTIImpl->shouldConsiderAddressTypePromotion( +      I, AllowPromotionWithoutCommonHeader); +} + +unsigned TargetTransformInfo::getCacheLineSize() const { +  return TTIImpl->getCacheLineSize(); +} + +llvm::Optional<unsigned> TargetTransformInfo::getCacheSize(CacheLevel Level) +  const { +  return TTIImpl->getCacheSize(Level); +} + +llvm::Optional<unsigned> TargetTransformInfo::getCacheAssociativity( +  CacheLevel Level) const { +  return TTIImpl->getCacheAssociativity(Level); +} + +unsigned TargetTransformInfo::getPrefetchDistance() const { +  return TTIImpl->getPrefetchDistance(); +} + +unsigned TargetTransformInfo::getMinPrefetchStride() const { +  return TTIImpl->getMinPrefetchStride(); +} + +unsigned TargetTransformInfo::getMaxPrefetchIterationsAhead() const { +  return TTIImpl->getMaxPrefetchIterationsAhead(); +} + +unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const { +  return TTIImpl->getMaxInterleaveFactor(VF); +} + +int TargetTransformInfo::getArithmeticInstrCost( +    unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, +    OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo, +    OperandValueProperties Opd2PropInfo, +    ArrayRef<const Value *> Args) const { +  int Cost = TTIImpl->getArithmeticInstrCost(Opcode, Ty, Opd1Info, Opd2Info, +                                             Opd1PropInfo, Opd2PropInfo, Args); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getShuffleCost(ShuffleKind Kind, Type *Ty, int Index, +                                        Type *SubTp) const { +  int Cost = TTIImpl->getShuffleCost(Kind, Ty, Index, SubTp); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, +                                 Type *Src, const Instruction *I) const { +  assert ((I == nullptr || I->getOpcode() == Opcode) && +          "Opcode should reflect passed instruction."); +  int Cost = TTIImpl->getCastInstrCost(Opcode, Dst, Src, I); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getExtractWithExtendCost(unsigned Opcode, Type *Dst, +                                                  VectorType *VecTy, +                                                  unsigned Index) const { +  int Cost = TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const { +  int Cost = TTIImpl->getCFInstrCost(Opcode); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, +                                 Type *CondTy, const Instruction *I) const { +  assert ((I == nullptr || I->getOpcode() == Opcode) && +          "Opcode should reflect passed instruction."); +  int Cost = TTIImpl->getCmpSelInstrCost(Opcode, ValTy, CondTy, I); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getVectorInstrCost(unsigned Opcode, Type *Val, +                                            unsigned Index) const { +  int Cost = TTIImpl->getVectorInstrCost(Opcode, Val, Index); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getMemoryOpCost(unsigned Opcode, Type *Src, +                                         unsigned Alignment, +                                         unsigned AddressSpace, +                                         const Instruction *I) const { +  assert ((I == nullptr || I->getOpcode() == Opcode) && +          "Opcode should reflect passed instruction."); +  int Cost = TTIImpl->getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, I); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, +                                               unsigned Alignment, +                                               unsigned AddressSpace) const { +  int Cost = +      TTIImpl->getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getGatherScatterOpCost(unsigned Opcode, Type *DataTy, +                                                Value *Ptr, bool VariableMask, +                                                unsigned Alignment) const { +  int Cost = TTIImpl->getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, +                                             Alignment); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getInterleavedMemoryOpCost( +    unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, +    unsigned Alignment, unsigned AddressSpace) const { +  int Cost = TTIImpl->getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, +                                                 Alignment, AddressSpace); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, +                                    ArrayRef<Type *> Tys, FastMathFlags FMF, +                                    unsigned ScalarizationCostPassed) const { +  int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Tys, FMF, +                                            ScalarizationCostPassed); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, +           ArrayRef<Value *> Args, FastMathFlags FMF, unsigned VF) const { +  int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args, FMF, VF); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getCallInstrCost(Function *F, Type *RetTy, +                                          ArrayRef<Type *> Tys) const { +  int Cost = TTIImpl->getCallInstrCost(F, RetTy, Tys); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +unsigned TargetTransformInfo::getNumberOfParts(Type *Tp) const { +  return TTIImpl->getNumberOfParts(Tp); +} + +int TargetTransformInfo::getAddressComputationCost(Type *Tp, +                                                   ScalarEvolution *SE, +                                                   const SCEV *Ptr) const { +  int Cost = TTIImpl->getAddressComputationCost(Tp, SE, Ptr); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getArithmeticReductionCost(unsigned Opcode, Type *Ty, +                                                    bool IsPairwiseForm) const { +  int Cost = TTIImpl->getArithmeticReductionCost(Opcode, Ty, IsPairwiseForm); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +int TargetTransformInfo::getMinMaxReductionCost(Type *Ty, Type *CondTy, +                                                bool IsPairwiseForm, +                                                bool IsUnsigned) const { +  int Cost = +      TTIImpl->getMinMaxReductionCost(Ty, CondTy, IsPairwiseForm, IsUnsigned); +  assert(Cost >= 0 && "TTI should not produce negative costs!"); +  return Cost; +} + +unsigned +TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const { +  return TTIImpl->getCostOfKeepingLiveOverCall(Tys); +} + +bool TargetTransformInfo::getTgtMemIntrinsic(IntrinsicInst *Inst, +                                             MemIntrinsicInfo &Info) const { +  return TTIImpl->getTgtMemIntrinsic(Inst, Info); +} + +unsigned TargetTransformInfo::getAtomicMemIntrinsicMaxElementSize() const { +  return TTIImpl->getAtomicMemIntrinsicMaxElementSize(); +} + +Value *TargetTransformInfo::getOrCreateResultFromMemIntrinsic( +    IntrinsicInst *Inst, Type *ExpectedType) const { +  return TTIImpl->getOrCreateResultFromMemIntrinsic(Inst, ExpectedType); +} + +Type *TargetTransformInfo::getMemcpyLoopLoweringType(LLVMContext &Context, +                                                     Value *Length, +                                                     unsigned SrcAlign, +                                                     unsigned DestAlign) const { +  return TTIImpl->getMemcpyLoopLoweringType(Context, Length, SrcAlign, +                                            DestAlign); +} + +void TargetTransformInfo::getMemcpyLoopResidualLoweringType( +    SmallVectorImpl<Type *> &OpsOut, LLVMContext &Context, +    unsigned RemainingBytes, unsigned SrcAlign, unsigned DestAlign) const { +  TTIImpl->getMemcpyLoopResidualLoweringType(OpsOut, Context, RemainingBytes, +                                             SrcAlign, DestAlign); +} + +bool TargetTransformInfo::areInlineCompatible(const Function *Caller, +                                              const Function *Callee) const { +  return TTIImpl->areInlineCompatible(Caller, Callee); +} + +bool TargetTransformInfo::isIndexedLoadLegal(MemIndexedMode Mode, +                                             Type *Ty) const { +  return TTIImpl->isIndexedLoadLegal(Mode, Ty); +} + +bool TargetTransformInfo::isIndexedStoreLegal(MemIndexedMode Mode, +                                              Type *Ty) const { +  return TTIImpl->isIndexedStoreLegal(Mode, Ty); +} + +unsigned TargetTransformInfo::getLoadStoreVecRegBitWidth(unsigned AS) const { +  return TTIImpl->getLoadStoreVecRegBitWidth(AS); +} + +bool TargetTransformInfo::isLegalToVectorizeLoad(LoadInst *LI) const { +  return TTIImpl->isLegalToVectorizeLoad(LI); +} + +bool TargetTransformInfo::isLegalToVectorizeStore(StoreInst *SI) const { +  return TTIImpl->isLegalToVectorizeStore(SI); +} + +bool TargetTransformInfo::isLegalToVectorizeLoadChain( +    unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const { +  return TTIImpl->isLegalToVectorizeLoadChain(ChainSizeInBytes, Alignment, +                                              AddrSpace); +} + +bool TargetTransformInfo::isLegalToVectorizeStoreChain( +    unsigned ChainSizeInBytes, unsigned Alignment, unsigned AddrSpace) const { +  return TTIImpl->isLegalToVectorizeStoreChain(ChainSizeInBytes, Alignment, +                                               AddrSpace); +} + +unsigned TargetTransformInfo::getLoadVectorFactor(unsigned VF, +                                                  unsigned LoadSize, +                                                  unsigned ChainSizeInBytes, +                                                  VectorType *VecTy) const { +  return TTIImpl->getLoadVectorFactor(VF, LoadSize, ChainSizeInBytes, VecTy); +} + +unsigned TargetTransformInfo::getStoreVectorFactor(unsigned VF, +                                                   unsigned StoreSize, +                                                   unsigned ChainSizeInBytes, +                                                   VectorType *VecTy) const { +  return TTIImpl->getStoreVectorFactor(VF, StoreSize, ChainSizeInBytes, VecTy); +} + +bool TargetTransformInfo::useReductionIntrinsic(unsigned Opcode, +                                                Type *Ty, ReductionFlags Flags) const { +  return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags); +} + +bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const { +  return TTIImpl->shouldExpandReduction(II); +} + +int TargetTransformInfo::getInstructionLatency(const Instruction *I) const { +  return TTIImpl->getInstructionLatency(I); +} + +static TargetTransformInfo::OperandValueKind +getOperandInfo(Value *V, TargetTransformInfo::OperandValueProperties &OpProps) { +  TargetTransformInfo::OperandValueKind OpInfo = +      TargetTransformInfo::OK_AnyValue; +  OpProps = TargetTransformInfo::OP_None; + +  if (auto *CI = dyn_cast<ConstantInt>(V)) { +    if (CI->getValue().isPowerOf2()) +      OpProps = TargetTransformInfo::OP_PowerOf2; +    return TargetTransformInfo::OK_UniformConstantValue; +  } + +  const Value *Splat = getSplatValue(V); + +  // Check for a splat of a constant or for a non uniform vector of constants +  // and check if the constant(s) are all powers of two. +  if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) { +    OpInfo = TargetTransformInfo::OK_NonUniformConstantValue; +    if (Splat) { +      OpInfo = TargetTransformInfo::OK_UniformConstantValue; +      if (auto *CI = dyn_cast<ConstantInt>(Splat)) +        if (CI->getValue().isPowerOf2()) +          OpProps = TargetTransformInfo::OP_PowerOf2; +    } else if (auto *CDS = dyn_cast<ConstantDataSequential>(V)) { +      OpProps = TargetTransformInfo::OP_PowerOf2; +      for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) { +        if (auto *CI = dyn_cast<ConstantInt>(CDS->getElementAsConstant(I))) +          if (CI->getValue().isPowerOf2()) +            continue; +        OpProps = TargetTransformInfo::OP_None; +        break; +      } +    } +  } + +  // Check for a splat of a uniform value. This is not loop aware, so return +  // true only for the obviously uniform cases (argument, globalvalue) +  if (Splat && (isa<Argument>(Splat) || isa<GlobalValue>(Splat))) +    OpInfo = TargetTransformInfo::OK_UniformValue; + +  return OpInfo; +} + +static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, +                                     unsigned Level) { +  // We don't need a shuffle if we just want to have element 0 in position 0 of +  // the vector. +  if (!SI && Level == 0 && IsLeft) +    return true; +  else if (!SI) +    return false; + +  SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1); + +  // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether +  // we look at the left or right side. +  for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2) +    Mask[i] = val; + +  SmallVector<int, 16> ActualMask = SI->getShuffleMask(); +  return Mask == ActualMask; +} + +namespace { +/// Kind of the reduction data. +enum ReductionKind { +  RK_None,           /// Not a reduction. +  RK_Arithmetic,     /// Binary reduction data. +  RK_MinMax,         /// Min/max reduction data. +  RK_UnsignedMinMax, /// Unsigned min/max reduction data. +}; +/// Contains opcode + LHS/RHS parts of the reduction operations. +struct ReductionData { +  ReductionData() = delete; +  ReductionData(ReductionKind Kind, unsigned Opcode, Value *LHS, Value *RHS) +      : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind) { +    assert(Kind != RK_None && "expected binary or min/max reduction only."); +  } +  unsigned Opcode = 0; +  Value *LHS = nullptr; +  Value *RHS = nullptr; +  ReductionKind Kind = RK_None; +  bool hasSameData(ReductionData &RD) const { +    return Kind == RD.Kind && Opcode == RD.Opcode; +  } +}; +} // namespace + +static Optional<ReductionData> getReductionData(Instruction *I) { +  Value *L, *R; +  if (m_BinOp(m_Value(L), m_Value(R)).match(I)) +    return ReductionData(RK_Arithmetic, I->getOpcode(), L, R); +  if (auto *SI = dyn_cast<SelectInst>(I)) { +    if (m_SMin(m_Value(L), m_Value(R)).match(SI) || +        m_SMax(m_Value(L), m_Value(R)).match(SI) || +        m_OrdFMin(m_Value(L), m_Value(R)).match(SI) || +        m_OrdFMax(m_Value(L), m_Value(R)).match(SI) || +        m_UnordFMin(m_Value(L), m_Value(R)).match(SI) || +        m_UnordFMax(m_Value(L), m_Value(R)).match(SI)) { +      auto *CI = cast<CmpInst>(SI->getCondition()); +      return ReductionData(RK_MinMax, CI->getOpcode(), L, R); +    } +    if (m_UMin(m_Value(L), m_Value(R)).match(SI) || +        m_UMax(m_Value(L), m_Value(R)).match(SI)) { +      auto *CI = cast<CmpInst>(SI->getCondition()); +      return ReductionData(RK_UnsignedMinMax, CI->getOpcode(), L, R); +    } +  } +  return llvm::None; +} + +static ReductionKind matchPairwiseReductionAtLevel(Instruction *I, +                                                   unsigned Level, +                                                   unsigned NumLevels) { +  // Match one level of pairwise operations. +  // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, +  //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> +  // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, +  //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> +  // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 +  if (!I) +    return RK_None; + +  assert(I->getType()->isVectorTy() && "Expecting a vector type"); + +  Optional<ReductionData> RD = getReductionData(I); +  if (!RD) +    return RK_None; + +  ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(RD->LHS); +  if (!LS && Level) +    return RK_None; +  ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(RD->RHS); +  if (!RS && Level) +    return RK_None; + +  // On level 0 we can omit one shufflevector instruction. +  if (!Level && !RS && !LS) +    return RK_None; + +  // Shuffle inputs must match. +  Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr; +  Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr; +  Value *NextLevelOp = nullptr; +  if (NextLevelOpR && NextLevelOpL) { +    // If we have two shuffles their operands must match. +    if (NextLevelOpL != NextLevelOpR) +      return RK_None; + +    NextLevelOp = NextLevelOpL; +  } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { +    // On the first level we can omit the shufflevector <0, undef,...>. So the +    // input to the other shufflevector <1, undef> must match with one of the +    // inputs to the current binary operation. +    // Example: +    //  %NextLevelOpL = shufflevector %R, <1, undef ...> +    //  %BinOp        = fadd          %NextLevelOpL, %R +    if (NextLevelOpL && NextLevelOpL != RD->RHS) +      return RK_None; +    else if (NextLevelOpR && NextLevelOpR != RD->LHS) +      return RK_None; + +    NextLevelOp = NextLevelOpL ? RD->RHS : RD->LHS; +  } else +    return RK_None; + +  // Check that the next levels binary operation exists and matches with the +  // current one. +  if (Level + 1 != NumLevels) { +    Optional<ReductionData> NextLevelRD = +        getReductionData(cast<Instruction>(NextLevelOp)); +    if (!NextLevelRD || !RD->hasSameData(*NextLevelRD)) +      return RK_None; +  } + +  // Shuffle mask for pairwise operation must match. +  if (matchPairwiseShuffleMask(LS, /*IsLeft=*/true, Level)) { +    if (!matchPairwiseShuffleMask(RS, /*IsLeft=*/false, Level)) +      return RK_None; +  } else if (matchPairwiseShuffleMask(RS, /*IsLeft=*/true, Level)) { +    if (!matchPairwiseShuffleMask(LS, /*IsLeft=*/false, Level)) +      return RK_None; +  } else { +    return RK_None; +  } + +  if (++Level == NumLevels) +    return RD->Kind; + +  // Match next level. +  return matchPairwiseReductionAtLevel(cast<Instruction>(NextLevelOp), Level, +                                       NumLevels); +} + +static ReductionKind matchPairwiseReduction(const ExtractElementInst *ReduxRoot, +                                            unsigned &Opcode, Type *&Ty) { +  if (!EnableReduxCost) +    return RK_None; + +  // Need to extract the first element. +  ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); +  unsigned Idx = ~0u; +  if (CI) +    Idx = CI->getZExtValue(); +  if (Idx != 0) +    return RK_None; + +  auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); +  if (!RdxStart) +    return RK_None; +  Optional<ReductionData> RD = getReductionData(RdxStart); +  if (!RD) +    return RK_None; + +  Type *VecTy = RdxStart->getType(); +  unsigned NumVecElems = VecTy->getVectorNumElements(); +  if (!isPowerOf2_32(NumVecElems)) +    return RK_None; + +  // We look for a sequence of shuffle,shuffle,add triples like the following +  // that builds a pairwise reduction tree. +  // +  //  (X0, X1, X2, X3) +  //   (X0 + X1, X2 + X3, undef, undef) +  //    ((X0 + X1) + (X2 + X3), undef, undef, undef) +  // +  // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, +  //       <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> +  // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, +  //       <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> +  // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 +  // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, +  //       <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef> +  // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, +  //       <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> +  // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 +  // %r = extractelement <4 x float> %bin.rdx8, i32 0 +  if (matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems)) == +      RK_None) +    return RK_None; + +  Opcode = RD->Opcode; +  Ty = VecTy; + +  return RD->Kind; +} + +static std::pair<Value *, ShuffleVectorInst *> +getShuffleAndOtherOprd(Value *L, Value *R) { +  ShuffleVectorInst *S = nullptr; + +  if ((S = dyn_cast<ShuffleVectorInst>(L))) +    return std::make_pair(R, S); + +  S = dyn_cast<ShuffleVectorInst>(R); +  return std::make_pair(L, S); +} + +static ReductionKind +matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, +                              unsigned &Opcode, Type *&Ty) { +  if (!EnableReduxCost) +    return RK_None; + +  // Need to extract the first element. +  ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); +  unsigned Idx = ~0u; +  if (CI) +    Idx = CI->getZExtValue(); +  if (Idx != 0) +    return RK_None; + +  auto *RdxStart = dyn_cast<Instruction>(ReduxRoot->getOperand(0)); +  if (!RdxStart) +    return RK_None; +  Optional<ReductionData> RD = getReductionData(RdxStart); +  if (!RD) +    return RK_None; + +  Type *VecTy = ReduxRoot->getOperand(0)->getType(); +  unsigned NumVecElems = VecTy->getVectorNumElements(); +  if (!isPowerOf2_32(NumVecElems)) +    return RK_None; + +  // We look for a sequence of shuffles and adds like the following matching one +  // fadd, shuffle vector pair at a time. +  // +  // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef, +  //                           <4 x i32> <i32 2, i32 3, i32 undef, i32 undef> +  // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf +  // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef, +  //                          <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> +  // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7 +  // %r = extractelement <4 x float> %bin.rdx8, i32 0 + +  unsigned MaskStart = 1; +  Instruction *RdxOp = RdxStart; +  SmallVector<int, 32> ShuffleMask(NumVecElems, 0); +  unsigned NumVecElemsRemain = NumVecElems; +  while (NumVecElemsRemain - 1) { +    // Check for the right reduction operation. +    if (!RdxOp) +      return RK_None; +    Optional<ReductionData> RDLevel = getReductionData(RdxOp); +    if (!RDLevel || !RDLevel->hasSameData(*RD)) +      return RK_None; + +    Value *NextRdxOp; +    ShuffleVectorInst *Shuffle; +    std::tie(NextRdxOp, Shuffle) = +        getShuffleAndOtherOprd(RDLevel->LHS, RDLevel->RHS); + +    // Check the current reduction operation and the shuffle use the same value. +    if (Shuffle == nullptr) +      return RK_None; +    if (Shuffle->getOperand(0) != NextRdxOp) +      return RK_None; + +    // Check that shuffle masks matches. +    for (unsigned j = 0; j != MaskStart; ++j) +      ShuffleMask[j] = MaskStart + j; +    // Fill the rest of the mask with -1 for undef. +    std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1); + +    SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); +    if (ShuffleMask != Mask) +      return RK_None; + +    RdxOp = dyn_cast<Instruction>(NextRdxOp); +    NumVecElemsRemain /= 2; +    MaskStart *= 2; +  } + +  Opcode = RD->Opcode; +  Ty = VecTy; +  return RD->Kind; +} + +int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { +  switch (I->getOpcode()) { +  case Instruction::GetElementPtr: +    return getUserCost(I); + +  case Instruction::Ret: +  case Instruction::PHI: +  case Instruction::Br: { +    return getCFInstrCost(I->getOpcode()); +  } +  case Instruction::Add: +  case Instruction::FAdd: +  case Instruction::Sub: +  case Instruction::FSub: +  case Instruction::Mul: +  case Instruction::FMul: +  case Instruction::UDiv: +  case Instruction::SDiv: +  case Instruction::FDiv: +  case Instruction::URem: +  case Instruction::SRem: +  case Instruction::FRem: +  case Instruction::Shl: +  case Instruction::LShr: +  case Instruction::AShr: +  case Instruction::And: +  case Instruction::Or: +  case Instruction::Xor: { +    TargetTransformInfo::OperandValueKind Op1VK, Op2VK; +    TargetTransformInfo::OperandValueProperties Op1VP, Op2VP; +    Op1VK = getOperandInfo(I->getOperand(0), Op1VP); +    Op2VK = getOperandInfo(I->getOperand(1), Op2VP); +    SmallVector<const Value *, 2> Operands(I->operand_values()); +    return getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, Op2VK, +                                  Op1VP, Op2VP, Operands); +  } +  case Instruction::Select: { +    const SelectInst *SI = cast<SelectInst>(I); +    Type *CondTy = SI->getCondition()->getType(); +    return getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy, I); +  } +  case Instruction::ICmp: +  case Instruction::FCmp: { +    Type *ValTy = I->getOperand(0)->getType(); +    return getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(), I); +  } +  case Instruction::Store: { +    const StoreInst *SI = cast<StoreInst>(I); +    Type *ValTy = SI->getValueOperand()->getType(); +    return getMemoryOpCost(I->getOpcode(), ValTy, +                                SI->getAlignment(), +                                SI->getPointerAddressSpace(), I); +  } +  case Instruction::Load: { +    const LoadInst *LI = cast<LoadInst>(I); +    return getMemoryOpCost(I->getOpcode(), I->getType(), +                                LI->getAlignment(), +                                LI->getPointerAddressSpace(), I); +  } +  case Instruction::ZExt: +  case Instruction::SExt: +  case Instruction::FPToUI: +  case Instruction::FPToSI: +  case Instruction::FPExt: +  case Instruction::PtrToInt: +  case Instruction::IntToPtr: +  case Instruction::SIToFP: +  case Instruction::UIToFP: +  case Instruction::Trunc: +  case Instruction::FPTrunc: +  case Instruction::BitCast: +  case Instruction::AddrSpaceCast: { +    Type *SrcTy = I->getOperand(0)->getType(); +    return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, I); +  } +  case Instruction::ExtractElement: { +    const ExtractElementInst * EEI = cast<ExtractElementInst>(I); +    ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1)); +    unsigned Idx = -1; +    if (CI) +      Idx = CI->getZExtValue(); + +    // Try to match a reduction sequence (series of shufflevector and vector +    // adds followed by a extractelement). +    unsigned ReduxOpCode; +    Type *ReduxType; + +    switch (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) { +    case RK_Arithmetic: +      return getArithmeticReductionCost(ReduxOpCode, ReduxType, +                                             /*IsPairwiseForm=*/false); +    case RK_MinMax: +      return getMinMaxReductionCost( +          ReduxType, CmpInst::makeCmpResultType(ReduxType), +          /*IsPairwiseForm=*/false, /*IsUnsigned=*/false); +    case RK_UnsignedMinMax: +      return getMinMaxReductionCost( +          ReduxType, CmpInst::makeCmpResultType(ReduxType), +          /*IsPairwiseForm=*/false, /*IsUnsigned=*/true); +    case RK_None: +      break; +    } + +    switch (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) { +    case RK_Arithmetic: +      return getArithmeticReductionCost(ReduxOpCode, ReduxType, +                                             /*IsPairwiseForm=*/true); +    case RK_MinMax: +      return getMinMaxReductionCost( +          ReduxType, CmpInst::makeCmpResultType(ReduxType), +          /*IsPairwiseForm=*/true, /*IsUnsigned=*/false); +    case RK_UnsignedMinMax: +      return getMinMaxReductionCost( +          ReduxType, CmpInst::makeCmpResultType(ReduxType), +          /*IsPairwiseForm=*/true, /*IsUnsigned=*/true); +    case RK_None: +      break; +    } + +    return getVectorInstrCost(I->getOpcode(), +                                   EEI->getOperand(0)->getType(), Idx); +  } +  case Instruction::InsertElement: { +    const InsertElementInst * IE = cast<InsertElementInst>(I); +    ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2)); +    unsigned Idx = -1; +    if (CI) +      Idx = CI->getZExtValue(); +    return getVectorInstrCost(I->getOpcode(), +                                   IE->getType(), Idx); +  } +  case Instruction::ShuffleVector: { +    const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); +    // TODO: Identify and add costs for insert/extract subvector, etc. +    if (Shuffle->changesLength()) +      return -1; + +    if (Shuffle->isIdentity()) +      return 0; + +    Type *Ty = Shuffle->getType(); +    if (Shuffle->isReverse()) +      return TTIImpl->getShuffleCost(SK_Reverse, Ty, 0, nullptr); + +    if (Shuffle->isSelect()) +      return TTIImpl->getShuffleCost(SK_Select, Ty, 0, nullptr); + +    if (Shuffle->isTranspose()) +      return TTIImpl->getShuffleCost(SK_Transpose, Ty, 0, nullptr); + +    if (Shuffle->isZeroEltSplat()) +      return TTIImpl->getShuffleCost(SK_Broadcast, Ty, 0, nullptr); + +    if (Shuffle->isSingleSource()) +      return TTIImpl->getShuffleCost(SK_PermuteSingleSrc, Ty, 0, nullptr); + +    return TTIImpl->getShuffleCost(SK_PermuteTwoSrc, Ty, 0, nullptr); +  } +  case Instruction::Call: +    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { +      SmallVector<Value *, 4> Args(II->arg_operands()); + +      FastMathFlags FMF; +      if (auto *FPMO = dyn_cast<FPMathOperator>(II)) +        FMF = FPMO->getFastMathFlags(); + +      return getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), +                                        Args, FMF); +    } +    return -1; +  default: +    // We don't have any information on this instruction. +    return -1; +  } +} + +TargetTransformInfo::Concept::~Concept() {} + +TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} + +TargetIRAnalysis::TargetIRAnalysis( +    std::function<Result(const Function &)> TTICallback) +    : TTICallback(std::move(TTICallback)) {} + +TargetIRAnalysis::Result TargetIRAnalysis::run(const Function &F, +                                               FunctionAnalysisManager &) { +  return TTICallback(F); +} + +AnalysisKey TargetIRAnalysis::Key; + +TargetIRAnalysis::Result TargetIRAnalysis::getDefaultTTI(const Function &F) { +  return Result(F.getParent()->getDataLayout()); +} + +// Register the basic pass. +INITIALIZE_PASS(TargetTransformInfoWrapperPass, "tti", +                "Target Transform Information", false, true) +char TargetTransformInfoWrapperPass::ID = 0; + +void TargetTransformInfoWrapperPass::anchor() {} + +TargetTransformInfoWrapperPass::TargetTransformInfoWrapperPass() +    : ImmutablePass(ID) { +  initializeTargetTransformInfoWrapperPassPass( +      *PassRegistry::getPassRegistry()); +} + +TargetTransformInfoWrapperPass::TargetTransformInfoWrapperPass( +    TargetIRAnalysis TIRA) +    : ImmutablePass(ID), TIRA(std::move(TIRA)) { +  initializeTargetTransformInfoWrapperPassPass( +      *PassRegistry::getPassRegistry()); +} + +TargetTransformInfo &TargetTransformInfoWrapperPass::getTTI(const Function &F) { +  FunctionAnalysisManager DummyFAM; +  TTI = TIRA.run(F, DummyFAM); +  return *TTI; +} + +ImmutablePass * +llvm::createTargetTransformInfoWrapperPass(TargetIRAnalysis TIRA) { +  return new TargetTransformInfoWrapperPass(std::move(TIRA)); +} diff --git a/contrib/llvm/lib/Analysis/Trace.cpp b/contrib/llvm/lib/Analysis/Trace.cpp new file mode 100644 index 000000000000..4dec53151ed6 --- /dev/null +++ b/contrib/llvm/lib/Analysis/Trace.cpp @@ -0,0 +1,54 @@ +//===- Trace.cpp - Implementation of Trace class --------------------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This class represents a single trace of LLVM basic blocks.  A trace is a +// single entry, multiple exit, region of code that is often hot.  Trace-based +// optimizations treat traces almost like they are a large, strange, basic +// block: because the trace path is assumed to be hot, optimizations for the +// fall-through path are made at the expense of the non-fall-through paths. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/Trace.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +Function *Trace::getFunction() const { +  return getEntryBasicBlock()->getParent(); +} + +Module *Trace::getModule() const { +  return getFunction()->getParent(); +} + +/// print - Write trace to output stream. +void Trace::print(raw_ostream &O) const { +  Function *F = getFunction(); +  O << "; Trace from function " << F->getName() << ", blocks:\n"; +  for (const_iterator i = begin(), e = end(); i != e; ++i) { +    O << "; "; +    (*i)->printAsOperand(O, true, getModule()); +    O << "\n"; +  } +  O << "; Trace parent function: \n" << *F; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +/// dump - Debugger convenience method; writes trace to standard error +/// output stream. +LLVM_DUMP_METHOD void Trace::dump() const { +  print(dbgs()); +} +#endif diff --git a/contrib/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp b/contrib/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp new file mode 100644 index 000000000000..25a154edf4ac --- /dev/null +++ b/contrib/llvm/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -0,0 +1,740 @@ +//===- TypeBasedAliasAnalysis.cpp - Type-Based Alias Analysis -------------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the TypeBasedAliasAnalysis pass, which implements +// metadata-based TBAA. +// +// In LLVM IR, memory does not have types, so LLVM's own type system is not +// suitable for doing TBAA. Instead, metadata is added to the IR to describe +// a type system of a higher level language. This can be used to implement +// typical C/C++ TBAA, but it can also be used to implement custom alias +// analysis behavior for other languages. +// +// We now support two types of metadata format: scalar TBAA and struct-path +// aware TBAA. After all testing cases are upgraded to use struct-path aware +// TBAA and we can auto-upgrade existing bc files, the support for scalar TBAA +// can be dropped. +// +// The scalar TBAA metadata format is very simple. TBAA MDNodes have up to +// three fields, e.g.: +//   !0 = !{ !"an example type tree" } +//   !1 = !{ !"int", !0 } +//   !2 = !{ !"float", !0 } +//   !3 = !{ !"const float", !2, i64 1 } +// +// The first field is an identity field. It can be any value, usually +// an MDString, which uniquely identifies the type. The most important +// name in the tree is the name of the root node. Two trees with +// different root node names are entirely disjoint, even if they +// have leaves with common names. +// +// The second field identifies the type's parent node in the tree, or +// is null or omitted for a root node. A type is considered to alias +// all of its descendants and all of its ancestors in the tree. Also, +// a type is considered to alias all types in other trees, so that +// bitcode produced from multiple front-ends is handled conservatively. +// +// If the third field is present, it's an integer which if equal to 1 +// indicates that the type is "constant" (meaning pointsToConstantMemory +// should return true; see +// http://llvm.org/docs/AliasAnalysis.html#OtherItfs). +// +// With struct-path aware TBAA, the MDNodes attached to an instruction using +// "!tbaa" are called path tag nodes. +// +// The path tag node has 4 fields with the last field being optional. +// +// The first field is the base type node, it can be a struct type node +// or a scalar type node. The second field is the access type node, it +// must be a scalar type node. The third field is the offset into the base type. +// The last field has the same meaning as the last field of our scalar TBAA: +// it's an integer which if equal to 1 indicates that the access is "constant". +// +// The struct type node has a name and a list of pairs, one pair for each member +// of the struct. The first element of each pair is a type node (a struct type +// node or a scalar type node), specifying the type of the member, the second +// element of each pair is the offset of the member. +// +// Given an example +// typedef struct { +//   short s; +// } A; +// typedef struct { +//   uint16_t s; +//   A a; +// } B; +// +// For an access to B.a.s, we attach !5 (a path tag node) to the load/store +// instruction. The base type is !4 (struct B), the access type is !2 (scalar +// type short) and the offset is 4. +// +// !0 = !{!"Simple C/C++ TBAA"} +// !1 = !{!"omnipotent char", !0} // Scalar type node +// !2 = !{!"short", !1}           // Scalar type node +// !3 = !{!"A", !2, i64 0}        // Struct type node +// !4 = !{!"B", !2, i64 0, !3, i64 4} +//                                                           // Struct type node +// !5 = !{!4, !2, i64 4}          // Path tag node +// +// The struct type nodes and the scalar type nodes form a type DAG. +//         Root (!0) +//         char (!1)  -- edge to Root +//         short (!2) -- edge to char +//         A (!3) -- edge with offset 0 to short +//         B (!4) -- edge with offset 0 to short and edge with offset 4 to A +// +// To check if two tags (tagX and tagY) can alias, we start from the base type +// of tagX, follow the edge with the correct offset in the type DAG and adjust +// the offset until we reach the base type of tagY or until we reach the Root +// node. +// If we reach the base type of tagY, compare the adjusted offset with +// offset of tagY, return Alias if the offsets are the same, return NoAlias +// otherwise. +// If we reach the Root node, perform the above starting from base type of tagY +// to see if we reach base type of tagX. +// +// If they have different roots, they're part of different potentially +// unrelated type systems, so we return Alias to be conservative. +// If neither node is an ancestor of the other and they have the same root, +// then we say NoAlias. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TypeBasedAliasAnalysis.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include <cassert> +#include <cstdint> + +using namespace llvm; + +// A handy option for disabling TBAA functionality. The same effect can also be +// achieved by stripping the !tbaa tags from IR, but this option is sometimes +// more convenient. +static cl::opt<bool> EnableTBAA("enable-tbaa", cl::init(true), cl::Hidden); + +namespace { + +/// isNewFormatTypeNode - Return true iff the given type node is in the new +/// size-aware format. +static bool isNewFormatTypeNode(const MDNode *N) { +  if (N->getNumOperands() < 3) +    return false; +  // In the old format the first operand is a string. +  if (!isa<MDNode>(N->getOperand(0))) +    return false; +  return true; +} + +/// This is a simple wrapper around an MDNode which provides a higher-level +/// interface by hiding the details of how alias analysis information is encoded +/// in its operands. +template<typename MDNodeTy> +class TBAANodeImpl { +  MDNodeTy *Node = nullptr; + +public: +  TBAANodeImpl() = default; +  explicit TBAANodeImpl(MDNodeTy *N) : Node(N) {} + +  /// getNode - Get the MDNode for this TBAANode. +  MDNodeTy *getNode() const { return Node; } + +  /// isNewFormat - Return true iff the wrapped type node is in the new +  /// size-aware format. +  bool isNewFormat() const { return isNewFormatTypeNode(Node); } + +  /// getParent - Get this TBAANode's Alias tree parent. +  TBAANodeImpl<MDNodeTy> getParent() const { +    if (isNewFormat()) +      return TBAANodeImpl(cast<MDNodeTy>(Node->getOperand(0))); + +    if (Node->getNumOperands() < 2) +      return TBAANodeImpl<MDNodeTy>(); +    MDNodeTy *P = dyn_cast_or_null<MDNodeTy>(Node->getOperand(1)); +    if (!P) +      return TBAANodeImpl<MDNodeTy>(); +    // Ok, this node has a valid parent. Return it. +    return TBAANodeImpl<MDNodeTy>(P); +  } + +  /// Test if this TBAANode represents a type for objects which are +  /// not modified (by any means) in the context where this +  /// AliasAnalysis is relevant. +  bool isTypeImmutable() const { +    if (Node->getNumOperands() < 3) +      return false; +    ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(2)); +    if (!CI) +      return false; +    return CI->getValue()[0]; +  } +}; + +/// \name Specializations of \c TBAANodeImpl for const and non const qualified +/// \c MDNode. +/// @{ +using TBAANode = TBAANodeImpl<const MDNode>; +using MutableTBAANode = TBAANodeImpl<MDNode>; +/// @} + +/// This is a simple wrapper around an MDNode which provides a +/// higher-level interface by hiding the details of how alias analysis +/// information is encoded in its operands. +template<typename MDNodeTy> +class TBAAStructTagNodeImpl { +  /// This node should be created with createTBAAAccessTag(). +  MDNodeTy *Node; + +public: +  explicit TBAAStructTagNodeImpl(MDNodeTy *N) : Node(N) {} + +  /// Get the MDNode for this TBAAStructTagNode. +  MDNodeTy *getNode() const { return Node; } + +  /// isNewFormat - Return true iff the wrapped access tag is in the new +  /// size-aware format. +  bool isNewFormat() const { +    if (Node->getNumOperands() < 4) +      return false; +    if (MDNodeTy *AccessType = getAccessType()) +      if (!TBAANodeImpl<MDNodeTy>(AccessType).isNewFormat()) +        return false; +    return true; +  } + +  MDNodeTy *getBaseType() const { +    return dyn_cast_or_null<MDNode>(Node->getOperand(0)); +  } + +  MDNodeTy *getAccessType() const { +    return dyn_cast_or_null<MDNode>(Node->getOperand(1)); +  } + +  uint64_t getOffset() const { +    return mdconst::extract<ConstantInt>(Node->getOperand(2))->getZExtValue(); +  } + +  uint64_t getSize() const { +    if (!isNewFormat()) +      return UINT64_MAX; +    return mdconst::extract<ConstantInt>(Node->getOperand(3))->getZExtValue(); +  } + +  /// Test if this TBAAStructTagNode represents a type for objects +  /// which are not modified (by any means) in the context where this +  /// AliasAnalysis is relevant. +  bool isTypeImmutable() const { +    unsigned OpNo = isNewFormat() ? 4 : 3; +    if (Node->getNumOperands() < OpNo + 1) +      return false; +    ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(OpNo)); +    if (!CI) +      return false; +    return CI->getValue()[0]; +  } +}; + +/// \name Specializations of \c TBAAStructTagNodeImpl for const and non const +/// qualified \c MDNods. +/// @{ +using TBAAStructTagNode = TBAAStructTagNodeImpl<const MDNode>; +using MutableTBAAStructTagNode = TBAAStructTagNodeImpl<MDNode>; +/// @} + +/// This is a simple wrapper around an MDNode which provides a +/// higher-level interface by hiding the details of how alias analysis +/// information is encoded in its operands. +class TBAAStructTypeNode { +  /// This node should be created with createTBAATypeNode(). +  const MDNode *Node = nullptr; + +public: +  TBAAStructTypeNode() = default; +  explicit TBAAStructTypeNode(const MDNode *N) : Node(N) {} + +  /// Get the MDNode for this TBAAStructTypeNode. +  const MDNode *getNode() const { return Node; } + +  /// isNewFormat - Return true iff the wrapped type node is in the new +  /// size-aware format. +  bool isNewFormat() const { return isNewFormatTypeNode(Node); } + +  bool operator==(const TBAAStructTypeNode &Other) const { +    return getNode() == Other.getNode(); +  } + +  /// getId - Return type identifier. +  Metadata *getId() const { +    return Node->getOperand(isNewFormat() ? 2 : 0); +  } + +  unsigned getNumFields() const { +    unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1; +    unsigned NumOpsPerField = isNewFormat() ? 3 : 2; +    return (getNode()->getNumOperands() - FirstFieldOpNo) / NumOpsPerField; +  } + +  TBAAStructTypeNode getFieldType(unsigned FieldIndex) const { +    unsigned FirstFieldOpNo = isNewFormat() ? 3 : 1; +    unsigned NumOpsPerField = isNewFormat() ? 3 : 2; +    unsigned OpIndex = FirstFieldOpNo + FieldIndex * NumOpsPerField; +    auto *TypeNode = cast<MDNode>(getNode()->getOperand(OpIndex)); +    return TBAAStructTypeNode(TypeNode); +  } + +  /// Get this TBAAStructTypeNode's field in the type DAG with +  /// given offset. Update the offset to be relative to the field type. +  TBAAStructTypeNode getField(uint64_t &Offset) const { +    bool NewFormat = isNewFormat(); +    if (NewFormat) { +      // New-format root and scalar type nodes have no fields. +      if (Node->getNumOperands() < 6) +        return TBAAStructTypeNode(); +    } else { +      // Parent can be omitted for the root node. +      if (Node->getNumOperands() < 2) +        return TBAAStructTypeNode(); + +      // Fast path for a scalar type node and a struct type node with a single +      // field. +      if (Node->getNumOperands() <= 3) { +        uint64_t Cur = Node->getNumOperands() == 2 +                           ? 0 +                           : mdconst::extract<ConstantInt>(Node->getOperand(2)) +                                 ->getZExtValue(); +        Offset -= Cur; +        MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(1)); +        if (!P) +          return TBAAStructTypeNode(); +        return TBAAStructTypeNode(P); +      } +    } + +    // Assume the offsets are in order. We return the previous field if +    // the current offset is bigger than the given offset. +    unsigned FirstFieldOpNo = NewFormat ? 3 : 1; +    unsigned NumOpsPerField = NewFormat ? 3 : 2; +    unsigned TheIdx = 0; +    for (unsigned Idx = FirstFieldOpNo; Idx < Node->getNumOperands(); +         Idx += NumOpsPerField) { +      uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(Idx + 1)) +                         ->getZExtValue(); +      if (Cur > Offset) { +        assert(Idx >= FirstFieldOpNo + NumOpsPerField && +               "TBAAStructTypeNode::getField should have an offset match!"); +        TheIdx = Idx - NumOpsPerField; +        break; +      } +    } +    // Move along the last field. +    if (TheIdx == 0) +      TheIdx = Node->getNumOperands() - NumOpsPerField; +    uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(TheIdx + 1)) +                       ->getZExtValue(); +    Offset -= Cur; +    MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(TheIdx)); +    if (!P) +      return TBAAStructTypeNode(); +    return TBAAStructTypeNode(P); +  } +}; + +} // end anonymous namespace + +/// Check the first operand of the tbaa tag node, if it is a MDNode, we treat +/// it as struct-path aware TBAA format, otherwise, we treat it as scalar TBAA +/// format. +static bool isStructPathTBAA(const MDNode *MD) { +  // Anonymous TBAA root starts with a MDNode and dragonegg uses it as +  // a TBAA tag. +  return isa<MDNode>(MD->getOperand(0)) && MD->getNumOperands() >= 3; +} + +AliasResult TypeBasedAAResult::alias(const MemoryLocation &LocA, +                                     const MemoryLocation &LocB) { +  if (!EnableTBAA) +    return AAResultBase::alias(LocA, LocB); + +  // If accesses may alias, chain to the next AliasAnalysis. +  if (Aliases(LocA.AATags.TBAA, LocB.AATags.TBAA)) +    return AAResultBase::alias(LocA, LocB); + +  // Otherwise return a definitive result. +  return NoAlias; +} + +bool TypeBasedAAResult::pointsToConstantMemory(const MemoryLocation &Loc, +                                               bool OrLocal) { +  if (!EnableTBAA) +    return AAResultBase::pointsToConstantMemory(Loc, OrLocal); + +  const MDNode *M = Loc.AATags.TBAA; +  if (!M) +    return AAResultBase::pointsToConstantMemory(Loc, OrLocal); + +  // If this is an "immutable" type, we can assume the pointer is pointing +  // to constant memory. +  if ((!isStructPathTBAA(M) && TBAANode(M).isTypeImmutable()) || +      (isStructPathTBAA(M) && TBAAStructTagNode(M).isTypeImmutable())) +    return true; + +  return AAResultBase::pointsToConstantMemory(Loc, OrLocal); +} + +FunctionModRefBehavior +TypeBasedAAResult::getModRefBehavior(ImmutableCallSite CS) { +  if (!EnableTBAA) +    return AAResultBase::getModRefBehavior(CS); + +  FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; + +  // If this is an "immutable" type, we can assume the call doesn't write +  // to memory. +  if (const MDNode *M = CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) +    if ((!isStructPathTBAA(M) && TBAANode(M).isTypeImmutable()) || +        (isStructPathTBAA(M) && TBAAStructTagNode(M).isTypeImmutable())) +      Min = FMRB_OnlyReadsMemory; + +  return FunctionModRefBehavior(AAResultBase::getModRefBehavior(CS) & Min); +} + +FunctionModRefBehavior TypeBasedAAResult::getModRefBehavior(const Function *F) { +  // Functions don't have metadata. Just chain to the next implementation. +  return AAResultBase::getModRefBehavior(F); +} + +ModRefInfo TypeBasedAAResult::getModRefInfo(ImmutableCallSite CS, +                                            const MemoryLocation &Loc) { +  if (!EnableTBAA) +    return AAResultBase::getModRefInfo(CS, Loc); + +  if (const MDNode *L = Loc.AATags.TBAA) +    if (const MDNode *M = +            CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) +      if (!Aliases(L, M)) +        return ModRefInfo::NoModRef; + +  return AAResultBase::getModRefInfo(CS, Loc); +} + +ModRefInfo TypeBasedAAResult::getModRefInfo(ImmutableCallSite CS1, +                                            ImmutableCallSite CS2) { +  if (!EnableTBAA) +    return AAResultBase::getModRefInfo(CS1, CS2); + +  if (const MDNode *M1 = +          CS1.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) +    if (const MDNode *M2 = +            CS2.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) +      if (!Aliases(M1, M2)) +        return ModRefInfo::NoModRef; + +  return AAResultBase::getModRefInfo(CS1, CS2); +} + +bool MDNode::isTBAAVtableAccess() const { +  if (!isStructPathTBAA(this)) { +    if (getNumOperands() < 1) +      return false; +    if (MDString *Tag1 = dyn_cast<MDString>(getOperand(0))) { +      if (Tag1->getString() == "vtable pointer") +        return true; +    } +    return false; +  } + +  // For struct-path aware TBAA, we use the access type of the tag. +  TBAAStructTagNode Tag(this); +  TBAAStructTypeNode AccessType(Tag.getAccessType()); +  if(auto *Id = dyn_cast<MDString>(AccessType.getId())) +    if (Id->getString() == "vtable pointer") +      return true; +  return false; +} + +static bool matchAccessTags(const MDNode *A, const MDNode *B, +                            const MDNode **GenericTag = nullptr); + +MDNode *MDNode::getMostGenericTBAA(MDNode *A, MDNode *B) { +  const MDNode *GenericTag; +  matchAccessTags(A, B, &GenericTag); +  return const_cast<MDNode*>(GenericTag); +} + +static const MDNode *getLeastCommonType(const MDNode *A, const MDNode *B) { +  if (!A || !B) +    return nullptr; + +  if (A == B) +    return A; + +  SmallSetVector<const MDNode *, 4> PathA; +  TBAANode TA(A); +  while (TA.getNode()) { +    if (PathA.count(TA.getNode())) +      report_fatal_error("Cycle found in TBAA metadata."); +    PathA.insert(TA.getNode()); +    TA = TA.getParent(); +  } + +  SmallSetVector<const MDNode *, 4> PathB; +  TBAANode TB(B); +  while (TB.getNode()) { +    if (PathB.count(TB.getNode())) +      report_fatal_error("Cycle found in TBAA metadata."); +    PathB.insert(TB.getNode()); +    TB = TB.getParent(); +  } + +  int IA = PathA.size() - 1; +  int IB = PathB.size() - 1; + +  const MDNode *Ret = nullptr; +  while (IA >= 0 && IB >= 0) { +    if (PathA[IA] == PathB[IB]) +      Ret = PathA[IA]; +    else +      break; +    --IA; +    --IB; +  } + +  return Ret; +} + +void Instruction::getAAMetadata(AAMDNodes &N, bool Merge) const { +  if (Merge) +    N.TBAA = +        MDNode::getMostGenericTBAA(N.TBAA, getMetadata(LLVMContext::MD_tbaa)); +  else +    N.TBAA = getMetadata(LLVMContext::MD_tbaa); + +  if (Merge) +    N.Scope = MDNode::getMostGenericAliasScope( +        N.Scope, getMetadata(LLVMContext::MD_alias_scope)); +  else +    N.Scope = getMetadata(LLVMContext::MD_alias_scope); + +  if (Merge) +    N.NoAlias = +        MDNode::intersect(N.NoAlias, getMetadata(LLVMContext::MD_noalias)); +  else +    N.NoAlias = getMetadata(LLVMContext::MD_noalias); +} + +static const MDNode *createAccessTag(const MDNode *AccessType) { +  // If there is no access type or the access type is the root node, then +  // we don't have any useful access tag to return. +  if (!AccessType || AccessType->getNumOperands() < 2) +    return nullptr; + +  Type *Int64 = IntegerType::get(AccessType->getContext(), 64); +  auto *OffsetNode = ConstantAsMetadata::get(ConstantInt::get(Int64, 0)); + +  if (TBAAStructTypeNode(AccessType).isNewFormat()) { +    // TODO: Take access ranges into account when matching access tags and +    // fix this code to generate actual access sizes for generic tags. +    uint64_t AccessSize = UINT64_MAX; +    auto *SizeNode = +        ConstantAsMetadata::get(ConstantInt::get(Int64, AccessSize)); +    Metadata *Ops[] = {const_cast<MDNode*>(AccessType), +                       const_cast<MDNode*>(AccessType), +                       OffsetNode, SizeNode}; +    return MDNode::get(AccessType->getContext(), Ops); +  } + +  Metadata *Ops[] = {const_cast<MDNode*>(AccessType), +                     const_cast<MDNode*>(AccessType), +                     OffsetNode}; +  return MDNode::get(AccessType->getContext(), Ops); +} + +static bool hasField(TBAAStructTypeNode BaseType, +                     TBAAStructTypeNode FieldType) { +  for (unsigned I = 0, E = BaseType.getNumFields(); I != E; ++I) { +    TBAAStructTypeNode T = BaseType.getFieldType(I); +    if (T == FieldType || hasField(T, FieldType)) +      return true; +  } +  return false; +} + +/// Return true if for two given accesses, one of the accessed objects may be a +/// subobject of the other. The \p BaseTag and \p SubobjectTag parameters +/// describe the accesses to the base object and the subobject respectively. +/// \p CommonType must be the metadata node describing the common type of the +/// accessed objects. On return, \p MayAlias is set to true iff these accesses +/// may alias and \p Generic, if not null, points to the most generic access +/// tag for the given two. +static bool mayBeAccessToSubobjectOf(TBAAStructTagNode BaseTag, +                                     TBAAStructTagNode SubobjectTag, +                                     const MDNode *CommonType, +                                     const MDNode **GenericTag, +                                     bool &MayAlias) { +  // If the base object is of the least common type, then this may be an access +  // to its subobject. +  if (BaseTag.getAccessType() == BaseTag.getBaseType() && +      BaseTag.getAccessType() == CommonType) { +    if (GenericTag) +      *GenericTag = createAccessTag(CommonType); +    MayAlias = true; +    return true; +  } + +  // If the access to the base object is through a field of the subobject's +  // type, then this may be an access to that field. To check for that we start +  // from the base type, follow the edge with the correct offset in the type DAG +  // and adjust the offset until we reach the field type or until we reach the +  // access type. +  bool NewFormat = BaseTag.isNewFormat(); +  TBAAStructTypeNode BaseType(BaseTag.getBaseType()); +  uint64_t OffsetInBase = BaseTag.getOffset(); + +  for (;;) { +    // In the old format there is no distinction between fields and parent +    // types, so in this case we consider all nodes up to the root. +    if (!BaseType.getNode()) { +      assert(!NewFormat && "Did not see access type in access path!"); +      break; +    } + +    if (BaseType.getNode() == SubobjectTag.getBaseType()) { +      bool SameMemberAccess = OffsetInBase == SubobjectTag.getOffset(); +      if (GenericTag) { +        *GenericTag = SameMemberAccess ? SubobjectTag.getNode() : +                                         createAccessTag(CommonType); +      } +      MayAlias = SameMemberAccess; +      return true; +    } + +    // With new-format nodes we stop at the access type. +    if (NewFormat && BaseType.getNode() == BaseTag.getAccessType()) +      break; + +    // Follow the edge with the correct offset. Offset will be adjusted to +    // be relative to the field type. +    BaseType = BaseType.getField(OffsetInBase); +  } + +  // If the base object has a direct or indirect field of the subobject's type, +  // then this may be an access to that field. We need this to check now that +  // we support aggregates as access types. +  if (NewFormat) { +    // TBAAStructTypeNode BaseAccessType(BaseTag.getAccessType()); +    TBAAStructTypeNode FieldType(SubobjectTag.getBaseType()); +    if (hasField(BaseType, FieldType)) { +      if (GenericTag) +        *GenericTag = createAccessTag(CommonType); +      MayAlias = true; +      return true; +    } +  } + +  return false; +} + +/// matchTags - Return true if the given couple of accesses are allowed to +/// overlap. If \arg GenericTag is not null, then on return it points to the +/// most generic access descriptor for the given two. +static bool matchAccessTags(const MDNode *A, const MDNode *B, +                            const MDNode **GenericTag) { +  if (A == B) { +    if (GenericTag) +      *GenericTag = A; +    return true; +  } + +  // Accesses with no TBAA information may alias with any other accesses. +  if (!A || !B) { +    if (GenericTag) +      *GenericTag = nullptr; +    return true; +  } + +  // Verify that both input nodes are struct-path aware.  Auto-upgrade should +  // have taken care of this. +  assert(isStructPathTBAA(A) && "Access A is not struct-path aware!"); +  assert(isStructPathTBAA(B) && "Access B is not struct-path aware!"); + +  TBAAStructTagNode TagA(A), TagB(B); +  const MDNode *CommonType = getLeastCommonType(TagA.getAccessType(), +                                                TagB.getAccessType()); + +  // If the final access types have different roots, they're part of different +  // potentially unrelated type systems, so we must be conservative. +  if (!CommonType) { +    if (GenericTag) +      *GenericTag = nullptr; +    return true; +  } + +  // If one of the accessed objects may be a subobject of the other, then such +  // accesses may alias. +  bool MayAlias; +  if (mayBeAccessToSubobjectOf(/* BaseTag= */ TagA, /* SubobjectTag= */ TagB, +                               CommonType, GenericTag, MayAlias) || +      mayBeAccessToSubobjectOf(/* BaseTag= */ TagB, /* SubobjectTag= */ TagA, +                               CommonType, GenericTag, MayAlias)) +    return MayAlias; + +  // Otherwise, we've proved there's no alias. +  if (GenericTag) +    *GenericTag = createAccessTag(CommonType); +  return false; +} + +/// Aliases - Test whether the access represented by tag A may alias the +/// access represented by tag B. +bool TypeBasedAAResult::Aliases(const MDNode *A, const MDNode *B) const { +  return matchAccessTags(A, B); +} + +AnalysisKey TypeBasedAA::Key; + +TypeBasedAAResult TypeBasedAA::run(Function &F, FunctionAnalysisManager &AM) { +  return TypeBasedAAResult(); +} + +char TypeBasedAAWrapperPass::ID = 0; +INITIALIZE_PASS(TypeBasedAAWrapperPass, "tbaa", "Type-Based Alias Analysis", +                false, true) + +ImmutablePass *llvm::createTypeBasedAAWrapperPass() { +  return new TypeBasedAAWrapperPass(); +} + +TypeBasedAAWrapperPass::TypeBasedAAWrapperPass() : ImmutablePass(ID) { +  initializeTypeBasedAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool TypeBasedAAWrapperPass::doInitialization(Module &M) { +  Result.reset(new TypeBasedAAResult()); +  return false; +} + +bool TypeBasedAAWrapperPass::doFinalization(Module &M) { +  Result.reset(); +  return false; +} + +void TypeBasedAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { +  AU.setPreservesAll(); +} diff --git a/contrib/llvm/lib/Analysis/TypeMetadataUtils.cpp b/contrib/llvm/lib/Analysis/TypeMetadataUtils.cpp new file mode 100644 index 000000000000..6871e4887c9e --- /dev/null +++ b/contrib/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -0,0 +1,118 @@ +//===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains functions that make it easier to manipulate type metadata +// for devirtualization. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" + +using namespace llvm; + +// Search for virtual calls that call FPtr and add them to DevirtCalls. +static void +findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls, +                          bool *HasNonCallUses, Value *FPtr, uint64_t Offset) { +  for (const Use &U : FPtr->uses()) { +    Value *User = U.getUser(); +    if (isa<BitCastInst>(User)) { +      findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset); +    } else if (auto CI = dyn_cast<CallInst>(User)) { +      DevirtCalls.push_back({Offset, CI}); +    } else if (auto II = dyn_cast<InvokeInst>(User)) { +      DevirtCalls.push_back({Offset, II}); +    } else if (HasNonCallUses) { +      *HasNonCallUses = true; +    } +  } +} + +// Search for virtual calls that load from VPtr and add them to DevirtCalls. +static void +findLoadCallsAtConstantOffset(const Module *M, +                              SmallVectorImpl<DevirtCallSite> &DevirtCalls, +                              Value *VPtr, int64_t Offset) { +  for (const Use &U : VPtr->uses()) { +    Value *User = U.getUser(); +    if (isa<BitCastInst>(User)) { +      findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset); +    } else if (isa<LoadInst>(User)) { +      findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset); +    } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) { +      // Take into account the GEP offset. +      if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { +        SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end()); +        int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType( +            GEP->getSourceElementType(), Indices); +        findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset); +      } +    } +  } +} + +void llvm::findDevirtualizableCallsForTypeTest( +    SmallVectorImpl<DevirtCallSite> &DevirtCalls, +    SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI) { +  assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test); + +  const Module *M = CI->getParent()->getParent()->getParent(); + +  // Find llvm.assume intrinsics for this llvm.type.test call. +  for (const Use &CIU : CI->uses()) { +    if (auto *AssumeCI = dyn_cast<CallInst>(CIU.getUser())) { +      Function *F = AssumeCI->getCalledFunction(); +      if (F && F->getIntrinsicID() == Intrinsic::assume) +        Assumes.push_back(AssumeCI); +    } +  } + +  // If we found any, search for virtual calls based on %p and add them to +  // DevirtCalls. +  if (!Assumes.empty()) +    findLoadCallsAtConstantOffset(M, DevirtCalls, +                                  CI->getArgOperand(0)->stripPointerCasts(), 0); +} + +void llvm::findDevirtualizableCallsForTypeCheckedLoad( +    SmallVectorImpl<DevirtCallSite> &DevirtCalls, +    SmallVectorImpl<Instruction *> &LoadedPtrs, +    SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, +    const CallInst *CI) { +  assert(CI->getCalledFunction()->getIntrinsicID() == +         Intrinsic::type_checked_load); + +  auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); +  if (!Offset) { +    HasNonCallUses = true; +    return; +  } + +  for (const Use &U : CI->uses()) { +    auto CIU = U.getUser(); +    if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) { +      if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) { +        LoadedPtrs.push_back(EVI); +        continue; +      } +      if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) { +        Preds.push_back(EVI); +        continue; +      } +    } +    HasNonCallUses = true; +  } + +  for (Value *LoadedPtr : LoadedPtrs) +    findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr, +                              Offset->getZExtValue()); +} diff --git a/contrib/llvm/lib/Analysis/ValueLattice.cpp b/contrib/llvm/lib/Analysis/ValueLattice.cpp new file mode 100644 index 000000000000..7de437ca480e --- /dev/null +++ b/contrib/llvm/lib/Analysis/ValueLattice.cpp @@ -0,0 +1,26 @@ +//===- ValueLattice.cpp - Value constraint analysis -------------*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ValueLattice.h" + +namespace llvm { +raw_ostream &operator<<(raw_ostream &OS, const ValueLatticeElement &Val) { +  if (Val.isUndefined()) +    return OS << "undefined"; +  if (Val.isOverdefined()) +    return OS << "overdefined"; + +  if (Val.isNotConstant()) +    return OS << "notconstant<" << *Val.getNotConstant() << ">"; +  if (Val.isConstantRange()) +    return OS << "constantrange<" << Val.getConstantRange().getLower() << ", " +              << Val.getConstantRange().getUpper() << ">"; +  return OS << "constant<" << *Val.getConstant() << ">"; +} +} // end namespace llvm diff --git a/contrib/llvm/lib/Analysis/ValueLatticeUtils.cpp b/contrib/llvm/lib/Analysis/ValueLatticeUtils.cpp new file mode 100644 index 000000000000..22c9de4fe94d --- /dev/null +++ b/contrib/llvm/lib/Analysis/ValueLatticeUtils.cpp @@ -0,0 +1,44 @@ +//===-- ValueLatticeUtils.cpp - Utils for solving lattices ------*- C++ -*-===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements common functions useful for performing data-flow +// analyses that propagate values across function boundaries. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instructions.h" +using namespace llvm; + +bool llvm::canTrackArgumentsInterprocedurally(Function *F) { +  return F->hasLocalLinkage() && !F->hasAddressTaken(); +} + +bool llvm::canTrackReturnsInterprocedurally(Function *F) { +  return F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked); +} + +bool llvm::canTrackGlobalVariableInterprocedurally(GlobalVariable *GV) { +  if (GV->isConstant() || !GV->hasLocalLinkage() || +      !GV->hasDefinitiveInitializer()) +    return false; +  return !any_of(GV->users(), [&](User *U) { +    if (auto *Store = dyn_cast<StoreInst>(U)) { +      if (Store->getValueOperand() == GV || Store->isVolatile()) +        return true; +    } else if (auto *Load = dyn_cast<LoadInst>(U)) { +      if (Load->isVolatile()) +        return true; +    } else { +      return true; +    } +    return false; +  }); +} diff --git a/contrib/llvm/lib/Analysis/ValueTracking.cpp b/contrib/llvm/lib/Analysis/ValueTracking.cpp new file mode 100644 index 000000000000..edd46c5fe362 --- /dev/null +++ b/contrib/llvm/lib/Analysis/ValueTracking.cpp @@ -0,0 +1,5135 @@ +//===- ValueTracking.cpp - Walk computations to compute properties --------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains routines that help analyze properties that chains of +// computations have. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" +#include <algorithm> +#include <array> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; +using namespace llvm::PatternMatch; + +const unsigned MaxDepth = 6; + +// Controls the number of uses of the value searched for possible +// dominating comparisons. +static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses", +                                              cl::Hidden, cl::init(20)); + +/// Returns the bitwidth of the given scalar or pointer type. For vector types, +/// returns the element type's bitwidth. +static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { +  if (unsigned BitWidth = Ty->getScalarSizeInBits()) +    return BitWidth; + +  return DL.getIndexTypeSizeInBits(Ty); +} + +namespace { + +// Simplifying using an assume can only be done in a particular control-flow +// context (the context instruction provides that context). If an assume and +// the context instruction are not in the same block then the DT helps in +// figuring out if we can use it. +struct Query { +  const DataLayout &DL; +  AssumptionCache *AC; +  const Instruction *CxtI; +  const DominatorTree *DT; + +  // Unlike the other analyses, this may be a nullptr because not all clients +  // provide it currently. +  OptimizationRemarkEmitter *ORE; + +  /// Set of assumptions that should be excluded from further queries. +  /// This is because of the potential for mutual recursion to cause +  /// computeKnownBits to repeatedly visit the same assume intrinsic. The +  /// classic case of this is assume(x = y), which will attempt to determine +  /// bits in x from bits in y, which will attempt to determine bits in y from +  /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call +  /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo +  /// (all of which can call computeKnownBits), and so on. +  std::array<const Value *, MaxDepth> Excluded; + +  unsigned NumExcluded = 0; + +  Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, +        const DominatorTree *DT, OptimizationRemarkEmitter *ORE = nullptr) +      : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE) {} + +  Query(const Query &Q, const Value *NewExcl) +      : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), +        NumExcluded(Q.NumExcluded) { +    Excluded = Q.Excluded; +    Excluded[NumExcluded++] = NewExcl; +    assert(NumExcluded <= Excluded.size()); +  } + +  bool isExcluded(const Value *Value) const { +    if (NumExcluded == 0) +      return false; +    auto End = Excluded.begin() + NumExcluded; +    return std::find(Excluded.begin(), End, Value) != End; +  } +}; + +} // end anonymous namespace + +// Given the provided Value and, potentially, a context instruction, return +// the preferred context instruction (if any). +static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { +  // If we've been provided with a context instruction, then use that (provided +  // it has been inserted). +  if (CxtI && CxtI->getParent()) +    return CxtI; + +  // If the value is really an already-inserted instruction, then use that. +  CxtI = dyn_cast<Instruction>(V); +  if (CxtI && CxtI->getParent()) +    return CxtI; + +  return nullptr; +} + +static void computeKnownBits(const Value *V, KnownBits &Known, +                             unsigned Depth, const Query &Q); + +void llvm::computeKnownBits(const Value *V, KnownBits &Known, +                            const DataLayout &DL, unsigned Depth, +                            AssumptionCache *AC, const Instruction *CxtI, +                            const DominatorTree *DT, +                            OptimizationRemarkEmitter *ORE) { +  ::computeKnownBits(V, Known, Depth, +                     Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); +} + +static KnownBits computeKnownBits(const Value *V, unsigned Depth, +                                  const Query &Q); + +KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, +                                 unsigned Depth, AssumptionCache *AC, +                                 const Instruction *CxtI, +                                 const DominatorTree *DT, +                                 OptimizationRemarkEmitter *ORE) { +  return ::computeKnownBits(V, Depth, +                            Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); +} + +bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, +                               const DataLayout &DL, +                               AssumptionCache *AC, const Instruction *CxtI, +                               const DominatorTree *DT) { +  assert(LHS->getType() == RHS->getType() && +         "LHS and RHS should have the same type"); +  assert(LHS->getType()->isIntOrIntVectorTy() && +         "LHS and RHS should be integers"); +  // Look for an inverted mask: (X & ~M) op (Y & M). +  Value *M; +  if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) && +      match(RHS, m_c_And(m_Specific(M), m_Value()))) +    return true; +  if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) && +      match(LHS, m_c_And(m_Specific(M), m_Value()))) +    return true; +  IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); +  KnownBits LHSKnown(IT->getBitWidth()); +  KnownBits RHSKnown(IT->getBitWidth()); +  computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT); +  computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT); +  return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue(); +} + +bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI) { +  for (const User *U : CxtI->users()) { +    if (const ICmpInst *IC = dyn_cast<ICmpInst>(U)) +      if (IC->isEquality()) +        if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) +          if (C->isNullValue()) +            continue; +    return false; +  } +  return true; +} + +static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, +                                   const Query &Q); + +bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL, +                                  bool OrZero, +                                  unsigned Depth, AssumptionCache *AC, +                                  const Instruction *CxtI, +                                  const DominatorTree *DT) { +  return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, +                                  Query(DL, AC, safeCxtI(V, CxtI), DT)); +} + +static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q); + +bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth, +                          AssumptionCache *AC, const Instruction *CxtI, +                          const DominatorTree *DT) { +  return ::isKnownNonZero(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); +} + +bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL, +                              unsigned Depth, +                              AssumptionCache *AC, const Instruction *CxtI, +                              const DominatorTree *DT) { +  KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); +  return Known.isNonNegative(); +} + +bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth, +                           AssumptionCache *AC, const Instruction *CxtI, +                           const DominatorTree *DT) { +  if (auto *CI = dyn_cast<ConstantInt>(V)) +    return CI->getValue().isStrictlyPositive(); + +  // TODO: We'd doing two recursive queries here.  We should factor this such +  // that only a single query is needed. +  return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT) && +    isKnownNonZero(V, DL, Depth, AC, CxtI, DT); +} + +bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth, +                           AssumptionCache *AC, const Instruction *CxtI, +                           const DominatorTree *DT) { +  KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); +  return Known.isNegative(); +} + +static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q); + +bool llvm::isKnownNonEqual(const Value *V1, const Value *V2, +                           const DataLayout &DL, +                           AssumptionCache *AC, const Instruction *CxtI, +                           const DominatorTree *DT) { +  return ::isKnownNonEqual(V1, V2, Query(DL, AC, +                                         safeCxtI(V1, safeCxtI(V2, CxtI)), +                                         DT)); +} + +static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, +                              const Query &Q); + +bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask, +                             const DataLayout &DL, +                             unsigned Depth, AssumptionCache *AC, +                             const Instruction *CxtI, const DominatorTree *DT) { +  return ::MaskedValueIsZero(V, Mask, Depth, +                             Query(DL, AC, safeCxtI(V, CxtI), DT)); +} + +static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, +                                   const Query &Q); + +unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, +                                  unsigned Depth, AssumptionCache *AC, +                                  const Instruction *CxtI, +                                  const DominatorTree *DT) { +  return ::ComputeNumSignBits(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); +} + +static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, +                                   bool NSW, +                                   KnownBits &KnownOut, KnownBits &Known2, +                                   unsigned Depth, const Query &Q) { +  unsigned BitWidth = KnownOut.getBitWidth(); + +  // If an initial sequence of bits in the result is not needed, the +  // corresponding bits in the operands are not needed. +  KnownBits LHSKnown(BitWidth); +  computeKnownBits(Op0, LHSKnown, Depth + 1, Q); +  computeKnownBits(Op1, Known2, Depth + 1, Q); + +  KnownOut = KnownBits::computeForAddSub(Add, NSW, LHSKnown, Known2); +} + +static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, +                                KnownBits &Known, KnownBits &Known2, +                                unsigned Depth, const Query &Q) { +  unsigned BitWidth = Known.getBitWidth(); +  computeKnownBits(Op1, Known, Depth + 1, Q); +  computeKnownBits(Op0, Known2, Depth + 1, Q); + +  bool isKnownNegative = false; +  bool isKnownNonNegative = false; +  // If the multiplication is known not to overflow, compute the sign bit. +  if (NSW) { +    if (Op0 == Op1) { +      // The product of a number with itself is non-negative. +      isKnownNonNegative = true; +    } else { +      bool isKnownNonNegativeOp1 = Known.isNonNegative(); +      bool isKnownNonNegativeOp0 = Known2.isNonNegative(); +      bool isKnownNegativeOp1 = Known.isNegative(); +      bool isKnownNegativeOp0 = Known2.isNegative(); +      // The product of two numbers with the same sign is non-negative. +      isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || +        (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); +      // The product of a negative number and a non-negative number is either +      // negative or zero. +      if (!isKnownNonNegative) +        isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && +                           isKnownNonZero(Op0, Depth, Q)) || +                          (isKnownNegativeOp0 && isKnownNonNegativeOp1 && +                           isKnownNonZero(Op1, Depth, Q)); +    } +  } + +  assert(!Known.hasConflict() && !Known2.hasConflict()); +  // Compute a conservative estimate for high known-0 bits. +  unsigned LeadZ =  std::max(Known.countMinLeadingZeros() + +                             Known2.countMinLeadingZeros(), +                             BitWidth) - BitWidth; +  LeadZ = std::min(LeadZ, BitWidth); + +  // The result of the bottom bits of an integer multiply can be +  // inferred by looking at the bottom bits of both operands and +  // multiplying them together. +  // We can infer at least the minimum number of known trailing bits +  // of both operands. Depending on number of trailing zeros, we can +  // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming +  // a and b are divisible by m and n respectively. +  // We then calculate how many of those bits are inferrable and set +  // the output. For example, the i8 mul: +  //  a = XXXX1100 (12) +  //  b = XXXX1110 (14) +  // We know the bottom 3 bits are zero since the first can be divided by +  // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). +  // Applying the multiplication to the trimmed arguments gets: +  //    XX11 (3) +  //    X111 (7) +  // ------- +  //    XX11 +  //   XX11 +  //  XX11 +  // XX11 +  // ------- +  // XXXXX01 +  // Which allows us to infer the 2 LSBs. Since we're multiplying the result +  // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. +  // The proof for this can be described as: +  // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && +  //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) + +  //                    umin(countTrailingZeros(C2), C6) + +  //                    umin(C5 - umin(countTrailingZeros(C1), C5), +  //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1) +  // %aa = shl i8 %a, C5 +  // %bb = shl i8 %b, C6 +  // %aaa = or i8 %aa, C1 +  // %bbb = or i8 %bb, C2 +  // %mul = mul i8 %aaa, %bbb +  // %mask = and i8 %mul, C7 +  //   => +  // %mask = i8 ((C1*C2)&C7) +  // Where C5, C6 describe the known bits of %a, %b +  // C1, C2 describe the known bottom bits of %a, %b. +  // C7 describes the mask of the known bits of the result. +  APInt Bottom0 = Known.One; +  APInt Bottom1 = Known2.One; + +  // How many times we'd be able to divide each argument by 2 (shr by 1). +  // This gives us the number of trailing zeros on the multiplication result. +  unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes(); +  unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes(); +  unsigned TrailZero0 = Known.countMinTrailingZeros(); +  unsigned TrailZero1 = Known2.countMinTrailingZeros(); +  unsigned TrailZ = TrailZero0 + TrailZero1; + +  // Figure out the fewest known-bits operand. +  unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0, +                                      TrailBitsKnown1 - TrailZero1); +  unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); + +  APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) * +                      Bottom1.getLoBits(TrailBitsKnown1); + +  Known.resetAll(); +  Known.Zero.setHighBits(LeadZ); +  Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); +  Known.One |= BottomKnown.getLoBits(ResultBitsKnown); + +  // Only make use of no-wrap flags if we failed to compute the sign bit +  // directly.  This matters if the multiplication always overflows, in +  // which case we prefer to follow the result of the direct computation, +  // though as the program is invoking undefined behaviour we can choose +  // whatever we like here. +  if (isKnownNonNegative && !Known.isNegative()) +    Known.makeNonNegative(); +  else if (isKnownNegative && !Known.isNonNegative()) +    Known.makeNegative(); +} + +void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, +                                             KnownBits &Known) { +  unsigned BitWidth = Known.getBitWidth(); +  unsigned NumRanges = Ranges.getNumOperands() / 2; +  assert(NumRanges >= 1); + +  Known.Zero.setAllBits(); +  Known.One.setAllBits(); + +  for (unsigned i = 0; i < NumRanges; ++i) { +    ConstantInt *Lower = +        mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0)); +    ConstantInt *Upper = +        mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1)); +    ConstantRange Range(Lower->getValue(), Upper->getValue()); + +    // The first CommonPrefixBits of all values in Range are equal. +    unsigned CommonPrefixBits = +        (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countLeadingZeros(); + +    APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits); +    Known.One &= Range.getUnsignedMax() & Mask; +    Known.Zero &= ~Range.getUnsignedMax() & Mask; +  } +} + +static bool isEphemeralValueOf(const Instruction *I, const Value *E) { +  SmallVector<const Value *, 16> WorkSet(1, I); +  SmallPtrSet<const Value *, 32> Visited; +  SmallPtrSet<const Value *, 16> EphValues; + +  // The instruction defining an assumption's condition itself is always +  // considered ephemeral to that assumption (even if it has other +  // non-ephemeral users). See r246696's test case for an example. +  if (is_contained(I->operands(), E)) +    return true; + +  while (!WorkSet.empty()) { +    const Value *V = WorkSet.pop_back_val(); +    if (!Visited.insert(V).second) +      continue; + +    // If all uses of this value are ephemeral, then so is this value. +    if (llvm::all_of(V->users(), [&](const User *U) { +                                   return EphValues.count(U); +                                 })) { +      if (V == E) +        return true; + +      if (V == I || isSafeToSpeculativelyExecute(V)) { +       EphValues.insert(V); +       if (const User *U = dyn_cast<User>(V)) +         for (User::const_op_iterator J = U->op_begin(), JE = U->op_end(); +              J != JE; ++J) +           WorkSet.push_back(*J); +      } +    } +  } + +  return false; +} + +// Is this an intrinsic that cannot be speculated but also cannot trap? +bool llvm::isAssumeLikeIntrinsic(const Instruction *I) { +  if (const CallInst *CI = dyn_cast<CallInst>(I)) +    if (Function *F = CI->getCalledFunction()) +      switch (F->getIntrinsicID()) { +      default: break; +      // FIXME: This list is repeated from NoTTI::getIntrinsicCost. +      case Intrinsic::assume: +      case Intrinsic::sideeffect: +      case Intrinsic::dbg_declare: +      case Intrinsic::dbg_value: +      case Intrinsic::dbg_label: +      case Intrinsic::invariant_start: +      case Intrinsic::invariant_end: +      case Intrinsic::lifetime_start: +      case Intrinsic::lifetime_end: +      case Intrinsic::objectsize: +      case Intrinsic::ptr_annotation: +      case Intrinsic::var_annotation: +        return true; +      } + +  return false; +} + +bool llvm::isValidAssumeForContext(const Instruction *Inv, +                                   const Instruction *CxtI, +                                   const DominatorTree *DT) { +  // There are two restrictions on the use of an assume: +  //  1. The assume must dominate the context (or the control flow must +  //     reach the assume whenever it reaches the context). +  //  2. The context must not be in the assume's set of ephemeral values +  //     (otherwise we will use the assume to prove that the condition +  //     feeding the assume is trivially true, thus causing the removal of +  //     the assume). + +  if (DT) { +    if (DT->dominates(Inv, CxtI)) +      return true; +  } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor()) { +    // We don't have a DT, but this trivially dominates. +    return true; +  } + +  // With or without a DT, the only remaining case we will check is if the +  // instructions are in the same BB.  Give up if that is not the case. +  if (Inv->getParent() != CxtI->getParent()) +    return false; + +  // If we have a dom tree, then we now know that the assume doesn't dominate +  // the other instruction.  If we don't have a dom tree then we can check if +  // the assume is first in the BB. +  if (!DT) { +    // Search forward from the assume until we reach the context (or the end +    // of the block); the common case is that the assume will come first. +    for (auto I = std::next(BasicBlock::const_iterator(Inv)), +         IE = Inv->getParent()->end(); I != IE; ++I) +      if (&*I == CxtI) +        return true; +  } + +  // The context comes first, but they're both in the same block. Make sure +  // there is nothing in between that might interrupt the control flow. +  for (BasicBlock::const_iterator I = +         std::next(BasicBlock::const_iterator(CxtI)), IE(Inv); +       I != IE; ++I) +    if (!isSafeToSpeculativelyExecute(&*I) && !isAssumeLikeIntrinsic(&*I)) +      return false; + +  return !isEphemeralValueOf(Inv, CxtI); +} + +static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, +                                       unsigned Depth, const Query &Q) { +  // Use of assumptions is context-sensitive. If we don't have a context, we +  // cannot use them! +  if (!Q.AC || !Q.CxtI) +    return; + +  unsigned BitWidth = Known.getBitWidth(); + +  // Note that the patterns below need to be kept in sync with the code +  // in AssumptionCache::updateAffectedValues. + +  for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { +    if (!AssumeVH) +      continue; +    CallInst *I = cast<CallInst>(AssumeVH); +    assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && +           "Got assumption for the wrong function!"); +    if (Q.isExcluded(I)) +      continue; + +    // Warning: This loop can end up being somewhat performance sensitive. +    // We're running this loop for once for each value queried resulting in a +    // runtime of ~O(#assumes * #values). + +    assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && +           "must be an assume intrinsic"); + +    Value *Arg = I->getArgOperand(0); + +    if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      assert(BitWidth == 1 && "assume operand is not i1?"); +      Known.setAllOnes(); +      return; +    } +    if (match(Arg, m_Not(m_Specific(V))) && +        isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      assert(BitWidth == 1 && "assume operand is not i1?"); +      Known.setAllZero(); +      return; +    } + +    // The remaining tests are all recursive, so bail out if we hit the limit. +    if (Depth == MaxDepth) +      continue; + +    Value *A, *B; +    auto m_V = m_CombineOr(m_Specific(V), +                           m_CombineOr(m_PtrToInt(m_Specific(V)), +                           m_BitCast(m_Specific(V)))); + +    CmpInst::Predicate Pred; +    uint64_t C; +    // assume(v = a) +    if (match(Arg, m_c_ICmp(Pred, m_V, m_Value(A))) && +        Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      Known.Zero |= RHSKnown.Zero; +      Known.One  |= RHSKnown.One; +    // assume(v & b = a) +    } else if (match(Arg, +                     m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      KnownBits MaskKnown(BitWidth); +      computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + +      // For those bits in the mask that are known to be one, we can propagate +      // known bits from the RHS to V. +      Known.Zero |= RHSKnown.Zero & MaskKnown.One; +      Known.One  |= RHSKnown.One  & MaskKnown.One; +    // assume(~(v & b) = a) +    } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), +                                   m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      KnownBits MaskKnown(BitWidth); +      computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + +      // For those bits in the mask that are known to be one, we can propagate +      // inverted known bits from the RHS to V. +      Known.Zero |= RHSKnown.One  & MaskKnown.One; +      Known.One  |= RHSKnown.Zero & MaskKnown.One; +    // assume(v | b = a) +    } else if (match(Arg, +                     m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      KnownBits BKnown(BitWidth); +      computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + +      // For those bits in B that are known to be zero, we can propagate known +      // bits from the RHS to V. +      Known.Zero |= RHSKnown.Zero & BKnown.Zero; +      Known.One  |= RHSKnown.One  & BKnown.Zero; +    // assume(~(v | b) = a) +    } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), +                                   m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      KnownBits BKnown(BitWidth); +      computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + +      // For those bits in B that are known to be zero, we can propagate +      // inverted known bits from the RHS to V. +      Known.Zero |= RHSKnown.One  & BKnown.Zero; +      Known.One  |= RHSKnown.Zero & BKnown.Zero; +    // assume(v ^ b = a) +    } else if (match(Arg, +                     m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      KnownBits BKnown(BitWidth); +      computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + +      // For those bits in B that are known to be zero, we can propagate known +      // bits from the RHS to V. For those bits in B that are known to be one, +      // we can propagate inverted known bits from the RHS to V. +      Known.Zero |= RHSKnown.Zero & BKnown.Zero; +      Known.One  |= RHSKnown.One  & BKnown.Zero; +      Known.Zero |= RHSKnown.One  & BKnown.One; +      Known.One  |= RHSKnown.Zero & BKnown.One; +    // assume(~(v ^ b) = a) +    } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), +                                   m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      KnownBits BKnown(BitWidth); +      computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + +      // For those bits in B that are known to be zero, we can propagate +      // inverted known bits from the RHS to V. For those bits in B that are +      // known to be one, we can propagate known bits from the RHS to V. +      Known.Zero |= RHSKnown.One  & BKnown.Zero; +      Known.One  |= RHSKnown.Zero & BKnown.Zero; +      Known.Zero |= RHSKnown.Zero & BKnown.One; +      Known.One  |= RHSKnown.One  & BKnown.One; +    // assume(v << c = a) +    } else if (match(Arg, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), +                                   m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT) && +               C < BitWidth) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      // For those bits in RHS that are known, we can propagate them to known +      // bits in V shifted to the right by C. +      RHSKnown.Zero.lshrInPlace(C); +      Known.Zero |= RHSKnown.Zero; +      RHSKnown.One.lshrInPlace(C); +      Known.One  |= RHSKnown.One; +    // assume(~(v << c) = a) +    } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), +                                   m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT) && +               C < BitWidth) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      // For those bits in RHS that are known, we can propagate them inverted +      // to known bits in V shifted to the right by C. +      RHSKnown.One.lshrInPlace(C); +      Known.Zero |= RHSKnown.One; +      RHSKnown.Zero.lshrInPlace(C); +      Known.One  |= RHSKnown.Zero; +    // assume(v >> c = a) +    } else if (match(Arg, +                     m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), +                              m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT) && +               C < BitWidth) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      // For those bits in RHS that are known, we can propagate them to known +      // bits in V shifted to the right by C. +      Known.Zero |= RHSKnown.Zero << C; +      Known.One  |= RHSKnown.One  << C; +    // assume(~(v >> c) = a) +    } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), +                                   m_Value(A))) && +               Pred == ICmpInst::ICMP_EQ && +               isValidAssumeForContext(I, Q.CxtI, Q.DT) && +               C < BitWidth) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); +      // For those bits in RHS that are known, we can propagate them inverted +      // to known bits in V shifted to the right by C. +      Known.Zero |= RHSKnown.One  << C; +      Known.One  |= RHSKnown.Zero << C; +    // assume(v >=_s c) where c is non-negative +    } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && +               Pred == ICmpInst::ICMP_SGE && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + +      if (RHSKnown.isNonNegative()) { +        // We know that the sign bit is zero. +        Known.makeNonNegative(); +      } +    // assume(v >_s c) where c is at least -1. +    } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && +               Pred == ICmpInst::ICMP_SGT && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + +      if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { +        // We know that the sign bit is zero. +        Known.makeNonNegative(); +      } +    // assume(v <=_s c) where c is negative +    } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && +               Pred == ICmpInst::ICMP_SLE && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + +      if (RHSKnown.isNegative()) { +        // We know that the sign bit is one. +        Known.makeNegative(); +      } +    // assume(v <_s c) where c is non-positive +    } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && +               Pred == ICmpInst::ICMP_SLT && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + +      if (RHSKnown.isZero() || RHSKnown.isNegative()) { +        // We know that the sign bit is one. +        Known.makeNegative(); +      } +    // assume(v <=_u c) +    } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && +               Pred == ICmpInst::ICMP_ULE && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + +      // Whatever high bits in c are zero are known to be zero. +      Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); +      // assume(v <_u c) +    } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && +               Pred == ICmpInst::ICMP_ULT && +               isValidAssumeForContext(I, Q.CxtI, Q.DT)) { +      KnownBits RHSKnown(BitWidth); +      computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + +      // If the RHS is known zero, then this assumption must be wrong (nothing +      // is unsigned less than zero). Signal a conflict and get out of here. +      if (RHSKnown.isZero()) { +        Known.Zero.setAllBits(); +        Known.One.setAllBits(); +        break; +      } + +      // Whatever high bits in c are zero are known to be zero (if c is a power +      // of 2, then one more). +      if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) +        Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); +      else +        Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); +    } +  } + +  // If assumptions conflict with each other or previous known bits, then we +  // have a logical fallacy. It's possible that the assumption is not reachable, +  // so this isn't a real bug. On the other hand, the program may have undefined +  // behavior, or we might have a bug in the compiler. We can't assert/crash, so +  // clear out the known bits, try to warn the user, and hope for the best. +  if (Known.Zero.intersects(Known.One)) { +    Known.resetAll(); + +    if (Q.ORE) +      Q.ORE->emit([&]() { +        auto *CxtI = const_cast<Instruction *>(Q.CxtI); +        return OptimizationRemarkAnalysis("value-tracking", "BadAssumption", +                                          CxtI) +               << "Detected conflicting code assumptions. Program may " +                  "have undefined behavior, or compiler may have " +                  "internal error."; +      }); +  } +} + +/// Compute known bits from a shift operator, including those with a +/// non-constant shift amount. Known is the output of this function. Known2 is a +/// pre-allocated temporary with the same bit width as Known. KZF and KOF are +/// operator-specific functions that, given the known-zero or known-one bits +/// respectively, and a shift amount, compute the implied known-zero or +/// known-one bits of the shift operator's result respectively for that shift +/// amount. The results from calling KZF and KOF are conservatively combined for +/// all permitted shift amounts. +static void computeKnownBitsFromShiftOperator( +    const Operator *I, KnownBits &Known, KnownBits &Known2, +    unsigned Depth, const Query &Q, +    function_ref<APInt(const APInt &, unsigned)> KZF, +    function_ref<APInt(const APInt &, unsigned)> KOF) { +  unsigned BitWidth = Known.getBitWidth(); + +  if (auto *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { +    unsigned ShiftAmt = SA->getLimitedValue(BitWidth-1); + +    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); +    Known.Zero = KZF(Known.Zero, ShiftAmt); +    Known.One  = KOF(Known.One, ShiftAmt); +    // If the known bits conflict, this must be an overflowing left shift, so +    // the shift result is poison. We can return anything we want. Choose 0 for +    // the best folding opportunity. +    if (Known.hasConflict()) +      Known.setAllZero(); + +    return; +  } + +  computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); + +  // If the shift amount could be greater than or equal to the bit-width of the +  // LHS, the value could be poison, but bail out because the check below is +  // expensive. TODO: Should we just carry on? +  if ((~Known.Zero).uge(BitWidth)) { +    Known.resetAll(); +    return; +  } + +  // Note: We cannot use Known.Zero.getLimitedValue() here, because if +  // BitWidth > 64 and any upper bits are known, we'll end up returning the +  // limit value (which implies all bits are known). +  uint64_t ShiftAmtKZ = Known.Zero.zextOrTrunc(64).getZExtValue(); +  uint64_t ShiftAmtKO = Known.One.zextOrTrunc(64).getZExtValue(); + +  // It would be more-clearly correct to use the two temporaries for this +  // calculation. Reusing the APInts here to prevent unnecessary allocations. +  Known.resetAll(); + +  // If we know the shifter operand is nonzero, we can sometimes infer more +  // known bits. However this is expensive to compute, so be lazy about it and +  // only compute it when absolutely necessary. +  Optional<bool> ShifterOperandIsNonZero; + +  // Early exit if we can't constrain any well-defined shift amount. +  if (!(ShiftAmtKZ & (PowerOf2Ceil(BitWidth) - 1)) && +      !(ShiftAmtKO & (PowerOf2Ceil(BitWidth) - 1))) { +    ShifterOperandIsNonZero = isKnownNonZero(I->getOperand(1), Depth + 1, Q); +    if (!*ShifterOperandIsNonZero) +      return; +  } + +  computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + +  Known.Zero.setAllBits(); +  Known.One.setAllBits(); +  for (unsigned ShiftAmt = 0; ShiftAmt < BitWidth; ++ShiftAmt) { +    // Combine the shifted known input bits only for those shift amounts +    // compatible with its known constraints. +    if ((ShiftAmt & ~ShiftAmtKZ) != ShiftAmt) +      continue; +    if ((ShiftAmt | ShiftAmtKO) != ShiftAmt) +      continue; +    // If we know the shifter is nonzero, we may be able to infer more known +    // bits. This check is sunk down as far as possible to avoid the expensive +    // call to isKnownNonZero if the cheaper checks above fail. +    if (ShiftAmt == 0) { +      if (!ShifterOperandIsNonZero.hasValue()) +        ShifterOperandIsNonZero = +            isKnownNonZero(I->getOperand(1), Depth + 1, Q); +      if (*ShifterOperandIsNonZero) +        continue; +    } + +    Known.Zero &= KZF(Known2.Zero, ShiftAmt); +    Known.One  &= KOF(Known2.One, ShiftAmt); +  } + +  // If the known bits conflict, the result is poison. Return a 0 and hope the +  // caller can further optimize that. +  if (Known.hasConflict()) +    Known.setAllZero(); +} + +static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, +                                         unsigned Depth, const Query &Q) { +  unsigned BitWidth = Known.getBitWidth(); + +  KnownBits Known2(Known); +  switch (I->getOpcode()) { +  default: break; +  case Instruction::Load: +    if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range)) +      computeKnownBitsFromRangeMetadata(*MD, Known); +    break; +  case Instruction::And: { +    // If either the LHS or the RHS are Zero, the result is zero. +    computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); +    computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + +    // Output known-1 bits are only known if set in both the LHS & RHS. +    Known.One &= Known2.One; +    // Output known-0 are known to be clear if zero in either the LHS | RHS. +    Known.Zero |= Known2.Zero; + +    // and(x, add (x, -1)) is a common idiom that always clears the low bit; +    // here we handle the more general case of adding any odd number by +    // matching the form add(x, add(x, y)) where y is odd. +    // TODO: This could be generalized to clearing any bit set in y where the +    // following bit is known to be unset in y. +    Value *X = nullptr, *Y = nullptr; +    if (!Known.Zero[0] && !Known.One[0] && +        match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) { +      Known2.resetAll(); +      computeKnownBits(Y, Known2, Depth + 1, Q); +      if (Known2.countMinTrailingOnes() > 0) +        Known.Zero.setBit(0); +    } +    break; +  } +  case Instruction::Or: +    computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); +    computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + +    // Output known-0 bits are only known if clear in both the LHS & RHS. +    Known.Zero &= Known2.Zero; +    // Output known-1 are known to be set if set in either the LHS | RHS. +    Known.One |= Known2.One; +    break; +  case Instruction::Xor: { +    computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); +    computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + +    // Output known-0 bits are known if clear or set in both the LHS & RHS. +    APInt KnownZeroOut = (Known.Zero & Known2.Zero) | (Known.One & Known2.One); +    // Output known-1 are known to be set if set in only one of the LHS, RHS. +    Known.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero); +    Known.Zero = std::move(KnownZeroOut); +    break; +  } +  case Instruction::Mul: { +    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); +    computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known, +                        Known2, Depth, Q); +    break; +  } +  case Instruction::UDiv: { +    // For the purposes of computing leading zeros we can conservatively +    // treat a udiv as a logical right shift by the power of 2 known to +    // be less than the denominator. +    computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); +    unsigned LeadZ = Known2.countMinLeadingZeros(); + +    Known2.resetAll(); +    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); +    unsigned RHSMaxLeadingZeros = Known2.countMaxLeadingZeros(); +    if (RHSMaxLeadingZeros != BitWidth) +      LeadZ = std::min(BitWidth, LeadZ + BitWidth - RHSMaxLeadingZeros - 1); + +    Known.Zero.setHighBits(LeadZ); +    break; +  } +  case Instruction::Select: { +    const Value *LHS, *RHS; +    SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor; +    if (SelectPatternResult::isMinOrMax(SPF)) { +      computeKnownBits(RHS, Known, Depth + 1, Q); +      computeKnownBits(LHS, Known2, Depth + 1, Q); +    } else { +      computeKnownBits(I->getOperand(2), Known, Depth + 1, Q); +      computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); +    } + +    unsigned MaxHighOnes = 0; +    unsigned MaxHighZeros = 0; +    if (SPF == SPF_SMAX) { +      // If both sides are negative, the result is negative. +      if (Known.isNegative() && Known2.isNegative()) +        // We can derive a lower bound on the result by taking the max of the +        // leading one bits. +        MaxHighOnes = +            std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); +      // If either side is non-negative, the result is non-negative. +      else if (Known.isNonNegative() || Known2.isNonNegative()) +        MaxHighZeros = 1; +    } else if (SPF == SPF_SMIN) { +      // If both sides are non-negative, the result is non-negative. +      if (Known.isNonNegative() && Known2.isNonNegative()) +        // We can derive an upper bound on the result by taking the max of the +        // leading zero bits. +        MaxHighZeros = std::max(Known.countMinLeadingZeros(), +                                Known2.countMinLeadingZeros()); +      // If either side is negative, the result is negative. +      else if (Known.isNegative() || Known2.isNegative()) +        MaxHighOnes = 1; +    } else if (SPF == SPF_UMAX) { +      // We can derive a lower bound on the result by taking the max of the +      // leading one bits. +      MaxHighOnes = +          std::max(Known.countMinLeadingOnes(), Known2.countMinLeadingOnes()); +    } else if (SPF == SPF_UMIN) { +      // We can derive an upper bound on the result by taking the max of the +      // leading zero bits. +      MaxHighZeros = +          std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); +    } else if (SPF == SPF_ABS) { +      // RHS from matchSelectPattern returns the negation part of abs pattern. +      // If the negate has an NSW flag we can assume the sign bit of the result +      // will be 0 because that makes abs(INT_MIN) undefined. +      if (cast<Instruction>(RHS)->hasNoSignedWrap()) +        MaxHighZeros = 1; +    } + +    // Only known if known in both the LHS and RHS. +    Known.One &= Known2.One; +    Known.Zero &= Known2.Zero; +    if (MaxHighOnes > 0) +      Known.One.setHighBits(MaxHighOnes); +    if (MaxHighZeros > 0) +      Known.Zero.setHighBits(MaxHighZeros); +    break; +  } +  case Instruction::FPTrunc: +  case Instruction::FPExt: +  case Instruction::FPToUI: +  case Instruction::FPToSI: +  case Instruction::SIToFP: +  case Instruction::UIToFP: +    break; // Can't work with floating point. +  case Instruction::PtrToInt: +  case Instruction::IntToPtr: +    // Fall through and handle them the same as zext/trunc. +    LLVM_FALLTHROUGH; +  case Instruction::ZExt: +  case Instruction::Trunc: { +    Type *SrcTy = I->getOperand(0)->getType(); + +    unsigned SrcBitWidth; +    // Note that we handle pointer operands here because of inttoptr/ptrtoint +    // which fall through here. +    Type *ScalarTy = SrcTy->getScalarType(); +    SrcBitWidth = ScalarTy->isPointerTy() ? +      Q.DL.getIndexTypeSizeInBits(ScalarTy) : +      Q.DL.getTypeSizeInBits(ScalarTy); + +    assert(SrcBitWidth && "SrcBitWidth can't be zero"); +    Known = Known.zextOrTrunc(SrcBitWidth); +    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); +    Known = Known.zextOrTrunc(BitWidth); +    // Any top bits are known to be zero. +    if (BitWidth > SrcBitWidth) +      Known.Zero.setBitsFrom(SrcBitWidth); +    break; +  } +  case Instruction::BitCast: { +    Type *SrcTy = I->getOperand(0)->getType(); +    if (SrcTy->isIntOrPtrTy() && +        // TODO: For now, not handling conversions like: +        // (bitcast i64 %x to <2 x i32>) +        !I->getType()->isVectorTy()) { +      computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); +      break; +    } +    break; +  } +  case Instruction::SExt: { +    // Compute the bits in the result that are not present in the input. +    unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); + +    Known = Known.trunc(SrcBitWidth); +    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); +    // If the sign bit of the input is known set or clear, then we know the +    // top bits of the result. +    Known = Known.sext(BitWidth); +    break; +  } +  case Instruction::Shl: { +    // (shl X, C1) & C2 == 0   iff   (X & C2 >>u C1) == 0 +    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); +    auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) { +      APInt KZResult = KnownZero << ShiftAmt; +      KZResult.setLowBits(ShiftAmt); // Low bits known 0. +      // If this shift has "nsw" keyword, then the result is either a poison +      // value or has the same sign bit as the first operand. +      if (NSW && KnownZero.isSignBitSet()) +        KZResult.setSignBit(); +      return KZResult; +    }; + +    auto KOF = [NSW](const APInt &KnownOne, unsigned ShiftAmt) { +      APInt KOResult = KnownOne << ShiftAmt; +      if (NSW && KnownOne.isSignBitSet()) +        KOResult.setSignBit(); +      return KOResult; +    }; + +    computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); +    break; +  } +  case Instruction::LShr: { +    // (lshr X, C1) & C2 == 0   iff  (-1 >> C1) & C2 == 0 +    auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { +      APInt KZResult = KnownZero.lshr(ShiftAmt); +      // High bits known zero. +      KZResult.setHighBits(ShiftAmt); +      return KZResult; +    }; + +    auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { +      return KnownOne.lshr(ShiftAmt); +    }; + +    computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); +    break; +  } +  case Instruction::AShr: { +    // (ashr X, C1) & C2 == 0   iff  (-1 >> C1) & C2 == 0 +    auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { +      return KnownZero.ashr(ShiftAmt); +    }; + +    auto KOF = [](const APInt &KnownOne, unsigned ShiftAmt) { +      return KnownOne.ashr(ShiftAmt); +    }; + +    computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); +    break; +  } +  case Instruction::Sub: { +    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); +    computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, +                           Known, Known2, Depth, Q); +    break; +  } +  case Instruction::Add: { +    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); +    computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, +                           Known, Known2, Depth, Q); +    break; +  } +  case Instruction::SRem: +    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { +      APInt RA = Rem->getValue().abs(); +      if (RA.isPowerOf2()) { +        APInt LowBits = RA - 1; +        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + +        // The low bits of the first operand are unchanged by the srem. +        Known.Zero = Known2.Zero & LowBits; +        Known.One = Known2.One & LowBits; + +        // If the first operand is non-negative or has all low bits zero, then +        // the upper bits are all zero. +        if (Known2.isNonNegative() || LowBits.isSubsetOf(Known2.Zero)) +          Known.Zero |= ~LowBits; + +        // If the first operand is negative and not all low bits are zero, then +        // the upper bits are all one. +        if (Known2.isNegative() && LowBits.intersects(Known2.One)) +          Known.One |= ~LowBits; + +        assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?"); +        break; +      } +    } + +    // The sign bit is the LHS's sign bit, except when the result of the +    // remainder is zero. +    computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); +    // If it's known zero, our sign bit is also zero. +    if (Known2.isNonNegative()) +      Known.makeNonNegative(); + +    break; +  case Instruction::URem: { +    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { +      const APInt &RA = Rem->getValue(); +      if (RA.isPowerOf2()) { +        APInt LowBits = (RA - 1); +        computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); +        Known.Zero |= ~LowBits; +        Known.One &= LowBits; +        break; +      } +    } + +    // Since the result is less than or equal to either operand, any leading +    // zero bits in either operand must also exist in the result. +    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); +    computeKnownBits(I->getOperand(1), Known2, Depth + 1, Q); + +    unsigned Leaders = +        std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); +    Known.resetAll(); +    Known.Zero.setHighBits(Leaders); +    break; +  } + +  case Instruction::Alloca: { +    const AllocaInst *AI = cast<AllocaInst>(I); +    unsigned Align = AI->getAlignment(); +    if (Align == 0) +      Align = Q.DL.getABITypeAlignment(AI->getAllocatedType()); + +    if (Align > 0) +      Known.Zero.setLowBits(countTrailingZeros(Align)); +    break; +  } +  case Instruction::GetElementPtr: { +    // Analyze all of the subscripts of this getelementptr instruction +    // to determine if we can prove known low zero bits. +    KnownBits LocalKnown(BitWidth); +    computeKnownBits(I->getOperand(0), LocalKnown, Depth + 1, Q); +    unsigned TrailZ = LocalKnown.countMinTrailingZeros(); + +    gep_type_iterator GTI = gep_type_begin(I); +    for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) { +      Value *Index = I->getOperand(i); +      if (StructType *STy = GTI.getStructTypeOrNull()) { +        // Handle struct member offset arithmetic. + +        // Handle case when index is vector zeroinitializer +        Constant *CIndex = cast<Constant>(Index); +        if (CIndex->isZeroValue()) +          continue; + +        if (CIndex->getType()->isVectorTy()) +          Index = CIndex->getSplatValue(); + +        unsigned Idx = cast<ConstantInt>(Index)->getZExtValue(); +        const StructLayout *SL = Q.DL.getStructLayout(STy); +        uint64_t Offset = SL->getElementOffset(Idx); +        TrailZ = std::min<unsigned>(TrailZ, +                                    countTrailingZeros(Offset)); +      } else { +        // Handle array index arithmetic. +        Type *IndexedTy = GTI.getIndexedType(); +        if (!IndexedTy->isSized()) { +          TrailZ = 0; +          break; +        } +        unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); +        uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy); +        LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); +        computeKnownBits(Index, LocalKnown, Depth + 1, Q); +        TrailZ = std::min(TrailZ, +                          unsigned(countTrailingZeros(TypeSize) + +                                   LocalKnown.countMinTrailingZeros())); +      } +    } + +    Known.Zero.setLowBits(TrailZ); +    break; +  } +  case Instruction::PHI: { +    const PHINode *P = cast<PHINode>(I); +    // Handle the case of a simple two-predecessor recurrence PHI. +    // There's a lot more that could theoretically be done here, but +    // this is sufficient to catch some interesting cases. +    if (P->getNumIncomingValues() == 2) { +      for (unsigned i = 0; i != 2; ++i) { +        Value *L = P->getIncomingValue(i); +        Value *R = P->getIncomingValue(!i); +        Operator *LU = dyn_cast<Operator>(L); +        if (!LU) +          continue; +        unsigned Opcode = LU->getOpcode(); +        // Check for operations that have the property that if +        // both their operands have low zero bits, the result +        // will have low zero bits. +        if (Opcode == Instruction::Add || +            Opcode == Instruction::Sub || +            Opcode == Instruction::And || +            Opcode == Instruction::Or || +            Opcode == Instruction::Mul) { +          Value *LL = LU->getOperand(0); +          Value *LR = LU->getOperand(1); +          // Find a recurrence. +          if (LL == I) +            L = LR; +          else if (LR == I) +            L = LL; +          else +            break; +          // Ok, we have a PHI of the form L op= R. Check for low +          // zero bits. +          computeKnownBits(R, Known2, Depth + 1, Q); + +          // We need to take the minimum number of known bits +          KnownBits Known3(Known); +          computeKnownBits(L, Known3, Depth + 1, Q); + +          Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(), +                                         Known3.countMinTrailingZeros())); + +          auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(LU); +          if (OverflowOp && OverflowOp->hasNoSignedWrap()) { +            // If initial value of recurrence is nonnegative, and we are adding +            // a nonnegative number with nsw, the result can only be nonnegative +            // or poison value regardless of the number of times we execute the +            // add in phi recurrence. If initial value is negative and we are +            // adding a negative number with nsw, the result can only be +            // negative or poison value. Similar arguments apply to sub and mul. +            // +            // (add non-negative, non-negative) --> non-negative +            // (add negative, negative) --> negative +            if (Opcode == Instruction::Add) { +              if (Known2.isNonNegative() && Known3.isNonNegative()) +                Known.makeNonNegative(); +              else if (Known2.isNegative() && Known3.isNegative()) +                Known.makeNegative(); +            } + +            // (sub nsw non-negative, negative) --> non-negative +            // (sub nsw negative, non-negative) --> negative +            else if (Opcode == Instruction::Sub && LL == I) { +              if (Known2.isNonNegative() && Known3.isNegative()) +                Known.makeNonNegative(); +              else if (Known2.isNegative() && Known3.isNonNegative()) +                Known.makeNegative(); +            } + +            // (mul nsw non-negative, non-negative) --> non-negative +            else if (Opcode == Instruction::Mul && Known2.isNonNegative() && +                     Known3.isNonNegative()) +              Known.makeNonNegative(); +          } + +          break; +        } +      } +    } + +    // Unreachable blocks may have zero-operand PHI nodes. +    if (P->getNumIncomingValues() == 0) +      break; + +    // Otherwise take the unions of the known bit sets of the operands, +    // taking conservative care to avoid excessive recursion. +    if (Depth < MaxDepth - 1 && !Known.Zero && !Known.One) { +      // Skip if every incoming value references to ourself. +      if (dyn_cast_or_null<UndefValue>(P->hasConstantValue())) +        break; + +      Known.Zero.setAllBits(); +      Known.One.setAllBits(); +      for (Value *IncValue : P->incoming_values()) { +        // Skip direct self references. +        if (IncValue == P) continue; + +        Known2 = KnownBits(BitWidth); +        // Recurse, but cap the recursion to one level, because we don't +        // want to waste time spinning around in loops. +        computeKnownBits(IncValue, Known2, MaxDepth - 1, Q); +        Known.Zero &= Known2.Zero; +        Known.One &= Known2.One; +        // If all bits have been ruled out, there's no need to check +        // more operands. +        if (!Known.Zero && !Known.One) +          break; +      } +    } +    break; +  } +  case Instruction::Call: +  case Instruction::Invoke: +    // If range metadata is attached to this call, set known bits from that, +    // and then intersect with known bits based on other properties of the +    // function. +    if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range)) +      computeKnownBitsFromRangeMetadata(*MD, Known); +    if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) { +      computeKnownBits(RV, Known2, Depth + 1, Q); +      Known.Zero |= Known2.Zero; +      Known.One |= Known2.One; +    } +    if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { +      switch (II->getIntrinsicID()) { +      default: break; +      case Intrinsic::bitreverse: +        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); +        Known.Zero |= Known2.Zero.reverseBits(); +        Known.One |= Known2.One.reverseBits(); +        break; +      case Intrinsic::bswap: +        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); +        Known.Zero |= Known2.Zero.byteSwap(); +        Known.One |= Known2.One.byteSwap(); +        break; +      case Intrinsic::ctlz: { +        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); +        // If we have a known 1, its position is our upper bound. +        unsigned PossibleLZ = Known2.One.countLeadingZeros(); +        // If this call is undefined for 0, the result will be less than 2^n. +        if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext())) +          PossibleLZ = std::min(PossibleLZ, BitWidth - 1); +        unsigned LowBits = Log2_32(PossibleLZ)+1; +        Known.Zero.setBitsFrom(LowBits); +        break; +      } +      case Intrinsic::cttz: { +        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); +        // If we have a known 1, its position is our upper bound. +        unsigned PossibleTZ = Known2.One.countTrailingZeros(); +        // If this call is undefined for 0, the result will be less than 2^n. +        if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext())) +          PossibleTZ = std::min(PossibleTZ, BitWidth - 1); +        unsigned LowBits = Log2_32(PossibleTZ)+1; +        Known.Zero.setBitsFrom(LowBits); +        break; +      } +      case Intrinsic::ctpop: { +        computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); +        // We can bound the space the count needs.  Also, bits known to be zero +        // can't contribute to the population. +        unsigned BitsPossiblySet = Known2.countMaxPopulation(); +        unsigned LowBits = Log2_32(BitsPossiblySet)+1; +        Known.Zero.setBitsFrom(LowBits); +        // TODO: we could bound KnownOne using the lower bound on the number +        // of bits which might be set provided by popcnt KnownOne2. +        break; +      } +      case Intrinsic::x86_sse42_crc32_64_64: +        Known.Zero.setBitsFrom(32); +        break; +      } +    } +    break; +  case Instruction::ExtractElement: +    // Look through extract element. At the moment we keep this simple and skip +    // tracking the specific element. But at least we might find information +    // valid for all elements of the vector (for example if vector is sign +    // extended, shifted, etc). +    computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); +    break; +  case Instruction::ExtractValue: +    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) { +      const ExtractValueInst *EVI = cast<ExtractValueInst>(I); +      if (EVI->getNumIndices() != 1) break; +      if (EVI->getIndices()[0] == 0) { +        switch (II->getIntrinsicID()) { +        default: break; +        case Intrinsic::uadd_with_overflow: +        case Intrinsic::sadd_with_overflow: +          computeKnownBitsAddSub(true, II->getArgOperand(0), +                                 II->getArgOperand(1), false, Known, Known2, +                                 Depth, Q); +          break; +        case Intrinsic::usub_with_overflow: +        case Intrinsic::ssub_with_overflow: +          computeKnownBitsAddSub(false, II->getArgOperand(0), +                                 II->getArgOperand(1), false, Known, Known2, +                                 Depth, Q); +          break; +        case Intrinsic::umul_with_overflow: +        case Intrinsic::smul_with_overflow: +          computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, +                              Known, Known2, Depth, Q); +          break; +        } +      } +    } +  } +} + +/// Determine which bits of V are known to be either zero or one and return +/// them. +KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { +  KnownBits Known(getBitWidth(V->getType(), Q.DL)); +  computeKnownBits(V, Known, Depth, Q); +  return Known; +} + +/// Determine which bits of V are known to be either zero or one and return +/// them in the Known bit set. +/// +/// NOTE: we cannot consider 'undef' to be "IsZero" here.  The problem is that +/// we cannot optimize based on the assumption that it is zero without changing +/// it to be an explicit zero.  If we don't change it to zero, other code could +/// optimized based on the contradictory assumption that it is non-zero. +/// Because instcombine aggressively folds operations with undef args anyway, +/// this won't lose us code quality. +/// +/// This function is defined on values with integer type, values with pointer +/// type, and vectors of integers.  In the case +/// where V is a vector, known zero, and known one values are the +/// same width as the vector element, and the bit is set only if it is true +/// for all of the elements in the vector. +void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, +                      const Query &Q) { +  assert(V && "No Value?"); +  assert(Depth <= MaxDepth && "Limit Search Depth"); +  unsigned BitWidth = Known.getBitWidth(); + +  assert((V->getType()->isIntOrIntVectorTy(BitWidth) || +          V->getType()->isPtrOrPtrVectorTy()) && +         "Not integer or pointer type!"); + +  Type *ScalarTy = V->getType()->getScalarType(); +  unsigned ExpectedWidth = ScalarTy->isPointerTy() ? +    Q.DL.getIndexTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); +  assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth"); +  (void)BitWidth; +  (void)ExpectedWidth; + +  const APInt *C; +  if (match(V, m_APInt(C))) { +    // We know all of the bits for a scalar constant or a splat vector constant! +    Known.One = *C; +    Known.Zero = ~Known.One; +    return; +  } +  // Null and aggregate-zero are all-zeros. +  if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) { +    Known.setAllZero(); +    return; +  } +  // Handle a constant vector by taking the intersection of the known bits of +  // each element. +  if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) { +    // We know that CDS must be a vector of integers. Take the intersection of +    // each element. +    Known.Zero.setAllBits(); Known.One.setAllBits(); +    for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { +      APInt Elt = CDS->getElementAsAPInt(i); +      Known.Zero &= ~Elt; +      Known.One &= Elt; +    } +    return; +  } + +  if (const auto *CV = dyn_cast<ConstantVector>(V)) { +    // We know that CV must be a vector of integers. Take the intersection of +    // each element. +    Known.Zero.setAllBits(); Known.One.setAllBits(); +    for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { +      Constant *Element = CV->getAggregateElement(i); +      auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); +      if (!ElementCI) { +        Known.resetAll(); +        return; +      } +      const APInt &Elt = ElementCI->getValue(); +      Known.Zero &= ~Elt; +      Known.One &= Elt; +    } +    return; +  } + +  // Start out not knowing anything. +  Known.resetAll(); + +  // We can't imply anything about undefs. +  if (isa<UndefValue>(V)) +    return; + +  // There's no point in looking through other users of ConstantData for +  // assumptions.  Confirm that we've handled them all. +  assert(!isa<ConstantData>(V) && "Unhandled constant data!"); + +  // Limit search depth. +  // All recursive calls that increase depth must come after this. +  if (Depth == MaxDepth) +    return; + +  // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has +  // the bits of its aliasee. +  if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { +    if (!GA->isInterposable()) +      computeKnownBits(GA->getAliasee(), Known, Depth + 1, Q); +    return; +  } + +  if (const Operator *I = dyn_cast<Operator>(V)) +    computeKnownBitsFromOperator(I, Known, Depth, Q); + +  // Aligned pointers have trailing zeros - refine Known.Zero set +  if (V->getType()->isPointerTy()) { +    unsigned Align = V->getPointerAlignment(Q.DL); +    if (Align) +      Known.Zero.setLowBits(countTrailingZeros(Align)); +  } + +  // computeKnownBitsFromAssume strictly refines Known. +  // Therefore, we run them after computeKnownBitsFromOperator. + +  // Check whether a nearby assume intrinsic can determine some known bits. +  computeKnownBitsFromAssume(V, Known, Depth, Q); + +  assert((Known.Zero & Known.One) == 0 && "Bits known to be one AND zero?"); +} + +/// Return true if the given value is known to have exactly one +/// bit set when defined. For vectors return true if every element is known to +/// be a power of two when defined. Supports values with integer or pointer +/// types and vectors of integers. +bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, +                            const Query &Q) { +  assert(Depth <= MaxDepth && "Limit Search Depth"); + +  // Attempt to match against constants. +  if (OrZero && match(V, m_Power2OrZero())) +      return true; +  if (match(V, m_Power2())) +      return true; + +  // 1 << X is clearly a power of two if the one is not shifted off the end.  If +  // it is shifted off the end then the result is undefined. +  if (match(V, m_Shl(m_One(), m_Value()))) +    return true; + +  // (signmask) >>l X is clearly a power of two if the one is not shifted off +  // the bottom.  If it is shifted off the bottom then the result is undefined. +  if (match(V, m_LShr(m_SignMask(), m_Value()))) +    return true; + +  // The remaining tests are all recursive, so bail out if we hit the limit. +  if (Depth++ == MaxDepth) +    return false; + +  Value *X = nullptr, *Y = nullptr; +  // A shift left or a logical shift right of a power of two is a power of two +  // or zero. +  if (OrZero && (match(V, m_Shl(m_Value(X), m_Value())) || +                 match(V, m_LShr(m_Value(X), m_Value())))) +    return isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q); + +  if (const ZExtInst *ZI = dyn_cast<ZExtInst>(V)) +    return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q); + +  if (const SelectInst *SI = dyn_cast<SelectInst>(V)) +    return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q) && +           isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q); + +  if (OrZero && match(V, m_And(m_Value(X), m_Value(Y)))) { +    // A power of two and'd with anything is a power of two or zero. +    if (isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q) || +        isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, Depth, Q)) +      return true; +    // X & (-X) is always a power of two or zero. +    if (match(X, m_Neg(m_Specific(Y))) || match(Y, m_Neg(m_Specific(X)))) +      return true; +    return false; +  } + +  // Adding a power-of-two or zero to the same power-of-two or zero yields +  // either the original power-of-two, a larger power-of-two or zero. +  if (match(V, m_Add(m_Value(X), m_Value(Y)))) { +    const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V); +    if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) { +      if (match(X, m_And(m_Specific(Y), m_Value())) || +          match(X, m_And(m_Value(), m_Specific(Y)))) +        if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) +          return true; +      if (match(Y, m_And(m_Specific(X), m_Value())) || +          match(Y, m_And(m_Value(), m_Specific(X)))) +        if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q)) +          return true; + +      unsigned BitWidth = V->getType()->getScalarSizeInBits(); +      KnownBits LHSBits(BitWidth); +      computeKnownBits(X, LHSBits, Depth, Q); + +      KnownBits RHSBits(BitWidth); +      computeKnownBits(Y, RHSBits, Depth, Q); +      // If i8 V is a power of two or zero: +      //  ZeroBits: 1 1 1 0 1 1 1 1 +      // ~ZeroBits: 0 0 0 1 0 0 0 0 +      if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2()) +        // If OrZero isn't set, we cannot give back a zero result. +        // Make sure either the LHS or RHS has a bit set. +        if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue()) +          return true; +    } +  } + +  // An exact divide or right shift can only shift off zero bits, so the result +  // is a power of two only if the first operand is a power of two and not +  // copying a sign bit (sdiv int_min, 2). +  if (match(V, m_Exact(m_LShr(m_Value(), m_Value()))) || +      match(V, m_Exact(m_UDiv(m_Value(), m_Value())))) { +    return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, +                                  Depth, Q); +  } + +  return false; +} + +/// Test whether a GEP's result is known to be non-null. +/// +/// Uses properties inherent in a GEP to try to determine whether it is known +/// to be non-null. +/// +/// Currently this routine does not support vector GEPs. +static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth, +                              const Query &Q) { +  const Function *F = nullptr; +  if (const Instruction *I = dyn_cast<Instruction>(GEP)) +    F = I->getFunction(); + +  if (!GEP->isInBounds() || +      NullPointerIsDefined(F, GEP->getPointerAddressSpace())) +    return false; + +  // FIXME: Support vector-GEPs. +  assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP"); + +  // If the base pointer is non-null, we cannot walk to a null address with an +  // inbounds GEP in address space zero. +  if (isKnownNonZero(GEP->getPointerOperand(), Depth, Q)) +    return true; + +  // Walk the GEP operands and see if any operand introduces a non-zero offset. +  // If so, then the GEP cannot produce a null pointer, as doing so would +  // inherently violate the inbounds contract within address space zero. +  for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP); +       GTI != GTE; ++GTI) { +    // Struct types are easy -- they must always be indexed by a constant. +    if (StructType *STy = GTI.getStructTypeOrNull()) { +      ConstantInt *OpC = cast<ConstantInt>(GTI.getOperand()); +      unsigned ElementIdx = OpC->getZExtValue(); +      const StructLayout *SL = Q.DL.getStructLayout(STy); +      uint64_t ElementOffset = SL->getElementOffset(ElementIdx); +      if (ElementOffset > 0) +        return true; +      continue; +    } + +    // If we have a zero-sized type, the index doesn't matter. Keep looping. +    if (Q.DL.getTypeAllocSize(GTI.getIndexedType()) == 0) +      continue; + +    // Fast path the constant operand case both for efficiency and so we don't +    // increment Depth when just zipping down an all-constant GEP. +    if (ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand())) { +      if (!OpC->isZero()) +        return true; +      continue; +    } + +    // We post-increment Depth here because while isKnownNonZero increments it +    // as well, when we pop back up that increment won't persist. We don't want +    // to recurse 10k times just because we have 10k GEP operands. We don't +    // bail completely out because we want to handle constant GEPs regardless +    // of depth. +    if (Depth++ >= MaxDepth) +      continue; + +    if (isKnownNonZero(GTI.getOperand(), Depth, Q)) +      return true; +  } + +  return false; +} + +static bool isKnownNonNullFromDominatingCondition(const Value *V, +                                                  const Instruction *CtxI, +                                                  const DominatorTree *DT) { +  assert(V->getType()->isPointerTy() && "V must be pointer type"); +  assert(!isa<ConstantData>(V) && "Did not expect ConstantPointerNull"); + +  if (!CtxI || !DT) +    return false; + +  unsigned NumUsesExplored = 0; +  for (auto *U : V->users()) { +    // Avoid massive lists +    if (NumUsesExplored >= DomConditionsMaxUses) +      break; +    NumUsesExplored++; + +    // If the value is used as an argument to a call or invoke, then argument +    // attributes may provide an answer about null-ness. +    if (auto CS = ImmutableCallSite(U)) +      if (auto *CalledFunc = CS.getCalledFunction()) +        for (const Argument &Arg : CalledFunc->args()) +          if (CS.getArgOperand(Arg.getArgNo()) == V && +              Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI)) +            return true; + +    // Consider only compare instructions uniquely controlling a branch +    CmpInst::Predicate Pred; +    if (!match(const_cast<User *>(U), +               m_c_ICmp(Pred, m_Specific(V), m_Zero())) || +        (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)) +      continue; + +    for (auto *CmpU : U->users()) { +      if (const BranchInst *BI = dyn_cast<BranchInst>(CmpU)) { +        assert(BI->isConditional() && "uses a comparison!"); + +        BasicBlock *NonNullSuccessor = +            BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); +        BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); +        if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) +          return true; +      } else if (Pred == ICmpInst::ICMP_NE && +                 match(CmpU, m_Intrinsic<Intrinsic::experimental_guard>()) && +                 DT->dominates(cast<Instruction>(CmpU), CtxI)) { +        return true; +      } +    } +  } + +  return false; +} + +/// Does the 'Range' metadata (which must be a valid MD_range operand list) +/// ensure that the value it's attached to is never Value?  'RangeType' is +/// is the type of the value described by the range. +static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) { +  const unsigned NumRanges = Ranges->getNumOperands() / 2; +  assert(NumRanges >= 1); +  for (unsigned i = 0; i < NumRanges; ++i) { +    ConstantInt *Lower = +        mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 0)); +    ConstantInt *Upper = +        mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 1)); +    ConstantRange Range(Lower->getValue(), Upper->getValue()); +    if (Range.contains(Value)) +      return false; +  } +  return true; +} + +/// Return true if the given value is known to be non-zero when defined. For +/// vectors, return true if every element is known to be non-zero when +/// defined. For pointers, if the context instruction and dominator tree are +/// specified, perform context-sensitive analysis and return true if the +/// pointer couldn't possibly be null at the specified instruction. +/// Supports values with integer or pointer type and vectors of integers. +bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { +  if (auto *C = dyn_cast<Constant>(V)) { +    if (C->isNullValue()) +      return false; +    if (isa<ConstantInt>(C)) +      // Must be non-zero due to null test above. +      return true; + +    // For constant vectors, check that all elements are undefined or known +    // non-zero to determine that the whole vector is known non-zero. +    if (auto *VecTy = dyn_cast<VectorType>(C->getType())) { +      for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) { +        Constant *Elt = C->getAggregateElement(i); +        if (!Elt || Elt->isNullValue()) +          return false; +        if (!isa<UndefValue>(Elt) && !isa<ConstantInt>(Elt)) +          return false; +      } +      return true; +    } + +    // A global variable in address space 0 is non null unless extern weak +    // or an absolute symbol reference. Other address spaces may have null as a +    // valid address for a global, so we can't assume anything. +    if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) { +      if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() && +          GV->getType()->getAddressSpace() == 0) +        return true; +    } else +      return false; +  } + +  if (auto *I = dyn_cast<Instruction>(V)) { +    if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) { +      // If the possible ranges don't contain zero, then the value is +      // definitely non-zero. +      if (auto *Ty = dyn_cast<IntegerType>(V->getType())) { +        const APInt ZeroValue(Ty->getBitWidth(), 0); +        if (rangeMetadataExcludesValue(Ranges, ZeroValue)) +          return true; +      } +    } +  } + +  // Some of the tests below are recursive, so bail out if we hit the limit. +  if (Depth++ >= MaxDepth) +    return false; + +  // Check for pointer simplifications. +  if (V->getType()->isPointerTy()) { +    // Alloca never returns null, malloc might. +    if (isa<AllocaInst>(V) && Q.DL.getAllocaAddrSpace() == 0) +      return true; + +    // A byval, inalloca, or nonnull argument is never null. +    if (const Argument *A = dyn_cast<Argument>(V)) +      if (A->hasByValOrInAllocaAttr() || A->hasNonNullAttr()) +        return true; + +    // A Load tagged with nonnull metadata is never null. +    if (const LoadInst *LI = dyn_cast<LoadInst>(V)) +      if (LI->getMetadata(LLVMContext::MD_nonnull)) +        return true; + +    if (auto CS = ImmutableCallSite(V)) { +      if (CS.isReturnNonNull()) +        return true; +      if (const auto *RP = getArgumentAliasingToReturnedPointer(CS)) +        return isKnownNonZero(RP, Depth, Q); +    } +  } + + +  // Check for recursive pointer simplifications. +  if (V->getType()->isPointerTy()) { +    if (isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT)) +      return true; + +    if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) +      if (isGEPKnownNonNull(GEP, Depth, Q)) +        return true; +  } + +  unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), Q.DL); + +  // X | Y != 0 if X != 0 or Y != 0. +  Value *X = nullptr, *Y = nullptr; +  if (match(V, m_Or(m_Value(X), m_Value(Y)))) +    return isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q); + +  // ext X != 0 if X != 0. +  if (isa<SExtInst>(V) || isa<ZExtInst>(V)) +    return isKnownNonZero(cast<Instruction>(V)->getOperand(0), Depth, Q); + +  // shl X, Y != 0 if X is odd.  Note that the value of the shift is undefined +  // if the lowest bit is shifted off the end. +  if (match(V, m_Shl(m_Value(X), m_Value(Y)))) { +    // shl nuw can't remove any non-zero bits. +    const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); +    if (BO->hasNoUnsignedWrap()) +      return isKnownNonZero(X, Depth, Q); + +    KnownBits Known(BitWidth); +    computeKnownBits(X, Known, Depth, Q); +    if (Known.One[0]) +      return true; +  } +  // shr X, Y != 0 if X is negative.  Note that the value of the shift is not +  // defined if the sign bit is shifted off the end. +  else if (match(V, m_Shr(m_Value(X), m_Value(Y)))) { +    // shr exact can only shift out zero bits. +    const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V); +    if (BO->isExact()) +      return isKnownNonZero(X, Depth, Q); + +    KnownBits Known = computeKnownBits(X, Depth, Q); +    if (Known.isNegative()) +      return true; + +    // If the shifter operand is a constant, and all of the bits shifted +    // out are known to be zero, and X is known non-zero then at least one +    // non-zero bit must remain. +    if (ConstantInt *Shift = dyn_cast<ConstantInt>(Y)) { +      auto ShiftVal = Shift->getLimitedValue(BitWidth - 1); +      // Is there a known one in the portion not shifted out? +      if (Known.countMaxLeadingZeros() < BitWidth - ShiftVal) +        return true; +      // Are all the bits to be shifted out known zero? +      if (Known.countMinTrailingZeros() >= ShiftVal) +        return isKnownNonZero(X, Depth, Q); +    } +  } +  // div exact can only produce a zero if the dividend is zero. +  else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) { +    return isKnownNonZero(X, Depth, Q); +  } +  // X + Y. +  else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { +    KnownBits XKnown = computeKnownBits(X, Depth, Q); +    KnownBits YKnown = computeKnownBits(Y, Depth, Q); + +    // If X and Y are both non-negative (as signed values) then their sum is not +    // zero unless both X and Y are zero. +    if (XKnown.isNonNegative() && YKnown.isNonNegative()) +      if (isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q)) +        return true; + +    // If X and Y are both negative (as signed values) then their sum is not +    // zero unless both X and Y equal INT_MIN. +    if (XKnown.isNegative() && YKnown.isNegative()) { +      APInt Mask = APInt::getSignedMaxValue(BitWidth); +      // The sign bit of X is set.  If some other bit is set then X is not equal +      // to INT_MIN. +      if (XKnown.One.intersects(Mask)) +        return true; +      // The sign bit of Y is set.  If some other bit is set then Y is not equal +      // to INT_MIN. +      if (YKnown.One.intersects(Mask)) +        return true; +    } + +    // The sum of a non-negative number and a power of two is not zero. +    if (XKnown.isNonNegative() && +        isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q)) +      return true; +    if (YKnown.isNonNegative() && +        isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q)) +      return true; +  } +  // X * Y. +  else if (match(V, m_Mul(m_Value(X), m_Value(Y)))) { +    const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); +    // If X and Y are non-zero then so is X * Y as long as the multiplication +    // does not overflow. +    if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) && +        isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q)) +      return true; +  } +  // (C ? X : Y) != 0 if X != 0 and Y != 0. +  else if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { +    if (isKnownNonZero(SI->getTrueValue(), Depth, Q) && +        isKnownNonZero(SI->getFalseValue(), Depth, Q)) +      return true; +  } +  // PHI +  else if (const PHINode *PN = dyn_cast<PHINode>(V)) { +    // Try and detect a recurrence that monotonically increases from a +    // starting value, as these are common as induction variables. +    if (PN->getNumIncomingValues() == 2) { +      Value *Start = PN->getIncomingValue(0); +      Value *Induction = PN->getIncomingValue(1); +      if (isa<ConstantInt>(Induction) && !isa<ConstantInt>(Start)) +        std::swap(Start, Induction); +      if (ConstantInt *C = dyn_cast<ConstantInt>(Start)) { +        if (!C->isZero() && !C->isNegative()) { +          ConstantInt *X; +          if ((match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) || +               match(Induction, m_NUWAdd(m_Specific(PN), m_ConstantInt(X)))) && +              !X->isNegative()) +            return true; +        } +      } +    } +    // Check if all incoming values are non-zero constant. +    bool AllNonZeroConstants = llvm::all_of(PN->operands(), [](Value *V) { +      return isa<ConstantInt>(V) && !cast<ConstantInt>(V)->isZero(); +    }); +    if (AllNonZeroConstants) +      return true; +  } + +  KnownBits Known(BitWidth); +  computeKnownBits(V, Known, Depth, Q); +  return Known.One != 0; +} + +/// Return true if V2 == V1 + X, where X is known non-zero. +static bool isAddOfNonZero(const Value *V1, const Value *V2, const Query &Q) { +  const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1); +  if (!BO || BO->getOpcode() != Instruction::Add) +    return false; +  Value *Op = nullptr; +  if (V2 == BO->getOperand(0)) +    Op = BO->getOperand(1); +  else if (V2 == BO->getOperand(1)) +    Op = BO->getOperand(0); +  else +    return false; +  return isKnownNonZero(Op, 0, Q); +} + +/// Return true if it is known that V1 != V2. +static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q) { +  if (V1 == V2) +    return false; +  if (V1->getType() != V2->getType()) +    // We can't look through casts yet. +    return false; +  if (isAddOfNonZero(V1, V2, Q) || isAddOfNonZero(V2, V1, Q)) +    return true; + +  if (V1->getType()->isIntOrIntVectorTy()) { +    // Are any known bits in V1 contradictory to known bits in V2? If V1 +    // has a known zero where V2 has a known one, they must not be equal. +    KnownBits Known1 = computeKnownBits(V1, 0, Q); +    KnownBits Known2 = computeKnownBits(V2, 0, Q); + +    if (Known1.Zero.intersects(Known2.One) || +        Known2.Zero.intersects(Known1.One)) +      return true; +  } +  return false; +} + +/// Return true if 'V & Mask' is known to be zero.  We use this predicate to +/// simplify operations downstream. Mask is known to be zero for bits that V +/// cannot have. +/// +/// This function is defined on values with integer type, values with pointer +/// type, and vectors of integers.  In the case +/// where V is a vector, the mask, known zero, and known one values are the +/// same width as the vector element, and the bit is set only if it is true +/// for all of the elements in the vector. +bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, +                       const Query &Q) { +  KnownBits Known(Mask.getBitWidth()); +  computeKnownBits(V, Known, Depth, Q); +  return Mask.isSubsetOf(Known.Zero); +} + +/// For vector constants, loop over the elements and find the constant with the +/// minimum number of sign bits. Return 0 if the value is not a vector constant +/// or if any element was not analyzed; otherwise, return the count for the +/// element with the minimum number of sign bits. +static unsigned computeNumSignBitsVectorConstant(const Value *V, +                                                 unsigned TyBits) { +  const auto *CV = dyn_cast<Constant>(V); +  if (!CV || !CV->getType()->isVectorTy()) +    return 0; + +  unsigned MinSignBits = TyBits; +  unsigned NumElts = CV->getType()->getVectorNumElements(); +  for (unsigned i = 0; i != NumElts; ++i) { +    // If we find a non-ConstantInt, bail out. +    auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i)); +    if (!Elt) +      return 0; + +    MinSignBits = std::min(MinSignBits, Elt->getValue().getNumSignBits()); +  } + +  return MinSignBits; +} + +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, +                                       const Query &Q); + +static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, +                                   const Query &Q) { +  unsigned Result = ComputeNumSignBitsImpl(V, Depth, Q); +  assert(Result > 0 && "At least one sign bit needs to be present!"); +  return Result; +} + +/// Return the number of times the sign bit of the register is replicated into +/// the other bits. We know that at least 1 bit is always equal to the sign bit +/// (itself), but other cases can give us information. For example, immediately +/// after an "ashr X, 2", we know that the top 3 bits are all equal to each +/// other, so we return 3. For vectors, return the number of sign bits for the +/// vector element with the minimum number of known sign bits. +static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, +                                       const Query &Q) { +  assert(Depth <= MaxDepth && "Limit Search Depth"); + +  // We return the minimum number of sign bits that are guaranteed to be present +  // in V, so for undef we have to conservatively return 1.  We don't have the +  // same behavior for poison though -- that's a FIXME today. + +  Type *ScalarTy = V->getType()->getScalarType(); +  unsigned TyBits = ScalarTy->isPointerTy() ? +    Q.DL.getIndexTypeSizeInBits(ScalarTy) : +    Q.DL.getTypeSizeInBits(ScalarTy); + +  unsigned Tmp, Tmp2; +  unsigned FirstAnswer = 1; + +  // Note that ConstantInt is handled by the general computeKnownBits case +  // below. + +  if (Depth == MaxDepth) +    return 1;  // Limit search depth. + +  const Operator *U = dyn_cast<Operator>(V); +  switch (Operator::getOpcode(V)) { +  default: break; +  case Instruction::SExt: +    Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); +    return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp; + +  case Instruction::SDiv: { +    const APInt *Denominator; +    // sdiv X, C -> adds log(C) sign bits. +    if (match(U->getOperand(1), m_APInt(Denominator))) { + +      // Ignore non-positive denominator. +      if (!Denominator->isStrictlyPositive()) +        break; + +      // Calculate the incoming numerator bits. +      unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + +      // Add floor(log(C)) bits to the numerator bits. +      return std::min(TyBits, NumBits + Denominator->logBase2()); +    } +    break; +  } + +  case Instruction::SRem: { +    const APInt *Denominator; +    // srem X, C -> we know that the result is within [-C+1,C) when C is a +    // positive constant.  This let us put a lower bound on the number of sign +    // bits. +    if (match(U->getOperand(1), m_APInt(Denominator))) { + +      // Ignore non-positive denominator. +      if (!Denominator->isStrictlyPositive()) +        break; + +      // Calculate the incoming numerator bits. SRem by a positive constant +      // can't lower the number of sign bits. +      unsigned NumrBits = +          ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + +      // Calculate the leading sign bit constraints by examining the +      // denominator.  Given that the denominator is positive, there are two +      // cases: +      // +      //  1. the numerator is positive.  The result range is [0,C) and [0,C) u< +      //     (1 << ceilLogBase2(C)). +      // +      //  2. the numerator is negative.  Then the result range is (-C,0] and +      //     integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)). +      // +      // Thus a lower bound on the number of sign bits is `TyBits - +      // ceilLogBase2(C)`. + +      unsigned ResBits = TyBits - Denominator->ceilLogBase2(); +      return std::max(NumrBits, ResBits); +    } +    break; +  } + +  case Instruction::AShr: { +    Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); +    // ashr X, C   -> adds C sign bits.  Vectors too. +    const APInt *ShAmt; +    if (match(U->getOperand(1), m_APInt(ShAmt))) { +      if (ShAmt->uge(TyBits)) +        break;  // Bad shift. +      unsigned ShAmtLimited = ShAmt->getZExtValue(); +      Tmp += ShAmtLimited; +      if (Tmp > TyBits) Tmp = TyBits; +    } +    return Tmp; +  } +  case Instruction::Shl: { +    const APInt *ShAmt; +    if (match(U->getOperand(1), m_APInt(ShAmt))) { +      // shl destroys sign bits. +      Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); +      if (ShAmt->uge(TyBits) ||      // Bad shift. +          ShAmt->uge(Tmp)) break;    // Shifted all sign bits out. +      Tmp2 = ShAmt->getZExtValue(); +      return Tmp - Tmp2; +    } +    break; +  } +  case Instruction::And: +  case Instruction::Or: +  case Instruction::Xor:    // NOT is handled here. +    // Logical binary ops preserve the number of sign bits at the worst. +    Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); +    if (Tmp != 1) { +      Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); +      FirstAnswer = std::min(Tmp, Tmp2); +      // We computed what we know about the sign bits as our first +      // answer. Now proceed to the generic code that uses +      // computeKnownBits, and pick whichever answer is better. +    } +    break; + +  case Instruction::Select: +    Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); +    if (Tmp == 1) break; +    Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q); +    return std::min(Tmp, Tmp2); + +  case Instruction::Add: +    // Add can have at most one carry bit.  Thus we know that the output +    // is, at worst, one more bit than the inputs. +    Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); +    if (Tmp == 1) break; + +    // Special case decrementing a value (ADD X, -1): +    if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1))) +      if (CRHS->isAllOnesValue()) { +        KnownBits Known(TyBits); +        computeKnownBits(U->getOperand(0), Known, Depth + 1, Q); + +        // If the input is known to be 0 or 1, the output is 0/-1, which is all +        // sign bits set. +        if ((Known.Zero | 1).isAllOnesValue()) +          return TyBits; + +        // If we are subtracting one from a positive number, there is no carry +        // out of the result. +        if (Known.isNonNegative()) +          return Tmp; +      } + +    Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); +    if (Tmp2 == 1) break; +    return std::min(Tmp, Tmp2)-1; + +  case Instruction::Sub: +    Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); +    if (Tmp2 == 1) break; + +    // Handle NEG. +    if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0))) +      if (CLHS->isNullValue()) { +        KnownBits Known(TyBits); +        computeKnownBits(U->getOperand(1), Known, Depth + 1, Q); +        // If the input is known to be 0 or 1, the output is 0/-1, which is all +        // sign bits set. +        if ((Known.Zero | 1).isAllOnesValue()) +          return TyBits; + +        // If the input is known to be positive (the sign bit is known clear), +        // the output of the NEG has the same number of sign bits as the input. +        if (Known.isNonNegative()) +          return Tmp2; + +        // Otherwise, we treat this like a SUB. +      } + +    // Sub can have at most one carry bit.  Thus we know that the output +    // is, at worst, one more bit than the inputs. +    Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); +    if (Tmp == 1) break; +    return std::min(Tmp, Tmp2)-1; + +  case Instruction::Mul: { +    // The output of the Mul can be at most twice the valid bits in the inputs. +    unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); +    if (SignBitsOp0 == 1) break; +    unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); +    if (SignBitsOp1 == 1) break; +    unsigned OutValidBits = +        (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1); +    return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1; +  } + +  case Instruction::PHI: { +    const PHINode *PN = cast<PHINode>(U); +    unsigned NumIncomingValues = PN->getNumIncomingValues(); +    // Don't analyze large in-degree PHIs. +    if (NumIncomingValues > 4) break; +    // Unreachable blocks may have zero-operand PHI nodes. +    if (NumIncomingValues == 0) break; + +    // Take the minimum of all incoming values.  This can't infinitely loop +    // because of our depth threshold. +    Tmp = ComputeNumSignBits(PN->getIncomingValue(0), Depth + 1, Q); +    for (unsigned i = 1, e = NumIncomingValues; i != e; ++i) { +      if (Tmp == 1) return Tmp; +      Tmp = std::min( +          Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, Q)); +    } +    return Tmp; +  } + +  case Instruction::Trunc: +    // FIXME: it's tricky to do anything useful for this, but it is an important +    // case for targets like X86. +    break; + +  case Instruction::ExtractElement: +    // Look through extract element. At the moment we keep this simple and skip +    // tracking the specific element. But at least we might find information +    // valid for all elements of the vector (for example if vector is sign +    // extended, shifted, etc). +    return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); +  } + +  // Finally, if we can prove that the top bits of the result are 0's or 1's, +  // use this information. + +  // If we can examine all elements of a vector constant successfully, we're +  // done (we can't do any better than that). If not, keep trying. +  if (unsigned VecSignBits = computeNumSignBitsVectorConstant(V, TyBits)) +    return VecSignBits; + +  KnownBits Known(TyBits); +  computeKnownBits(V, Known, Depth, Q); + +  // If we know that the sign bit is either zero or one, determine the number of +  // identical bits in the top of the input value. +  return std::max(FirstAnswer, Known.countMinSignBits()); +} + +/// This function computes the integer multiple of Base that equals V. +/// If successful, it returns true and returns the multiple in +/// Multiple. If unsuccessful, it returns false. It looks +/// through SExt instructions only if LookThroughSExt is true. +bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, +                           bool LookThroughSExt, unsigned Depth) { +  const unsigned MaxDepth = 6; + +  assert(V && "No Value?"); +  assert(Depth <= MaxDepth && "Limit Search Depth"); +  assert(V->getType()->isIntegerTy() && "Not integer or pointer type!"); + +  Type *T = V->getType(); + +  ConstantInt *CI = dyn_cast<ConstantInt>(V); + +  if (Base == 0) +    return false; + +  if (Base == 1) { +    Multiple = V; +    return true; +  } + +  ConstantExpr *CO = dyn_cast<ConstantExpr>(V); +  Constant *BaseVal = ConstantInt::get(T, Base); +  if (CO && CO == BaseVal) { +    // Multiple is 1. +    Multiple = ConstantInt::get(T, 1); +    return true; +  } + +  if (CI && CI->getZExtValue() % Base == 0) { +    Multiple = ConstantInt::get(T, CI->getZExtValue() / Base); +    return true; +  } + +  if (Depth == MaxDepth) return false;  // Limit search depth. + +  Operator *I = dyn_cast<Operator>(V); +  if (!I) return false; + +  switch (I->getOpcode()) { +  default: break; +  case Instruction::SExt: +    if (!LookThroughSExt) return false; +    // otherwise fall through to ZExt +    LLVM_FALLTHROUGH; +  case Instruction::ZExt: +    return ComputeMultiple(I->getOperand(0), Base, Multiple, +                           LookThroughSExt, Depth+1); +  case Instruction::Shl: +  case Instruction::Mul: { +    Value *Op0 = I->getOperand(0); +    Value *Op1 = I->getOperand(1); + +    if (I->getOpcode() == Instruction::Shl) { +      ConstantInt *Op1CI = dyn_cast<ConstantInt>(Op1); +      if (!Op1CI) return false; +      // Turn Op0 << Op1 into Op0 * 2^Op1 +      APInt Op1Int = Op1CI->getValue(); +      uint64_t BitToSet = Op1Int.getLimitedValue(Op1Int.getBitWidth() - 1); +      APInt API(Op1Int.getBitWidth(), 0); +      API.setBit(BitToSet); +      Op1 = ConstantInt::get(V->getContext(), API); +    } + +    Value *Mul0 = nullptr; +    if (ComputeMultiple(Op0, Base, Mul0, LookThroughSExt, Depth+1)) { +      if (Constant *Op1C = dyn_cast<Constant>(Op1)) +        if (Constant *MulC = dyn_cast<Constant>(Mul0)) { +          if (Op1C->getType()->getPrimitiveSizeInBits() < +              MulC->getType()->getPrimitiveSizeInBits()) +            Op1C = ConstantExpr::getZExt(Op1C, MulC->getType()); +          if (Op1C->getType()->getPrimitiveSizeInBits() > +              MulC->getType()->getPrimitiveSizeInBits()) +            MulC = ConstantExpr::getZExt(MulC, Op1C->getType()); + +          // V == Base * (Mul0 * Op1), so return (Mul0 * Op1) +          Multiple = ConstantExpr::getMul(MulC, Op1C); +          return true; +        } + +      if (ConstantInt *Mul0CI = dyn_cast<ConstantInt>(Mul0)) +        if (Mul0CI->getValue() == 1) { +          // V == Base * Op1, so return Op1 +          Multiple = Op1; +          return true; +        } +    } + +    Value *Mul1 = nullptr; +    if (ComputeMultiple(Op1, Base, Mul1, LookThroughSExt, Depth+1)) { +      if (Constant *Op0C = dyn_cast<Constant>(Op0)) +        if (Constant *MulC = dyn_cast<Constant>(Mul1)) { +          if (Op0C->getType()->getPrimitiveSizeInBits() < +              MulC->getType()->getPrimitiveSizeInBits()) +            Op0C = ConstantExpr::getZExt(Op0C, MulC->getType()); +          if (Op0C->getType()->getPrimitiveSizeInBits() > +              MulC->getType()->getPrimitiveSizeInBits()) +            MulC = ConstantExpr::getZExt(MulC, Op0C->getType()); + +          // V == Base * (Mul1 * Op0), so return (Mul1 * Op0) +          Multiple = ConstantExpr::getMul(MulC, Op0C); +          return true; +        } + +      if (ConstantInt *Mul1CI = dyn_cast<ConstantInt>(Mul1)) +        if (Mul1CI->getValue() == 1) { +          // V == Base * Op0, so return Op0 +          Multiple = Op0; +          return true; +        } +    } +  } +  } + +  // We could not determine if V is a multiple of Base. +  return false; +} + +Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, +                                            const TargetLibraryInfo *TLI) { +  const Function *F = ICS.getCalledFunction(); +  if (!F) +    return Intrinsic::not_intrinsic; + +  if (F->isIntrinsic()) +    return F->getIntrinsicID(); + +  if (!TLI) +    return Intrinsic::not_intrinsic; + +  LibFunc Func; +  // We're going to make assumptions on the semantics of the functions, check +  // that the target knows that it's available in this environment and it does +  // not have local linkage. +  if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(*F, Func)) +    return Intrinsic::not_intrinsic; + +  if (!ICS.onlyReadsMemory()) +    return Intrinsic::not_intrinsic; + +  // Otherwise check if we have a call to a function that can be turned into a +  // vector intrinsic. +  switch (Func) { +  default: +    break; +  case LibFunc_sin: +  case LibFunc_sinf: +  case LibFunc_sinl: +    return Intrinsic::sin; +  case LibFunc_cos: +  case LibFunc_cosf: +  case LibFunc_cosl: +    return Intrinsic::cos; +  case LibFunc_exp: +  case LibFunc_expf: +  case LibFunc_expl: +    return Intrinsic::exp; +  case LibFunc_exp2: +  case LibFunc_exp2f: +  case LibFunc_exp2l: +    return Intrinsic::exp2; +  case LibFunc_log: +  case LibFunc_logf: +  case LibFunc_logl: +    return Intrinsic::log; +  case LibFunc_log10: +  case LibFunc_log10f: +  case LibFunc_log10l: +    return Intrinsic::log10; +  case LibFunc_log2: +  case LibFunc_log2f: +  case LibFunc_log2l: +    return Intrinsic::log2; +  case LibFunc_fabs: +  case LibFunc_fabsf: +  case LibFunc_fabsl: +    return Intrinsic::fabs; +  case LibFunc_fmin: +  case LibFunc_fminf: +  case LibFunc_fminl: +    return Intrinsic::minnum; +  case LibFunc_fmax: +  case LibFunc_fmaxf: +  case LibFunc_fmaxl: +    return Intrinsic::maxnum; +  case LibFunc_copysign: +  case LibFunc_copysignf: +  case LibFunc_copysignl: +    return Intrinsic::copysign; +  case LibFunc_floor: +  case LibFunc_floorf: +  case LibFunc_floorl: +    return Intrinsic::floor; +  case LibFunc_ceil: +  case LibFunc_ceilf: +  case LibFunc_ceill: +    return Intrinsic::ceil; +  case LibFunc_trunc: +  case LibFunc_truncf: +  case LibFunc_truncl: +    return Intrinsic::trunc; +  case LibFunc_rint: +  case LibFunc_rintf: +  case LibFunc_rintl: +    return Intrinsic::rint; +  case LibFunc_nearbyint: +  case LibFunc_nearbyintf: +  case LibFunc_nearbyintl: +    return Intrinsic::nearbyint; +  case LibFunc_round: +  case LibFunc_roundf: +  case LibFunc_roundl: +    return Intrinsic::round; +  case LibFunc_pow: +  case LibFunc_powf: +  case LibFunc_powl: +    return Intrinsic::pow; +  case LibFunc_sqrt: +  case LibFunc_sqrtf: +  case LibFunc_sqrtl: +    return Intrinsic::sqrt; +  } + +  return Intrinsic::not_intrinsic; +} + +/// Return true if we can prove that the specified FP value is never equal to +/// -0.0. +/// +/// NOTE: this function will need to be revisited when we support non-default +/// rounding modes! +bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, +                                unsigned Depth) { +  if (auto *CFP = dyn_cast<ConstantFP>(V)) +    return !CFP->getValueAPF().isNegZero(); + +  // Limit search depth. +  if (Depth == MaxDepth) +    return false; + +  auto *Op = dyn_cast<Operator>(V); +  if (!Op) +    return false; + +  // Check if the nsz fast-math flag is set. +  if (auto *FPO = dyn_cast<FPMathOperator>(Op)) +    if (FPO->hasNoSignedZeros()) +      return true; + +  // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0. +  if (match(Op, m_FAdd(m_Value(), m_PosZeroFP()))) +    return true; + +  // sitofp and uitofp turn into +0.0 for zero. +  if (isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) +    return true; + +  if (auto *Call = dyn_cast<CallInst>(Op)) { +    Intrinsic::ID IID = getIntrinsicForCallSite(Call, TLI); +    switch (IID) { +    default: +      break; +    // sqrt(-0.0) = -0.0, no other negative results are possible. +    case Intrinsic::sqrt: +      return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); +    // fabs(x) != -0.0 +    case Intrinsic::fabs: +      return true; +    } +  } + +  return false; +} + +/// If \p SignBitOnly is true, test for a known 0 sign bit rather than a +/// standard ordered compare. e.g. make -0.0 olt 0.0 be true because of the sign +/// bit despite comparing equal. +static bool cannotBeOrderedLessThanZeroImpl(const Value *V, +                                            const TargetLibraryInfo *TLI, +                                            bool SignBitOnly, +                                            unsigned Depth) { +  // TODO: This function does not do the right thing when SignBitOnly is true +  // and we're lowering to a hypothetical IEEE 754-compliant-but-evil platform +  // which flips the sign bits of NaNs.  See +  // https://llvm.org/bugs/show_bug.cgi?id=31702. + +  if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { +    return !CFP->getValueAPF().isNegative() || +           (!SignBitOnly && CFP->getValueAPF().isZero()); +  } + +  // Handle vector of constants. +  if (auto *CV = dyn_cast<Constant>(V)) { +    if (CV->getType()->isVectorTy()) { +      unsigned NumElts = CV->getType()->getVectorNumElements(); +      for (unsigned i = 0; i != NumElts; ++i) { +        auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); +        if (!CFP) +          return false; +        if (CFP->getValueAPF().isNegative() && +            (SignBitOnly || !CFP->getValueAPF().isZero())) +          return false; +      } + +      // All non-negative ConstantFPs. +      return true; +    } +  } + +  if (Depth == MaxDepth) +    return false; // Limit search depth. + +  const Operator *I = dyn_cast<Operator>(V); +  if (!I) +    return false; + +  switch (I->getOpcode()) { +  default: +    break; +  // Unsigned integers are always nonnegative. +  case Instruction::UIToFP: +    return true; +  case Instruction::FMul: +    // x*x is always non-negative or a NaN. +    if (I->getOperand(0) == I->getOperand(1) && +        (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs())) +      return true; + +    LLVM_FALLTHROUGH; +  case Instruction::FAdd: +  case Instruction::FDiv: +  case Instruction::FRem: +    return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, +                                           Depth + 1) && +           cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, +                                           Depth + 1); +  case Instruction::Select: +    return cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, +                                           Depth + 1) && +           cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly, +                                           Depth + 1); +  case Instruction::FPExt: +  case Instruction::FPTrunc: +    // Widening/narrowing never change sign. +    return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, +                                           Depth + 1); +  case Instruction::ExtractElement: +    // Look through extract element. At the moment we keep this simple and skip +    // tracking the specific element. But at least we might find information +    // valid for all elements of the vector. +    return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, +                                           Depth + 1); +  case Instruction::Call: +    const auto *CI = cast<CallInst>(I); +    Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); +    switch (IID) { +    default: +      break; +    case Intrinsic::maxnum: +      return (isKnownNeverNaN(I->getOperand(0)) && +              cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, +                                              SignBitOnly, Depth + 1)) || +             (isKnownNeverNaN(I->getOperand(1)) && +              cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, +                                              SignBitOnly, Depth + 1)); + +    case Intrinsic::minnum: +      return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, +                                             Depth + 1) && +             cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, +                                             Depth + 1); +    case Intrinsic::exp: +    case Intrinsic::exp2: +    case Intrinsic::fabs: +      return true; + +    case Intrinsic::sqrt: +      // sqrt(x) is always >= -0 or NaN.  Moreover, sqrt(x) == -0 iff x == -0. +      if (!SignBitOnly) +        return true; +      return CI->hasNoNaNs() && (CI->hasNoSignedZeros() || +                                 CannotBeNegativeZero(CI->getOperand(0), TLI)); + +    case Intrinsic::powi: +      if (ConstantInt *Exponent = dyn_cast<ConstantInt>(I->getOperand(1))) { +        // powi(x,n) is non-negative if n is even. +        if (Exponent->getBitWidth() <= 64 && Exponent->getSExtValue() % 2u == 0) +          return true; +      } +      // TODO: This is not correct.  Given that exp is an integer, here are the +      // ways that pow can return a negative value: +      // +      //   pow(x, exp)    --> negative if exp is odd and x is negative. +      //   pow(-0, exp)   --> -inf if exp is negative odd. +      //   pow(-0, exp)   --> -0 if exp is positive odd. +      //   pow(-inf, exp) --> -0 if exp is negative odd. +      //   pow(-inf, exp) --> -inf if exp is positive odd. +      // +      // Therefore, if !SignBitOnly, we can return true if x >= +0 or x is NaN, +      // but we must return false if x == -0.  Unfortunately we do not currently +      // have a way of expressing this constraint.  See details in +      // https://llvm.org/bugs/show_bug.cgi?id=31702. +      return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, +                                             Depth + 1); + +    case Intrinsic::fma: +    case Intrinsic::fmuladd: +      // x*x+y is non-negative if y is non-negative. +      return I->getOperand(0) == I->getOperand(1) && +             (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs()) && +             cannotBeOrderedLessThanZeroImpl(I->getOperand(2), TLI, SignBitOnly, +                                             Depth + 1); +    } +    break; +  } +  return false; +} + +bool llvm::CannotBeOrderedLessThanZero(const Value *V, +                                       const TargetLibraryInfo *TLI) { +  return cannotBeOrderedLessThanZeroImpl(V, TLI, false, 0); +} + +bool llvm::SignBitMustBeZero(const Value *V, const TargetLibraryInfo *TLI) { +  return cannotBeOrderedLessThanZeroImpl(V, TLI, true, 0); +} + +bool llvm::isKnownNeverNaN(const Value *V) { +  assert(V->getType()->isFPOrFPVectorTy() && "Querying for NaN on non-FP type"); + +  // If we're told that NaNs won't happen, assume they won't. +  if (auto *FPMathOp = dyn_cast<FPMathOperator>(V)) +    if (FPMathOp->hasNoNaNs()) +      return true; + +  // TODO: Handle instructions and potentially recurse like other 'isKnown' +  // functions. For example, the result of sitofp is never NaN. + +  // Handle scalar constants. +  if (auto *CFP = dyn_cast<ConstantFP>(V)) +    return !CFP->isNaN(); + +  // Bail out for constant expressions, but try to handle vector constants. +  if (!V->getType()->isVectorTy() || !isa<Constant>(V)) +    return false; + +  // For vectors, verify that each element is not NaN. +  unsigned NumElts = V->getType()->getVectorNumElements(); +  for (unsigned i = 0; i != NumElts; ++i) { +    Constant *Elt = cast<Constant>(V)->getAggregateElement(i); +    if (!Elt) +      return false; +    if (isa<UndefValue>(Elt)) +      continue; +    auto *CElt = dyn_cast<ConstantFP>(Elt); +    if (!CElt || CElt->isNaN()) +      return false; +  } +  // All elements were confirmed not-NaN or undefined. +  return true; +} + +/// If the specified value can be set by repeating the same byte in memory, +/// return the i8 value that it is represented with.  This is +/// true for all i8 values obviously, but is also true for i32 0, i32 -1, +/// i16 0xF0F0, double 0.0 etc.  If the value can't be handled with a repeated +/// byte store (e.g. i16 0x1234), return null. +Value *llvm::isBytewiseValue(Value *V) { +  // All byte-wide stores are splatable, even of arbitrary variables. +  if (V->getType()->isIntegerTy(8)) return V; + +  // Handle 'null' ConstantArrayZero etc. +  if (Constant *C = dyn_cast<Constant>(V)) +    if (C->isNullValue()) +      return Constant::getNullValue(Type::getInt8Ty(V->getContext())); + +  // Constant float and double values can be handled as integer values if the +  // corresponding integer value is "byteable".  An important case is 0.0. +  if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { +    if (CFP->getType()->isFloatTy()) +      V = ConstantExpr::getBitCast(CFP, Type::getInt32Ty(V->getContext())); +    if (CFP->getType()->isDoubleTy()) +      V = ConstantExpr::getBitCast(CFP, Type::getInt64Ty(V->getContext())); +    // Don't handle long double formats, which have strange constraints. +  } + +  // We can handle constant integers that are multiple of 8 bits. +  if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { +    if (CI->getBitWidth() % 8 == 0) { +      assert(CI->getBitWidth() > 8 && "8 bits should be handled above!"); + +      if (!CI->getValue().isSplat(8)) +        return nullptr; +      return ConstantInt::get(V->getContext(), CI->getValue().trunc(8)); +    } +  } + +  // A ConstantDataArray/Vector is splatable if all its members are equal and +  // also splatable. +  if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(V)) { +    Value *Elt = CA->getElementAsConstant(0); +    Value *Val = isBytewiseValue(Elt); +    if (!Val) +      return nullptr; + +    for (unsigned I = 1, E = CA->getNumElements(); I != E; ++I) +      if (CA->getElementAsConstant(I) != Elt) +        return nullptr; + +    return Val; +  } + +  // Conceptually, we could handle things like: +  //   %a = zext i8 %X to i16 +  //   %b = shl i16 %a, 8 +  //   %c = or i16 %a, %b +  // but until there is an example that actually needs this, it doesn't seem +  // worth worrying about. +  return nullptr; +} + +// This is the recursive version of BuildSubAggregate. It takes a few different +// arguments. Idxs is the index within the nested struct From that we are +// looking at now (which is of type IndexedType). IdxSkip is the number of +// indices from Idxs that should be left out when inserting into the resulting +// struct. To is the result struct built so far, new insertvalue instructions +// build on that. +static Value *BuildSubAggregate(Value *From, Value* To, Type *IndexedType, +                                SmallVectorImpl<unsigned> &Idxs, +                                unsigned IdxSkip, +                                Instruction *InsertBefore) { +  StructType *STy = dyn_cast<StructType>(IndexedType); +  if (STy) { +    // Save the original To argument so we can modify it +    Value *OrigTo = To; +    // General case, the type indexed by Idxs is a struct +    for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { +      // Process each struct element recursively +      Idxs.push_back(i); +      Value *PrevTo = To; +      To = BuildSubAggregate(From, To, STy->getElementType(i), Idxs, IdxSkip, +                             InsertBefore); +      Idxs.pop_back(); +      if (!To) { +        // Couldn't find any inserted value for this index? Cleanup +        while (PrevTo != OrigTo) { +          InsertValueInst* Del = cast<InsertValueInst>(PrevTo); +          PrevTo = Del->getAggregateOperand(); +          Del->eraseFromParent(); +        } +        // Stop processing elements +        break; +      } +    } +    // If we successfully found a value for each of our subaggregates +    if (To) +      return To; +  } +  // Base case, the type indexed by SourceIdxs is not a struct, or not all of +  // the struct's elements had a value that was inserted directly. In the latter +  // case, perhaps we can't determine each of the subelements individually, but +  // we might be able to find the complete struct somewhere. + +  // Find the value that is at that particular spot +  Value *V = FindInsertedValue(From, Idxs); + +  if (!V) +    return nullptr; + +  // Insert the value in the new (sub) aggregate +  return InsertValueInst::Create(To, V, makeArrayRef(Idxs).slice(IdxSkip), +                                 "tmp", InsertBefore); +} + +// This helper takes a nested struct and extracts a part of it (which is again a +// struct) into a new value. For example, given the struct: +// { a, { b, { c, d }, e } } +// and the indices "1, 1" this returns +// { c, d }. +// +// It does this by inserting an insertvalue for each element in the resulting +// struct, as opposed to just inserting a single struct. This will only work if +// each of the elements of the substruct are known (ie, inserted into From by an +// insertvalue instruction somewhere). +// +// All inserted insertvalue instructions are inserted before InsertBefore +static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range, +                                Instruction *InsertBefore) { +  assert(InsertBefore && "Must have someplace to insert!"); +  Type *IndexedType = ExtractValueInst::getIndexedType(From->getType(), +                                                             idx_range); +  Value *To = UndefValue::get(IndexedType); +  SmallVector<unsigned, 10> Idxs(idx_range.begin(), idx_range.end()); +  unsigned IdxSkip = Idxs.size(); + +  return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore); +} + +/// Given an aggregate and a sequence of indices, see if the scalar value +/// indexed is already around as a register, for example if it was inserted +/// directly into the aggregate. +/// +/// If InsertBefore is not null, this function will duplicate (modified) +/// insertvalues when a part of a nested struct is extracted. +Value *llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range, +                               Instruction *InsertBefore) { +  // Nothing to index? Just return V then (this is useful at the end of our +  // recursion). +  if (idx_range.empty()) +    return V; +  // We have indices, so V should have an indexable type. +  assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) && +         "Not looking at a struct or array?"); +  assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) && +         "Invalid indices for type?"); + +  if (Constant *C = dyn_cast<Constant>(V)) { +    C = C->getAggregateElement(idx_range[0]); +    if (!C) return nullptr; +    return FindInsertedValue(C, idx_range.slice(1), InsertBefore); +  } + +  if (InsertValueInst *I = dyn_cast<InsertValueInst>(V)) { +    // Loop the indices for the insertvalue instruction in parallel with the +    // requested indices +    const unsigned *req_idx = idx_range.begin(); +    for (const unsigned *i = I->idx_begin(), *e = I->idx_end(); +         i != e; ++i, ++req_idx) { +      if (req_idx == idx_range.end()) { +        // We can't handle this without inserting insertvalues +        if (!InsertBefore) +          return nullptr; + +        // The requested index identifies a part of a nested aggregate. Handle +        // this specially. For example, +        // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0 +        // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1 +        // %C = extractvalue {i32, { i32, i32 } } %B, 1 +        // This can be changed into +        // %A = insertvalue {i32, i32 } undef, i32 10, 0 +        // %C = insertvalue {i32, i32 } %A, i32 11, 1 +        // which allows the unused 0,0 element from the nested struct to be +        // removed. +        return BuildSubAggregate(V, makeArrayRef(idx_range.begin(), req_idx), +                                 InsertBefore); +      } + +      // This insert value inserts something else than what we are looking for. +      // See if the (aggregate) value inserted into has the value we are +      // looking for, then. +      if (*req_idx != *i) +        return FindInsertedValue(I->getAggregateOperand(), idx_range, +                                 InsertBefore); +    } +    // If we end up here, the indices of the insertvalue match with those +    // requested (though possibly only partially). Now we recursively look at +    // the inserted value, passing any remaining indices. +    return FindInsertedValue(I->getInsertedValueOperand(), +                             makeArrayRef(req_idx, idx_range.end()), +                             InsertBefore); +  } + +  if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(V)) { +    // If we're extracting a value from an aggregate that was extracted from +    // something else, we can extract from that something else directly instead. +    // However, we will need to chain I's indices with the requested indices. + +    // Calculate the number of indices required +    unsigned size = I->getNumIndices() + idx_range.size(); +    // Allocate some space to put the new indices in +    SmallVector<unsigned, 5> Idxs; +    Idxs.reserve(size); +    // Add indices from the extract value instruction +    Idxs.append(I->idx_begin(), I->idx_end()); + +    // Add requested indices +    Idxs.append(idx_range.begin(), idx_range.end()); + +    assert(Idxs.size() == size +           && "Number of indices added not correct?"); + +    return FindInsertedValue(I->getAggregateOperand(), Idxs, InsertBefore); +  } +  // Otherwise, we don't know (such as, extracting from a function return value +  // or load instruction) +  return nullptr; +} + +/// Analyze the specified pointer to see if it can be expressed as a base +/// pointer plus a constant offset. Return the base and offset to the caller. +Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, +                                              const DataLayout &DL) { +  unsigned BitWidth = DL.getIndexTypeSizeInBits(Ptr->getType()); +  APInt ByteOffset(BitWidth, 0); + +  // We walk up the defs but use a visited set to handle unreachable code. In +  // that case, we stop after accumulating the cycle once (not that it +  // matters). +  SmallPtrSet<Value *, 16> Visited; +  while (Visited.insert(Ptr).second) { +    if (Ptr->getType()->isVectorTy()) +      break; + +    if (GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr)) { +      // If one of the values we have visited is an addrspacecast, then +      // the pointer type of this GEP may be different from the type +      // of the Ptr parameter which was passed to this function.  This +      // means when we construct GEPOffset, we need to use the size +      // of GEP's pointer type rather than the size of the original +      // pointer type. +      APInt GEPOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0); +      if (!GEP->accumulateConstantOffset(DL, GEPOffset)) +        break; + +      ByteOffset += GEPOffset.getSExtValue(); + +      Ptr = GEP->getPointerOperand(); +    } else if (Operator::getOpcode(Ptr) == Instruction::BitCast || +               Operator::getOpcode(Ptr) == Instruction::AddrSpaceCast) { +      Ptr = cast<Operator>(Ptr)->getOperand(0); +    } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(Ptr)) { +      if (GA->isInterposable()) +        break; +      Ptr = GA->getAliasee(); +    } else { +      break; +    } +  } +  Offset = ByteOffset.getSExtValue(); +  return Ptr; +} + +bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP, +                                       unsigned CharSize) { +  // Make sure the GEP has exactly three arguments. +  if (GEP->getNumOperands() != 3) +    return false; + +  // Make sure the index-ee is a pointer to array of \p CharSize integers. +  // CharSize. +  ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType()); +  if (!AT || !AT->getElementType()->isIntegerTy(CharSize)) +    return false; + +  // Check to make sure that the first operand of the GEP is an integer and +  // has value 0 so that we are sure we're indexing into the initializer. +  const ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1)); +  if (!FirstIdx || !FirstIdx->isZero()) +    return false; + +  return true; +} + +bool llvm::getConstantDataArrayInfo(const Value *V, +                                    ConstantDataArraySlice &Slice, +                                    unsigned ElementSize, uint64_t Offset) { +  assert(V); + +  // Look through bitcast instructions and geps. +  V = V->stripPointerCasts(); + +  // If the value is a GEP instruction or constant expression, treat it as an +  // offset. +  if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { +    // The GEP operator should be based on a pointer to string constant, and is +    // indexing into the string constant. +    if (!isGEPBasedOnPointerToString(GEP, ElementSize)) +      return false; + +    // If the second index isn't a ConstantInt, then this is a variable index +    // into the array.  If this occurs, we can't say anything meaningful about +    // the string. +    uint64_t StartIdx = 0; +    if (const ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(2))) +      StartIdx = CI->getZExtValue(); +    else +      return false; +    return getConstantDataArrayInfo(GEP->getOperand(0), Slice, ElementSize, +                                    StartIdx + Offset); +  } + +  // The GEP instruction, constant or instruction, must reference a global +  // variable that is a constant and is initialized. The referenced constant +  // initializer is the array that we'll use for optimization. +  const GlobalVariable *GV = dyn_cast<GlobalVariable>(V); +  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) +    return false; + +  const ConstantDataArray *Array; +  ArrayType *ArrayTy; +  if (GV->getInitializer()->isNullValue()) { +    Type *GVTy = GV->getValueType(); +    if ( (ArrayTy = dyn_cast<ArrayType>(GVTy)) ) { +      // A zeroinitializer for the array; there is no ConstantDataArray. +      Array = nullptr; +    } else { +      const DataLayout &DL = GV->getParent()->getDataLayout(); +      uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy); +      uint64_t Length = SizeInBytes / (ElementSize / 8); +      if (Length <= Offset) +        return false; + +      Slice.Array = nullptr; +      Slice.Offset = 0; +      Slice.Length = Length - Offset; +      return true; +    } +  } else { +    // This must be a ConstantDataArray. +    Array = dyn_cast<ConstantDataArray>(GV->getInitializer()); +    if (!Array) +      return false; +    ArrayTy = Array->getType(); +  } +  if (!ArrayTy->getElementType()->isIntegerTy(ElementSize)) +    return false; + +  uint64_t NumElts = ArrayTy->getArrayNumElements(); +  if (Offset > NumElts) +    return false; + +  Slice.Array = Array; +  Slice.Offset = Offset; +  Slice.Length = NumElts - Offset; +  return true; +} + +/// This function computes the length of a null-terminated C string pointed to +/// by V. If successful, it returns true and returns the string in Str. +/// If unsuccessful, it returns false. +bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, +                                 uint64_t Offset, bool TrimAtNul) { +  ConstantDataArraySlice Slice; +  if (!getConstantDataArrayInfo(V, Slice, 8, Offset)) +    return false; + +  if (Slice.Array == nullptr) { +    if (TrimAtNul) { +      Str = StringRef(); +      return true; +    } +    if (Slice.Length == 1) { +      Str = StringRef("", 1); +      return true; +    } +    // We cannot instantiate a StringRef as we do not have an appropriate string +    // of 0s at hand. +    return false; +  } + +  // Start out with the entire array in the StringRef. +  Str = Slice.Array->getAsString(); +  // Skip over 'offset' bytes. +  Str = Str.substr(Slice.Offset); + +  if (TrimAtNul) { +    // Trim off the \0 and anything after it.  If the array is not nul +    // terminated, we just return the whole end of string.  The client may know +    // some other way that the string is length-bound. +    Str = Str.substr(0, Str.find('\0')); +  } +  return true; +} + +// These next two are very similar to the above, but also look through PHI +// nodes. +// TODO: See if we can integrate these two together. + +/// If we can compute the length of the string pointed to by +/// the specified pointer, return 'len+1'.  If we can't, return 0. +static uint64_t GetStringLengthH(const Value *V, +                                 SmallPtrSetImpl<const PHINode*> &PHIs, +                                 unsigned CharSize) { +  // Look through noop bitcast instructions. +  V = V->stripPointerCasts(); + +  // If this is a PHI node, there are two cases: either we have already seen it +  // or we haven't. +  if (const PHINode *PN = dyn_cast<PHINode>(V)) { +    if (!PHIs.insert(PN).second) +      return ~0ULL;  // already in the set. + +    // If it was new, see if all the input strings are the same length. +    uint64_t LenSoFar = ~0ULL; +    for (Value *IncValue : PN->incoming_values()) { +      uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize); +      if (Len == 0) return 0; // Unknown length -> unknown. + +      if (Len == ~0ULL) continue; + +      if (Len != LenSoFar && LenSoFar != ~0ULL) +        return 0;    // Disagree -> unknown. +      LenSoFar = Len; +    } + +    // Success, all agree. +    return LenSoFar; +  } + +  // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y) +  if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { +    uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize); +    if (Len1 == 0) return 0; +    uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize); +    if (Len2 == 0) return 0; +    if (Len1 == ~0ULL) return Len2; +    if (Len2 == ~0ULL) return Len1; +    if (Len1 != Len2) return 0; +    return Len1; +  } + +  // Otherwise, see if we can read the string. +  ConstantDataArraySlice Slice; +  if (!getConstantDataArrayInfo(V, Slice, CharSize)) +    return 0; + +  if (Slice.Array == nullptr) +    return 1; + +  // Search for nul characters +  unsigned NullIndex = 0; +  for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) { +    if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0) +      break; +  } + +  return NullIndex + 1; +} + +/// If we can compute the length of the string pointed to by +/// the specified pointer, return 'len+1'.  If we can't, return 0. +uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { +  if (!V->getType()->isPointerTy()) +    return 0; + +  SmallPtrSet<const PHINode*, 32> PHIs; +  uint64_t Len = GetStringLengthH(V, PHIs, CharSize); +  // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return +  // an empty string as a length. +  return Len == ~0ULL ? 1 : Len; +} + +const Value *llvm::getArgumentAliasingToReturnedPointer(ImmutableCallSite CS) { +  assert(CS && +         "getArgumentAliasingToReturnedPointer only works on nonnull CallSite"); +  if (const Value *RV = CS.getReturnedArgOperand()) +    return RV; +  // This can be used only as a aliasing property. +  if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(CS)) +    return CS.getArgOperand(0); +  return nullptr; +} + +bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( +    ImmutableCallSite CS) { +  return CS.getIntrinsicID() == Intrinsic::launder_invariant_group || +         CS.getIntrinsicID() == Intrinsic::strip_invariant_group; +} + +/// \p PN defines a loop-variant pointer to an object.  Check if the +/// previous iteration of the loop was referring to the same object as \p PN. +static bool isSameUnderlyingObjectInLoop(const PHINode *PN, +                                         const LoopInfo *LI) { +  // Find the loop-defined value. +  Loop *L = LI->getLoopFor(PN->getParent()); +  if (PN->getNumIncomingValues() != 2) +    return true; + +  // Find the value from previous iteration. +  auto *PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(0)); +  if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L) +    PrevValue = dyn_cast<Instruction>(PN->getIncomingValue(1)); +  if (!PrevValue || LI->getLoopFor(PrevValue->getParent()) != L) +    return true; + +  // If a new pointer is loaded in the loop, the pointer references a different +  // object in every iteration.  E.g.: +  //    for (i) +  //       int *p = a[i]; +  //       ... +  if (auto *Load = dyn_cast<LoadInst>(PrevValue)) +    if (!L->isLoopInvariant(Load->getPointerOperand())) +      return false; +  return true; +} + +Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, +                                 unsigned MaxLookup) { +  if (!V->getType()->isPointerTy()) +    return V; +  for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) { +    if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { +      V = GEP->getPointerOperand(); +    } else if (Operator::getOpcode(V) == Instruction::BitCast || +               Operator::getOpcode(V) == Instruction::AddrSpaceCast) { +      V = cast<Operator>(V)->getOperand(0); +    } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { +      if (GA->isInterposable()) +        return V; +      V = GA->getAliasee(); +    } else if (isa<AllocaInst>(V)) { +      // An alloca can't be further simplified. +      return V; +    } else { +      if (auto CS = CallSite(V)) { +        // CaptureTracking can know about special capturing properties of some +        // intrinsics like launder.invariant.group, that can't be expressed with +        // the attributes, but have properties like returning aliasing pointer. +        // Because some analysis may assume that nocaptured pointer is not +        // returned from some special intrinsic (because function would have to +        // be marked with returns attribute), it is crucial to use this function +        // because it should be in sync with CaptureTracking. Not using it may +        // cause weird miscompilations where 2 aliasing pointers are assumed to +        // noalias. +        if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) { +          V = RP; +          continue; +        } +      } + +      // See if InstructionSimplify knows any relevant tricks. +      if (Instruction *I = dyn_cast<Instruction>(V)) +        // TODO: Acquire a DominatorTree and AssumptionCache and use them. +        if (Value *Simplified = SimplifyInstruction(I, {DL, I})) { +          V = Simplified; +          continue; +        } + +      return V; +    } +    assert(V->getType()->isPointerTy() && "Unexpected operand type!"); +  } +  return V; +} + +void llvm::GetUnderlyingObjects(Value *V, SmallVectorImpl<Value *> &Objects, +                                const DataLayout &DL, LoopInfo *LI, +                                unsigned MaxLookup) { +  SmallPtrSet<Value *, 4> Visited; +  SmallVector<Value *, 4> Worklist; +  Worklist.push_back(V); +  do { +    Value *P = Worklist.pop_back_val(); +    P = GetUnderlyingObject(P, DL, MaxLookup); + +    if (!Visited.insert(P).second) +      continue; + +    if (SelectInst *SI = dyn_cast<SelectInst>(P)) { +      Worklist.push_back(SI->getTrueValue()); +      Worklist.push_back(SI->getFalseValue()); +      continue; +    } + +    if (PHINode *PN = dyn_cast<PHINode>(P)) { +      // If this PHI changes the underlying object in every iteration of the +      // loop, don't look through it.  Consider: +      //   int **A; +      //   for (i) { +      //     Prev = Curr;     // Prev = PHI (Prev_0, Curr) +      //     Curr = A[i]; +      //     *Prev, *Curr; +      // +      // Prev is tracking Curr one iteration behind so they refer to different +      // underlying objects. +      if (!LI || !LI->isLoopHeader(PN->getParent()) || +          isSameUnderlyingObjectInLoop(PN, LI)) +        for (Value *IncValue : PN->incoming_values()) +          Worklist.push_back(IncValue); +      continue; +    } + +    Objects.push_back(P); +  } while (!Worklist.empty()); +} + +/// This is the function that does the work of looking through basic +/// ptrtoint+arithmetic+inttoptr sequences. +static const Value *getUnderlyingObjectFromInt(const Value *V) { +  do { +    if (const Operator *U = dyn_cast<Operator>(V)) { +      // If we find a ptrtoint, we can transfer control back to the +      // regular getUnderlyingObjectFromInt. +      if (U->getOpcode() == Instruction::PtrToInt) +        return U->getOperand(0); +      // If we find an add of a constant, a multiplied value, or a phi, it's +      // likely that the other operand will lead us to the base +      // object. We don't have to worry about the case where the +      // object address is somehow being computed by the multiply, +      // because our callers only care when the result is an +      // identifiable object. +      if (U->getOpcode() != Instruction::Add || +          (!isa<ConstantInt>(U->getOperand(1)) && +           Operator::getOpcode(U->getOperand(1)) != Instruction::Mul && +           !isa<PHINode>(U->getOperand(1)))) +        return V; +      V = U->getOperand(0); +    } else { +      return V; +    } +    assert(V->getType()->isIntegerTy() && "Unexpected operand type!"); +  } while (true); +} + +/// This is a wrapper around GetUnderlyingObjects and adds support for basic +/// ptrtoint+arithmetic+inttoptr sequences. +/// It returns false if unidentified object is found in GetUnderlyingObjects. +bool llvm::getUnderlyingObjectsForCodeGen(const Value *V, +                          SmallVectorImpl<Value *> &Objects, +                          const DataLayout &DL) { +  SmallPtrSet<const Value *, 16> Visited; +  SmallVector<const Value *, 4> Working(1, V); +  do { +    V = Working.pop_back_val(); + +    SmallVector<Value *, 4> Objs; +    GetUnderlyingObjects(const_cast<Value *>(V), Objs, DL); + +    for (Value *V : Objs) { +      if (!Visited.insert(V).second) +        continue; +      if (Operator::getOpcode(V) == Instruction::IntToPtr) { +        const Value *O = +          getUnderlyingObjectFromInt(cast<User>(V)->getOperand(0)); +        if (O->getType()->isPointerTy()) { +          Working.push_back(O); +          continue; +        } +      } +      // If GetUnderlyingObjects fails to find an identifiable object, +      // getUnderlyingObjectsForCodeGen also fails for safety. +      if (!isIdentifiedObject(V)) { +        Objects.clear(); +        return false; +      } +      Objects.push_back(const_cast<Value *>(V)); +    } +  } while (!Working.empty()); +  return true; +} + +/// Return true if the only users of this pointer are lifetime markers. +bool llvm::onlyUsedByLifetimeMarkers(const Value *V) { +  for (const User *U : V->users()) { +    const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); +    if (!II) return false; + +    if (II->getIntrinsicID() != Intrinsic::lifetime_start && +        II->getIntrinsicID() != Intrinsic::lifetime_end) +      return false; +  } +  return true; +} + +bool llvm::isSafeToSpeculativelyExecute(const Value *V, +                                        const Instruction *CtxI, +                                        const DominatorTree *DT) { +  const Operator *Inst = dyn_cast<Operator>(V); +  if (!Inst) +    return false; + +  for (unsigned i = 0, e = Inst->getNumOperands(); i != e; ++i) +    if (Constant *C = dyn_cast<Constant>(Inst->getOperand(i))) +      if (C->canTrap()) +        return false; + +  switch (Inst->getOpcode()) { +  default: +    return true; +  case Instruction::UDiv: +  case Instruction::URem: { +    // x / y is undefined if y == 0. +    const APInt *V; +    if (match(Inst->getOperand(1), m_APInt(V))) +      return *V != 0; +    return false; +  } +  case Instruction::SDiv: +  case Instruction::SRem: { +    // x / y is undefined if y == 0 or x == INT_MIN and y == -1 +    const APInt *Numerator, *Denominator; +    if (!match(Inst->getOperand(1), m_APInt(Denominator))) +      return false; +    // We cannot hoist this division if the denominator is 0. +    if (*Denominator == 0) +      return false; +    // It's safe to hoist if the denominator is not 0 or -1. +    if (*Denominator != -1) +      return true; +    // At this point we know that the denominator is -1.  It is safe to hoist as +    // long we know that the numerator is not INT_MIN. +    if (match(Inst->getOperand(0), m_APInt(Numerator))) +      return !Numerator->isMinSignedValue(); +    // The numerator *might* be MinSignedValue. +    return false; +  } +  case Instruction::Load: { +    const LoadInst *LI = cast<LoadInst>(Inst); +    if (!LI->isUnordered() || +        // Speculative load may create a race that did not exist in the source. +        LI->getFunction()->hasFnAttribute(Attribute::SanitizeThread) || +        // Speculative load may load data from dirty regions. +        LI->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || +        LI->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) +      return false; +    const DataLayout &DL = LI->getModule()->getDataLayout(); +    return isDereferenceableAndAlignedPointer(LI->getPointerOperand(), +                                              LI->getAlignment(), DL, CtxI, DT); +  } +  case Instruction::Call: { +    auto *CI = cast<const CallInst>(Inst); +    const Function *Callee = CI->getCalledFunction(); + +    // The called function could have undefined behavior or side-effects, even +    // if marked readnone nounwind. +    return Callee && Callee->isSpeculatable(); +  } +  case Instruction::VAArg: +  case Instruction::Alloca: +  case Instruction::Invoke: +  case Instruction::PHI: +  case Instruction::Store: +  case Instruction::Ret: +  case Instruction::Br: +  case Instruction::IndirectBr: +  case Instruction::Switch: +  case Instruction::Unreachable: +  case Instruction::Fence: +  case Instruction::AtomicRMW: +  case Instruction::AtomicCmpXchg: +  case Instruction::LandingPad: +  case Instruction::Resume: +  case Instruction::CatchSwitch: +  case Instruction::CatchPad: +  case Instruction::CatchRet: +  case Instruction::CleanupPad: +  case Instruction::CleanupRet: +    return false; // Misc instructions which have effects +  } +} + +bool llvm::mayBeMemoryDependent(const Instruction &I) { +  return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I); +} + +OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, +                                                   const Value *RHS, +                                                   const DataLayout &DL, +                                                   AssumptionCache *AC, +                                                   const Instruction *CxtI, +                                                   const DominatorTree *DT) { +  // Multiplying n * m significant bits yields a result of n + m significant +  // bits. If the total number of significant bits does not exceed the +  // result bit width (minus 1), there is no overflow. +  // This means if we have enough leading zero bits in the operands +  // we can guarantee that the result does not overflow. +  // Ref: "Hacker's Delight" by Henry Warren +  unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); +  KnownBits LHSKnown(BitWidth); +  KnownBits RHSKnown(BitWidth); +  computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); +  computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); +  // Note that underestimating the number of zero bits gives a more +  // conservative answer. +  unsigned ZeroBits = LHSKnown.countMinLeadingZeros() + +                      RHSKnown.countMinLeadingZeros(); +  // First handle the easy case: if we have enough zero bits there's +  // definitely no overflow. +  if (ZeroBits >= BitWidth) +    return OverflowResult::NeverOverflows; + +  // Get the largest possible values for each operand. +  APInt LHSMax = ~LHSKnown.Zero; +  APInt RHSMax = ~RHSKnown.Zero; + +  // We know the multiply operation doesn't overflow if the maximum values for +  // each operand will not overflow after we multiply them together. +  bool MaxOverflow; +  (void)LHSMax.umul_ov(RHSMax, MaxOverflow); +  if (!MaxOverflow) +    return OverflowResult::NeverOverflows; + +  // We know it always overflows if multiplying the smallest possible values for +  // the operands also results in overflow. +  bool MinOverflow; +  (void)LHSKnown.One.umul_ov(RHSKnown.One, MinOverflow); +  if (MinOverflow) +    return OverflowResult::AlwaysOverflows; + +  return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, +                                                 const Value *RHS, +                                                 const DataLayout &DL, +                                                 AssumptionCache *AC, +                                                 const Instruction *CxtI, +                                                 const DominatorTree *DT) { +  // Multiplying n * m significant bits yields a result of n + m significant +  // bits. If the total number of significant bits does not exceed the +  // result bit width (minus 1), there is no overflow. +  // This means if we have enough leading sign bits in the operands +  // we can guarantee that the result does not overflow. +  // Ref: "Hacker's Delight" by Henry Warren +  unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + +  // Note that underestimating the number of sign bits gives a more +  // conservative answer. +  unsigned SignBits = ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) + +                      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT); + +  // First handle the easy case: if we have enough sign bits there's +  // definitely no overflow. +  if (SignBits > BitWidth + 1) +    return OverflowResult::NeverOverflows; + +  // There are two ambiguous cases where there can be no overflow: +  //   SignBits == BitWidth + 1    and +  //   SignBits == BitWidth +  // The second case is difficult to check, therefore we only handle the +  // first case. +  if (SignBits == BitWidth + 1) { +    // It overflows only when both arguments are negative and the true +    // product is exactly the minimum negative number. +    // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 +    // For simplicity we just check if at least one side is not negative. +    KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); +    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); +    if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) +      return OverflowResult::NeverOverflows; +  } +  return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, +                                                   const Value *RHS, +                                                   const DataLayout &DL, +                                                   AssumptionCache *AC, +                                                   const Instruction *CxtI, +                                                   const DominatorTree *DT) { +  KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); +  if (LHSKnown.isNonNegative() || LHSKnown.isNegative()) { +    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + +    if (LHSKnown.isNegative() && RHSKnown.isNegative()) { +      // The sign bit is set in both cases: this MUST overflow. +      // Create a simple add instruction, and insert it into the struct. +      return OverflowResult::AlwaysOverflows; +    } + +    if (LHSKnown.isNonNegative() && RHSKnown.isNonNegative()) { +      // The sign bit is clear in both cases: this CANNOT overflow. +      // Create a simple add instruction, and insert it into the struct. +      return OverflowResult::NeverOverflows; +    } +  } + +  return OverflowResult::MayOverflow; +} + +/// Return true if we can prove that adding the two values of the +/// knownbits will not overflow. +/// Otherwise return false. +static bool checkRippleForSignedAdd(const KnownBits &LHSKnown, +                                    const KnownBits &RHSKnown) { +  // Addition of two 2's complement numbers having opposite signs will never +  // overflow. +  if ((LHSKnown.isNegative() && RHSKnown.isNonNegative()) || +      (LHSKnown.isNonNegative() && RHSKnown.isNegative())) +    return true; + +  // If either of the values is known to be non-negative, adding them can only +  // overflow if the second is also non-negative, so we can assume that. +  // Two non-negative numbers will only overflow if there is a carry to the +  // sign bit, so we can check if even when the values are as big as possible +  // there is no overflow to the sign bit. +  if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) { +    APInt MaxLHS = ~LHSKnown.Zero; +    MaxLHS.clearSignBit(); +    APInt MaxRHS = ~RHSKnown.Zero; +    MaxRHS.clearSignBit(); +    APInt Result = std::move(MaxLHS) + std::move(MaxRHS); +    return Result.isSignBitClear(); +  } + +  // If either of the values is known to be negative, adding them can only +  // overflow if the second is also negative, so we can assume that. +  // Two negative number will only overflow if there is no carry to the sign +  // bit, so we can check if even when the values are as small as possible +  // there is overflow to the sign bit. +  if (LHSKnown.isNegative() || RHSKnown.isNegative()) { +    APInt MinLHS = LHSKnown.One; +    MinLHS.clearSignBit(); +    APInt MinRHS = RHSKnown.One; +    MinRHS.clearSignBit(); +    APInt Result = std::move(MinLHS) + std::move(MinRHS); +    return Result.isSignBitSet(); +  } + +  // If we reached here it means that we know nothing about the sign bits. +  // In this case we can't know if there will be an overflow, since by +  // changing the sign bits any two values can be made to overflow. +  return false; +} + +static OverflowResult computeOverflowForSignedAdd(const Value *LHS, +                                                  const Value *RHS, +                                                  const AddOperator *Add, +                                                  const DataLayout &DL, +                                                  AssumptionCache *AC, +                                                  const Instruction *CxtI, +                                                  const DominatorTree *DT) { +  if (Add && Add->hasNoSignedWrap()) { +    return OverflowResult::NeverOverflows; +  } + +  // If LHS and RHS each have at least two sign bits, the addition will look +  // like +  // +  // XX..... + +  // YY..... +  // +  // If the carry into the most significant position is 0, X and Y can't both +  // be 1 and therefore the carry out of the addition is also 0. +  // +  // If the carry into the most significant position is 1, X and Y can't both +  // be 0 and therefore the carry out of the addition is also 1. +  // +  // Since the carry into the most significant position is always equal to +  // the carry out of the addition, there is no signed overflow. +  if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 && +      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1) +    return OverflowResult::NeverOverflows; + +  KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); +  KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + +  if (checkRippleForSignedAdd(LHSKnown, RHSKnown)) +    return OverflowResult::NeverOverflows; + +  // The remaining code needs Add to be available. Early returns if not so. +  if (!Add) +    return OverflowResult::MayOverflow; + +  // If the sign of Add is the same as at least one of the operands, this add +  // CANNOT overflow. This is particularly useful when the sum is +  // @llvm.assume'ed non-negative rather than proved so from analyzing its +  // operands. +  bool LHSOrRHSKnownNonNegative = +      (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()); +  bool LHSOrRHSKnownNegative = +      (LHSKnown.isNegative() || RHSKnown.isNegative()); +  if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) { +    KnownBits AddKnown = computeKnownBits(Add, DL, /*Depth=*/0, AC, CxtI, DT); +    if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) || +        (AddKnown.isNegative() && LHSOrRHSKnownNegative)) { +      return OverflowResult::NeverOverflows; +    } +  } + +  return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, +                                                   const Value *RHS, +                                                   const DataLayout &DL, +                                                   AssumptionCache *AC, +                                                   const Instruction *CxtI, +                                                   const DominatorTree *DT) { +  // If the LHS is negative and the RHS is non-negative, no unsigned wrap. +  KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); +  KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); +  if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) +    return OverflowResult::NeverOverflows; + +  return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS, +                                                 const Value *RHS, +                                                 const DataLayout &DL, +                                                 AssumptionCache *AC, +                                                 const Instruction *CxtI, +                                                 const DominatorTree *DT) { +  // If LHS and RHS each have at least two sign bits, the subtraction +  // cannot overflow. +  if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 && +      ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1) +    return OverflowResult::NeverOverflows; + +  KnownBits LHSKnown = computeKnownBits(LHS, DL, 0, AC, CxtI, DT); + +  KnownBits RHSKnown = computeKnownBits(RHS, DL, 0, AC, CxtI, DT); + +  // Subtraction of two 2's complement numbers having identical signs will +  // never overflow. +  if ((LHSKnown.isNegative() && RHSKnown.isNegative()) || +      (LHSKnown.isNonNegative() && RHSKnown.isNonNegative())) +    return OverflowResult::NeverOverflows; + +  // TODO: implement logic similar to checkRippleForAdd +  return OverflowResult::MayOverflow; +} + +bool llvm::isOverflowIntrinsicNoWrap(const IntrinsicInst *II, +                                     const DominatorTree &DT) { +#ifndef NDEBUG +  auto IID = II->getIntrinsicID(); +  assert((IID == Intrinsic::sadd_with_overflow || +          IID == Intrinsic::uadd_with_overflow || +          IID == Intrinsic::ssub_with_overflow || +          IID == Intrinsic::usub_with_overflow || +          IID == Intrinsic::smul_with_overflow || +          IID == Intrinsic::umul_with_overflow) && +         "Not an overflow intrinsic!"); +#endif + +  SmallVector<const BranchInst *, 2> GuardingBranches; +  SmallVector<const ExtractValueInst *, 2> Results; + +  for (const User *U : II->users()) { +    if (const auto *EVI = dyn_cast<ExtractValueInst>(U)) { +      assert(EVI->getNumIndices() == 1 && "Obvious from CI's type"); + +      if (EVI->getIndices()[0] == 0) +        Results.push_back(EVI); +      else { +        assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type"); + +        for (const auto *U : EVI->users()) +          if (const auto *B = dyn_cast<BranchInst>(U)) { +            assert(B->isConditional() && "How else is it using an i1?"); +            GuardingBranches.push_back(B); +          } +      } +    } else { +      // We are using the aggregate directly in a way we don't want to analyze +      // here (storing it to a global, say). +      return false; +    } +  } + +  auto AllUsesGuardedByBranch = [&](const BranchInst *BI) { +    BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(1)); +    if (!NoWrapEdge.isSingleEdge()) +      return false; + +    // Check if all users of the add are provably no-wrap. +    for (const auto *Result : Results) { +      // If the extractvalue itself is not executed on overflow, the we don't +      // need to check each use separately, since domination is transitive. +      if (DT.dominates(NoWrapEdge, Result->getParent())) +        continue; + +      for (auto &RU : Result->uses()) +        if (!DT.dominates(NoWrapEdge, RU)) +          return false; +    } + +    return true; +  }; + +  return llvm::any_of(GuardingBranches, AllUsesGuardedByBranch); +} + + +OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add, +                                                 const DataLayout &DL, +                                                 AssumptionCache *AC, +                                                 const Instruction *CxtI, +                                                 const DominatorTree *DT) { +  return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1), +                                       Add, DL, AC, CxtI, DT); +} + +OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS, +                                                 const Value *RHS, +                                                 const DataLayout &DL, +                                                 AssumptionCache *AC, +                                                 const Instruction *CxtI, +                                                 const DominatorTree *DT) { +  return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, DL, AC, CxtI, DT); +} + +bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { +  // A memory operation returns normally if it isn't volatile. A volatile +  // operation is allowed to trap. +  // +  // An atomic operation isn't guaranteed to return in a reasonable amount of +  // time because it's possible for another thread to interfere with it for an +  // arbitrary length of time, but programs aren't allowed to rely on that. +  if (const LoadInst *LI = dyn_cast<LoadInst>(I)) +    return !LI->isVolatile(); +  if (const StoreInst *SI = dyn_cast<StoreInst>(I)) +    return !SI->isVolatile(); +  if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(I)) +    return !CXI->isVolatile(); +  if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) +    return !RMWI->isVolatile(); +  if (const MemIntrinsic *MII = dyn_cast<MemIntrinsic>(I)) +    return !MII->isVolatile(); + +  // If there is no successor, then execution can't transfer to it. +  if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) +    return !CRI->unwindsToCaller(); +  if (const auto *CatchSwitch = dyn_cast<CatchSwitchInst>(I)) +    return !CatchSwitch->unwindsToCaller(); +  if (isa<ResumeInst>(I)) +    return false; +  if (isa<ReturnInst>(I)) +    return false; +  if (isa<UnreachableInst>(I)) +    return false; + +  // Calls can throw, or contain an infinite loop, or kill the process. +  if (auto CS = ImmutableCallSite(I)) { +    // Call sites that throw have implicit non-local control flow. +    if (!CS.doesNotThrow()) +      return false; + +    // Non-throwing call sites can loop infinitely, call exit/pthread_exit +    // etc. and thus not return.  However, LLVM already assumes that +    // +    //  - Thread exiting actions are modeled as writes to memory invisible to +    //    the program. +    // +    //  - Loops that don't have side effects (side effects are volatile/atomic +    //    stores and IO) always terminate (see http://llvm.org/PR965). +    //    Furthermore IO itself is also modeled as writes to memory invisible to +    //    the program. +    // +    // We rely on those assumptions here, and use the memory effects of the call +    // target as a proxy for checking that it always returns. + +    // FIXME: This isn't aggressive enough; a call which only writes to a global +    // is guaranteed to return. +    return CS.onlyReadsMemory() || CS.onlyAccessesArgMemory() || +           match(I, m_Intrinsic<Intrinsic::assume>()) || +           match(I, m_Intrinsic<Intrinsic::sideeffect>()); +  } + +  // Other instructions return normally. +  return true; +} + +bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) { +  // TODO: This is slightly consdervative for invoke instruction since exiting +  // via an exception *is* normal control for them. +  for (auto I = BB->begin(), E = BB->end(); I != E; ++I) +    if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) +      return false; +  return true; +} + +bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I, +                                                  const Loop *L) { +  // The loop header is guaranteed to be executed for every iteration. +  // +  // FIXME: Relax this constraint to cover all basic blocks that are +  // guaranteed to be executed at every iteration. +  if (I->getParent() != L->getHeader()) return false; + +  for (const Instruction &LI : *L->getHeader()) { +    if (&LI == I) return true; +    if (!isGuaranteedToTransferExecutionToSuccessor(&LI)) return false; +  } +  llvm_unreachable("Instruction not contained in its own parent basic block."); +} + +bool llvm::propagatesFullPoison(const Instruction *I) { +  switch (I->getOpcode()) { +  case Instruction::Add: +  case Instruction::Sub: +  case Instruction::Xor: +  case Instruction::Trunc: +  case Instruction::BitCast: +  case Instruction::AddrSpaceCast: +  case Instruction::Mul: +  case Instruction::Shl: +  case Instruction::GetElementPtr: +    // These operations all propagate poison unconditionally. Note that poison +    // is not any particular value, so xor or subtraction of poison with +    // itself still yields poison, not zero. +    return true; + +  case Instruction::AShr: +  case Instruction::SExt: +    // For these operations, one bit of the input is replicated across +    // multiple output bits. A replicated poison bit is still poison. +    return true; + +  case Instruction::ICmp: +    // Comparing poison with any value yields poison.  This is why, for +    // instance, x s< (x +nsw 1) can be folded to true. +    return true; + +  default: +    return false; +  } +} + +const Value *llvm::getGuaranteedNonFullPoisonOp(const Instruction *I) { +  switch (I->getOpcode()) { +    case Instruction::Store: +      return cast<StoreInst>(I)->getPointerOperand(); + +    case Instruction::Load: +      return cast<LoadInst>(I)->getPointerOperand(); + +    case Instruction::AtomicCmpXchg: +      return cast<AtomicCmpXchgInst>(I)->getPointerOperand(); + +    case Instruction::AtomicRMW: +      return cast<AtomicRMWInst>(I)->getPointerOperand(); + +    case Instruction::UDiv: +    case Instruction::SDiv: +    case Instruction::URem: +    case Instruction::SRem: +      return I->getOperand(1); + +    default: +      return nullptr; +  } +} + +bool llvm::programUndefinedIfFullPoison(const Instruction *PoisonI) { +  // We currently only look for uses of poison values within the same basic +  // block, as that makes it easier to guarantee that the uses will be +  // executed given that PoisonI is executed. +  // +  // FIXME: Expand this to consider uses beyond the same basic block. To do +  // this, look out for the distinction between post-dominance and strong +  // post-dominance. +  const BasicBlock *BB = PoisonI->getParent(); + +  // Set of instructions that we have proved will yield poison if PoisonI +  // does. +  SmallSet<const Value *, 16> YieldsPoison; +  SmallSet<const BasicBlock *, 4> Visited; +  YieldsPoison.insert(PoisonI); +  Visited.insert(PoisonI->getParent()); + +  BasicBlock::const_iterator Begin = PoisonI->getIterator(), End = BB->end(); + +  unsigned Iter = 0; +  while (Iter++ < MaxDepth) { +    for (auto &I : make_range(Begin, End)) { +      if (&I != PoisonI) { +        const Value *NotPoison = getGuaranteedNonFullPoisonOp(&I); +        if (NotPoison != nullptr && YieldsPoison.count(NotPoison)) +          return true; +        if (!isGuaranteedToTransferExecutionToSuccessor(&I)) +          return false; +      } + +      // Mark poison that propagates from I through uses of I. +      if (YieldsPoison.count(&I)) { +        for (const User *User : I.users()) { +          const Instruction *UserI = cast<Instruction>(User); +          if (propagatesFullPoison(UserI)) +            YieldsPoison.insert(User); +        } +      } +    } + +    if (auto *NextBB = BB->getSingleSuccessor()) { +      if (Visited.insert(NextBB).second) { +        BB = NextBB; +        Begin = BB->getFirstNonPHI()->getIterator(); +        End = BB->end(); +        continue; +      } +    } + +    break; +  } +  return false; +} + +static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) { +  if (FMF.noNaNs()) +    return true; + +  if (auto *C = dyn_cast<ConstantFP>(V)) +    return !C->isNaN(); +  return false; +} + +static bool isKnownNonZero(const Value *V) { +  if (auto *C = dyn_cast<ConstantFP>(V)) +    return !C->isZero(); +  return false; +} + +/// Match clamp pattern for float types without care about NaNs or signed zeros. +/// Given non-min/max outer cmp/select from the clamp pattern this +/// function recognizes if it can be substitued by a "canonical" min/max +/// pattern. +static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred, +                                               Value *CmpLHS, Value *CmpRHS, +                                               Value *TrueVal, Value *FalseVal, +                                               Value *&LHS, Value *&RHS) { +  // Try to match +  //   X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2)) +  //   X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2)) +  // and return description of the outer Max/Min. + +  // First, check if select has inverse order: +  if (CmpRHS == FalseVal) { +    std::swap(TrueVal, FalseVal); +    Pred = CmpInst::getInversePredicate(Pred); +  } + +  // Assume success now. If there's no match, callers should not use these anyway. +  LHS = TrueVal; +  RHS = FalseVal; + +  const APFloat *FC1; +  if (CmpRHS != TrueVal || !match(CmpRHS, m_APFloat(FC1)) || !FC1->isFinite()) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  const APFloat *FC2; +  switch (Pred) { +  case CmpInst::FCMP_OLT: +  case CmpInst::FCMP_OLE: +  case CmpInst::FCMP_ULT: +  case CmpInst::FCMP_ULE: +    if (match(FalseVal, +              m_CombineOr(m_OrdFMin(m_Specific(CmpLHS), m_APFloat(FC2)), +                          m_UnordFMin(m_Specific(CmpLHS), m_APFloat(FC2)))) && +        FC1->compare(*FC2) == APFloat::cmpResult::cmpLessThan) +      return {SPF_FMAXNUM, SPNB_RETURNS_ANY, false}; +    break; +  case CmpInst::FCMP_OGT: +  case CmpInst::FCMP_OGE: +  case CmpInst::FCMP_UGT: +  case CmpInst::FCMP_UGE: +    if (match(FalseVal, +              m_CombineOr(m_OrdFMax(m_Specific(CmpLHS), m_APFloat(FC2)), +                          m_UnordFMax(m_Specific(CmpLHS), m_APFloat(FC2)))) && +        FC1->compare(*FC2) == APFloat::cmpResult::cmpGreaterThan) +      return {SPF_FMINNUM, SPNB_RETURNS_ANY, false}; +    break; +  default: +    break; +  } + +  return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Recognize variations of: +///   CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v))) +static SelectPatternResult matchClamp(CmpInst::Predicate Pred, +                                      Value *CmpLHS, Value *CmpRHS, +                                      Value *TrueVal, Value *FalseVal) { +  // Swap the select operands and predicate to match the patterns below. +  if (CmpRHS != TrueVal) { +    Pred = ICmpInst::getSwappedPredicate(Pred); +    std::swap(TrueVal, FalseVal); +  } +  const APInt *C1; +  if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) { +    const APInt *C2; +    // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1) +    if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) && +        C1->slt(*C2) && Pred == CmpInst::ICMP_SLT) +      return {SPF_SMAX, SPNB_NA, false}; + +    // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1) +    if (match(FalseVal, m_SMax(m_Specific(CmpLHS), m_APInt(C2))) && +        C1->sgt(*C2) && Pred == CmpInst::ICMP_SGT) +      return {SPF_SMIN, SPNB_NA, false}; + +    // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1) +    if (match(FalseVal, m_UMin(m_Specific(CmpLHS), m_APInt(C2))) && +        C1->ult(*C2) && Pred == CmpInst::ICMP_ULT) +      return {SPF_UMAX, SPNB_NA, false}; + +    // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1) +    if (match(FalseVal, m_UMax(m_Specific(CmpLHS), m_APInt(C2))) && +        C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT) +      return {SPF_UMIN, SPNB_NA, false}; +  } +  return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Recognize variations of: +///   a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c)) +static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred, +                                               Value *CmpLHS, Value *CmpRHS, +                                               Value *TVal, Value *FVal, +                                               unsigned Depth) { +  // TODO: Allow FP min/max with nnan/nsz. +  assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison"); + +  Value *A, *B; +  SelectPatternResult L = matchSelectPattern(TVal, A, B, nullptr, Depth + 1); +  if (!SelectPatternResult::isMinOrMax(L.Flavor)) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  Value *C, *D; +  SelectPatternResult R = matchSelectPattern(FVal, C, D, nullptr, Depth + 1); +  if (L.Flavor != R.Flavor) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  // We have something like: x Pred y ? min(a, b) : min(c, d). +  // Try to match the compare to the min/max operations of the select operands. +  // First, make sure we have the right compare predicate. +  switch (L.Flavor) { +  case SPF_SMIN: +    if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) { +      Pred = ICmpInst::getSwappedPredicate(Pred); +      std::swap(CmpLHS, CmpRHS); +    } +    if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) +      break; +    return {SPF_UNKNOWN, SPNB_NA, false}; +  case SPF_SMAX: +    if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { +      Pred = ICmpInst::getSwappedPredicate(Pred); +      std::swap(CmpLHS, CmpRHS); +    } +    if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) +      break; +    return {SPF_UNKNOWN, SPNB_NA, false}; +  case SPF_UMIN: +    if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { +      Pred = ICmpInst::getSwappedPredicate(Pred); +      std::swap(CmpLHS, CmpRHS); +    } +    if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) +      break; +    return {SPF_UNKNOWN, SPNB_NA, false}; +  case SPF_UMAX: +    if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { +      Pred = ICmpInst::getSwappedPredicate(Pred); +      std::swap(CmpLHS, CmpRHS); +    } +    if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) +      break; +    return {SPF_UNKNOWN, SPNB_NA, false}; +  default: +    return {SPF_UNKNOWN, SPNB_NA, false}; +  } + +  // If there is a common operand in the already matched min/max and the other +  // min/max operands match the compare operands (either directly or inverted), +  // then this is min/max of the same flavor. + +  // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b)) +  // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b)) +  if (D == B) { +    if ((CmpLHS == A && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) && +                                         match(A, m_Not(m_Specific(CmpRHS))))) +      return {L.Flavor, SPNB_NA, false}; +  } +  // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d)) +  // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d)) +  if (C == B) { +    if ((CmpLHS == A && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) && +                                         match(A, m_Not(m_Specific(CmpRHS))))) +      return {L.Flavor, SPNB_NA, false}; +  } +  // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a)) +  // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a)) +  if (D == A) { +    if ((CmpLHS == B && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) && +                                         match(B, m_Not(m_Specific(CmpRHS))))) +      return {L.Flavor, SPNB_NA, false}; +  } +  // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d)) +  // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d)) +  if (C == A) { +    if ((CmpLHS == B && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) && +                                         match(B, m_Not(m_Specific(CmpRHS))))) +      return {L.Flavor, SPNB_NA, false}; +  } + +  return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Match non-obvious integer minimum and maximum sequences. +static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, +                                       Value *CmpLHS, Value *CmpRHS, +                                       Value *TrueVal, Value *FalseVal, +                                       Value *&LHS, Value *&RHS, +                                       unsigned Depth) { +  // Assume success. If there's no match, callers should not use these anyway. +  LHS = TrueVal; +  RHS = FalseVal; + +  SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal); +  if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) +    return SPR; + +  SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, Depth); +  if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) +    return SPR; + +  if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  // Z = X -nsw Y +  // (X >s Y) ? 0 : Z ==> (Z >s 0) ? 0 : Z ==> SMIN(Z, 0) +  // (X <s Y) ? 0 : Z ==> (Z <s 0) ? 0 : Z ==> SMAX(Z, 0) +  if (match(TrueVal, m_Zero()) && +      match(FalseVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) +    return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; + +  // Z = X -nsw Y +  // (X >s Y) ? Z : 0 ==> (Z >s 0) ? Z : 0 ==> SMAX(Z, 0) +  // (X <s Y) ? Z : 0 ==> (Z <s 0) ? Z : 0 ==> SMIN(Z, 0) +  if (match(FalseVal, m_Zero()) && +      match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) +    return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; + +  const APInt *C1; +  if (!match(CmpRHS, m_APInt(C1))) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  // An unsigned min/max can be written with a signed compare. +  const APInt *C2; +  if ((CmpLHS == TrueVal && match(FalseVal, m_APInt(C2))) || +      (CmpLHS == FalseVal && match(TrueVal, m_APInt(C2)))) { +    // Is the sign bit set? +    // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX +    // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN +    if (Pred == CmpInst::ICMP_SLT && C1->isNullValue() && +        C2->isMaxSignedValue()) +      return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; + +    // Is the sign bit clear? +    // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX +    // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN +    if (Pred == CmpInst::ICMP_SGT && C1->isAllOnesValue() && +        C2->isMinSignedValue()) +      return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; +  } + +  // Look through 'not' ops to find disguised signed min/max. +  // (X >s C) ? ~X : ~C ==> (~X <s ~C) ? ~X : ~C ==> SMIN(~X, ~C) +  // (X <s C) ? ~X : ~C ==> (~X >s ~C) ? ~X : ~C ==> SMAX(~X, ~C) +  if (match(TrueVal, m_Not(m_Specific(CmpLHS))) && +      match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2) +    return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; + +  // (X >s C) ? ~C : ~X ==> (~X <s ~C) ? ~C : ~X ==> SMAX(~C, ~X) +  // (X <s C) ? ~C : ~X ==> (~X >s ~C) ? ~C : ~X ==> SMIN(~C, ~X) +  if (match(FalseVal, m_Not(m_Specific(CmpLHS))) && +      match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2) +    return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; + +  return {SPF_UNKNOWN, SPNB_NA, false}; +} + +bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW) { +  assert(X && Y && "Invalid operand"); + +  // X = sub (0, Y) || X = sub nsw (0, Y) +  if ((!NeedNSW && match(X, m_Sub(m_ZeroInt(), m_Specific(Y)))) || +      (NeedNSW && match(X, m_NSWSub(m_ZeroInt(), m_Specific(Y))))) +    return true; + +  // Y = sub (0, X) || Y = sub nsw (0, X) +  if ((!NeedNSW && match(Y, m_Sub(m_ZeroInt(), m_Specific(X)))) || +      (NeedNSW && match(Y, m_NSWSub(m_ZeroInt(), m_Specific(X))))) +    return true; + +  // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A) +  Value *A, *B; +  return (!NeedNSW && (match(X, m_Sub(m_Value(A), m_Value(B))) && +                        match(Y, m_Sub(m_Specific(B), m_Specific(A))))) || +         (NeedNSW && (match(X, m_NSWSub(m_Value(A), m_Value(B))) && +                       match(Y, m_NSWSub(m_Specific(B), m_Specific(A))))); +} + +static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, +                                              FastMathFlags FMF, +                                              Value *CmpLHS, Value *CmpRHS, +                                              Value *TrueVal, Value *FalseVal, +                                              Value *&LHS, Value *&RHS, +                                              unsigned Depth) { +  LHS = CmpLHS; +  RHS = CmpRHS; + +  // Signed zero may return inconsistent results between implementations. +  //  (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0 +  //  minNum(0.0, -0.0)          // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1) +  // Therefore, we behave conservatively and only proceed if at least one of the +  // operands is known to not be zero or if we don't care about signed zero. +  switch (Pred) { +  default: break; +  // FIXME: Include OGT/OLT/UGT/ULT. +  case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE: +  case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE: +    if (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) && +        !isKnownNonZero(CmpRHS)) +      return {SPF_UNKNOWN, SPNB_NA, false}; +  } + +  SelectPatternNaNBehavior NaNBehavior = SPNB_NA; +  bool Ordered = false; + +  // When given one NaN and one non-NaN input: +  //   - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input. +  //   - A simple C99 (a < b ? a : b) construction will return 'b' (as the +  //     ordered comparison fails), which could be NaN or non-NaN. +  // so here we discover exactly what NaN behavior is required/accepted. +  if (CmpInst::isFPPredicate(Pred)) { +    bool LHSSafe = isKnownNonNaN(CmpLHS, FMF); +    bool RHSSafe = isKnownNonNaN(CmpRHS, FMF); + +    if (LHSSafe && RHSSafe) { +      // Both operands are known non-NaN. +      NaNBehavior = SPNB_RETURNS_ANY; +    } else if (CmpInst::isOrdered(Pred)) { +      // An ordered comparison will return false when given a NaN, so it +      // returns the RHS. +      Ordered = true; +      if (LHSSafe) +        // LHS is non-NaN, so if RHS is NaN then NaN will be returned. +        NaNBehavior = SPNB_RETURNS_NAN; +      else if (RHSSafe) +        NaNBehavior = SPNB_RETURNS_OTHER; +      else +        // Completely unsafe. +        return {SPF_UNKNOWN, SPNB_NA, false}; +    } else { +      Ordered = false; +      // An unordered comparison will return true when given a NaN, so it +      // returns the LHS. +      if (LHSSafe) +        // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned. +        NaNBehavior = SPNB_RETURNS_OTHER; +      else if (RHSSafe) +        NaNBehavior = SPNB_RETURNS_NAN; +      else +        // Completely unsafe. +        return {SPF_UNKNOWN, SPNB_NA, false}; +    } +  } + +  if (TrueVal == CmpRHS && FalseVal == CmpLHS) { +    std::swap(CmpLHS, CmpRHS); +    Pred = CmpInst::getSwappedPredicate(Pred); +    if (NaNBehavior == SPNB_RETURNS_NAN) +      NaNBehavior = SPNB_RETURNS_OTHER; +    else if (NaNBehavior == SPNB_RETURNS_OTHER) +      NaNBehavior = SPNB_RETURNS_NAN; +    Ordered = !Ordered; +  } + +  // ([if]cmp X, Y) ? X : Y +  if (TrueVal == CmpLHS && FalseVal == CmpRHS) { +    switch (Pred) { +    default: return {SPF_UNKNOWN, SPNB_NA, false}; // Equality. +    case ICmpInst::ICMP_UGT: +    case ICmpInst::ICMP_UGE: return {SPF_UMAX, SPNB_NA, false}; +    case ICmpInst::ICMP_SGT: +    case ICmpInst::ICMP_SGE: return {SPF_SMAX, SPNB_NA, false}; +    case ICmpInst::ICMP_ULT: +    case ICmpInst::ICMP_ULE: return {SPF_UMIN, SPNB_NA, false}; +    case ICmpInst::ICMP_SLT: +    case ICmpInst::ICMP_SLE: return {SPF_SMIN, SPNB_NA, false}; +    case FCmpInst::FCMP_UGT: +    case FCmpInst::FCMP_UGE: +    case FCmpInst::FCMP_OGT: +    case FCmpInst::FCMP_OGE: return {SPF_FMAXNUM, NaNBehavior, Ordered}; +    case FCmpInst::FCMP_ULT: +    case FCmpInst::FCMP_ULE: +    case FCmpInst::FCMP_OLT: +    case FCmpInst::FCMP_OLE: return {SPF_FMINNUM, NaNBehavior, Ordered}; +    } +  } + +  if (isKnownNegation(TrueVal, FalseVal)) { +    // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can +    // match against either LHS or sext(LHS). +    auto MaybeSExtCmpLHS = +        m_CombineOr(m_Specific(CmpLHS), m_SExt(m_Specific(CmpLHS))); +    auto ZeroOrAllOnes = m_CombineOr(m_ZeroInt(), m_AllOnes()); +    auto ZeroOrOne = m_CombineOr(m_ZeroInt(), m_One()); +    if (match(TrueVal, MaybeSExtCmpLHS)) { +      // Set the return values. If the compare uses the negated value (-X >s 0), +      // swap the return values because the negated value is always 'RHS'. +      LHS = TrueVal; +      RHS = FalseVal; +      if (match(CmpLHS, m_Neg(m_Specific(FalseVal)))) +        std::swap(LHS, RHS); + +      // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X) +      // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X) +      if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes)) +        return {SPF_ABS, SPNB_NA, false}; + +      // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X) +      // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X) +      if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne)) +        return {SPF_NABS, SPNB_NA, false}; +    } +    else if (match(FalseVal, MaybeSExtCmpLHS)) { +      // Set the return values. If the compare uses the negated value (-X >s 0), +      // swap the return values because the negated value is always 'RHS'. +      LHS = FalseVal; +      RHS = TrueVal; +      if (match(CmpLHS, m_Neg(m_Specific(TrueVal)))) +        std::swap(LHS, RHS); + +      // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X) +      // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X) +      if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes)) +        return {SPF_NABS, SPNB_NA, false}; + +      // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X) +      // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X) +      if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne)) +        return {SPF_ABS, SPNB_NA, false}; +    } +  } + +  if (CmpInst::isIntPredicate(Pred)) +    return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth); + +  // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar +  // may return either -0.0 or 0.0, so fcmp/select pair has stricter +  // semantics than minNum. Be conservative in such case. +  if (NaNBehavior != SPNB_RETURNS_ANY || +      (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) && +       !isKnownNonZero(CmpRHS))) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS); +} + +/// Helps to match a select pattern in case of a type mismatch. +/// +/// The function processes the case when type of true and false values of a +/// select instruction differs from type of the cmp instruction operands because +/// of a cast instruction. The function checks if it is legal to move the cast +/// operation after "select". If yes, it returns the new second value of +/// "select" (with the assumption that cast is moved): +/// 1. As operand of cast instruction when both values of "select" are same cast +/// instructions. +/// 2. As restored constant (by applying reverse cast operation) when the first +/// value of the "select" is a cast operation and the second value is a +/// constant. +/// NOTE: We return only the new second value because the first value could be +/// accessed as operand of cast instruction. +static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, +                              Instruction::CastOps *CastOp) { +  auto *Cast1 = dyn_cast<CastInst>(V1); +  if (!Cast1) +    return nullptr; + +  *CastOp = Cast1->getOpcode(); +  Type *SrcTy = Cast1->getSrcTy(); +  if (auto *Cast2 = dyn_cast<CastInst>(V2)) { +    // If V1 and V2 are both the same cast from the same type, look through V1. +    if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy()) +      return Cast2->getOperand(0); +    return nullptr; +  } + +  auto *C = dyn_cast<Constant>(V2); +  if (!C) +    return nullptr; + +  Constant *CastedTo = nullptr; +  switch (*CastOp) { +  case Instruction::ZExt: +    if (CmpI->isUnsigned()) +      CastedTo = ConstantExpr::getTrunc(C, SrcTy); +    break; +  case Instruction::SExt: +    if (CmpI->isSigned()) +      CastedTo = ConstantExpr::getTrunc(C, SrcTy, true); +    break; +  case Instruction::Trunc: +    Constant *CmpConst; +    if (match(CmpI->getOperand(1), m_Constant(CmpConst)) && +        CmpConst->getType() == SrcTy) { +      // Here we have the following case: +      // +      //   %cond = cmp iN %x, CmpConst +      //   %tr = trunc iN %x to iK +      //   %narrowsel = select i1 %cond, iK %t, iK C +      // +      // We can always move trunc after select operation: +      // +      //   %cond = cmp iN %x, CmpConst +      //   %widesel = select i1 %cond, iN %x, iN CmpConst +      //   %tr = trunc iN %widesel to iK +      // +      // Note that C could be extended in any way because we don't care about +      // upper bits after truncation. It can't be abs pattern, because it would +      // look like: +      // +      //   select i1 %cond, x, -x. +      // +      // So only min/max pattern could be matched. Such match requires widened C +      // == CmpConst. That is why set widened C = CmpConst, condition trunc +      // CmpConst == C is checked below. +      CastedTo = CmpConst; +    } else { +      CastedTo = ConstantExpr::getIntegerCast(C, SrcTy, CmpI->isSigned()); +    } +    break; +  case Instruction::FPTrunc: +    CastedTo = ConstantExpr::getFPExtend(C, SrcTy, true); +    break; +  case Instruction::FPExt: +    CastedTo = ConstantExpr::getFPTrunc(C, SrcTy, true); +    break; +  case Instruction::FPToUI: +    CastedTo = ConstantExpr::getUIToFP(C, SrcTy, true); +    break; +  case Instruction::FPToSI: +    CastedTo = ConstantExpr::getSIToFP(C, SrcTy, true); +    break; +  case Instruction::UIToFP: +    CastedTo = ConstantExpr::getFPToUI(C, SrcTy, true); +    break; +  case Instruction::SIToFP: +    CastedTo = ConstantExpr::getFPToSI(C, SrcTy, true); +    break; +  default: +    break; +  } + +  if (!CastedTo) +    return nullptr; + +  // Make sure the cast doesn't lose any information. +  Constant *CastedBack = +      ConstantExpr::getCast(*CastOp, CastedTo, C->getType(), true); +  if (CastedBack != C) +    return nullptr; + +  return CastedTo; +} + +SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, +                                             Instruction::CastOps *CastOp, +                                             unsigned Depth) { +  if (Depth >= MaxDepth) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  SelectInst *SI = dyn_cast<SelectInst>(V); +  if (!SI) return {SPF_UNKNOWN, SPNB_NA, false}; + +  CmpInst *CmpI = dyn_cast<CmpInst>(SI->getCondition()); +  if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false}; + +  CmpInst::Predicate Pred = CmpI->getPredicate(); +  Value *CmpLHS = CmpI->getOperand(0); +  Value *CmpRHS = CmpI->getOperand(1); +  Value *TrueVal = SI->getTrueValue(); +  Value *FalseVal = SI->getFalseValue(); +  FastMathFlags FMF; +  if (isa<FPMathOperator>(CmpI)) +    FMF = CmpI->getFastMathFlags(); + +  // Bail out early. +  if (CmpI->isEquality()) +    return {SPF_UNKNOWN, SPNB_NA, false}; + +  // Deal with type mismatches. +  if (CastOp && CmpLHS->getType() != TrueVal->getType()) { +    if (Value *C = lookThroughCast(CmpI, TrueVal, FalseVal, CastOp)) { +      // If this is a potential fmin/fmax with a cast to integer, then ignore +      // -0.0 because there is no corresponding integer value. +      if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI) +        FMF.setNoSignedZeros(); +      return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, +                                  cast<CastInst>(TrueVal)->getOperand(0), C, +                                  LHS, RHS, Depth); +    } +    if (Value *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp)) { +      // If this is a potential fmin/fmax with a cast to integer, then ignore +      // -0.0 because there is no corresponding integer value. +      if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI) +        FMF.setNoSignedZeros(); +      return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, +                                  C, cast<CastInst>(FalseVal)->getOperand(0), +                                  LHS, RHS, Depth); +    } +  } +  return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal, +                              LHS, RHS, Depth); +} + +CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) { +  if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT; +  if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT; +  if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT; +  if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT; +  if (SPF == SPF_FMINNUM) +    return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; +  if (SPF == SPF_FMAXNUM) +    return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; +  llvm_unreachable("unhandled!"); +} + +SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) { +  if (SPF == SPF_SMIN) return SPF_SMAX; +  if (SPF == SPF_UMIN) return SPF_UMAX; +  if (SPF == SPF_SMAX) return SPF_SMIN; +  if (SPF == SPF_UMAX) return SPF_UMIN; +  llvm_unreachable("unhandled!"); +} + +CmpInst::Predicate llvm::getInverseMinMaxPred(SelectPatternFlavor SPF) { +  return getMinMaxPred(getInverseMinMaxFlavor(SPF)); +} + +/// Return true if "icmp Pred LHS RHS" is always true. +static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, +                            const Value *RHS, const DataLayout &DL, +                            unsigned Depth) { +  assert(!LHS->getType()->isVectorTy() && "TODO: extend to handle vectors!"); +  if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS) +    return true; + +  switch (Pred) { +  default: +    return false; + +  case CmpInst::ICMP_SLE: { +    const APInt *C; + +    // LHS s<= LHS +_{nsw} C   if C >= 0 +    if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C)))) +      return !C->isNegative(); +    return false; +  } + +  case CmpInst::ICMP_ULE: { +    const APInt *C; + +    // LHS u<= LHS +_{nuw} C   for any C +    if (match(RHS, m_NUWAdd(m_Specific(LHS), m_APInt(C)))) +      return true; + +    // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB) +    auto MatchNUWAddsToSameValue = [&](const Value *A, const Value *B, +                                       const Value *&X, +                                       const APInt *&CA, const APInt *&CB) { +      if (match(A, m_NUWAdd(m_Value(X), m_APInt(CA))) && +          match(B, m_NUWAdd(m_Specific(X), m_APInt(CB)))) +        return true; + +      // If X & C == 0 then (X | C) == X +_{nuw} C +      if (match(A, m_Or(m_Value(X), m_APInt(CA))) && +          match(B, m_Or(m_Specific(X), m_APInt(CB)))) { +        KnownBits Known(CA->getBitWidth()); +        computeKnownBits(X, Known, DL, Depth + 1, /*AC*/ nullptr, +                         /*CxtI*/ nullptr, /*DT*/ nullptr); +        if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero)) +          return true; +      } + +      return false; +    }; + +    const Value *X; +    const APInt *CLHS, *CRHS; +    if (MatchNUWAddsToSameValue(LHS, RHS, X, CLHS, CRHS)) +      return CLHS->ule(*CRHS); + +    return false; +  } +  } +} + +/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred +/// ALHS ARHS" is true.  Otherwise, return None. +static Optional<bool> +isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, +                      const Value *ARHS, const Value *BLHS, const Value *BRHS, +                      const DataLayout &DL, unsigned Depth) { +  switch (Pred) { +  default: +    return None; + +  case CmpInst::ICMP_SLT: +  case CmpInst::ICMP_SLE: +    if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth) && +        isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth)) +      return true; +    return None; + +  case CmpInst::ICMP_ULT: +  case CmpInst::ICMP_ULE: +    if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) && +        isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth)) +      return true; +    return None; +  } +} + +/// Return true if the operands of the two compares match.  IsSwappedOps is true +/// when the operands match, but are swapped. +static bool isMatchingOps(const Value *ALHS, const Value *ARHS, +                          const Value *BLHS, const Value *BRHS, +                          bool &IsSwappedOps) { + +  bool IsMatchingOps = (ALHS == BLHS && ARHS == BRHS); +  IsSwappedOps = (ALHS == BRHS && ARHS == BLHS); +  return IsMatchingOps || IsSwappedOps; +} + +/// Return true if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS BRHS" is +/// true.  Return false if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS +/// BRHS" is false.  Otherwise, return None if we can't infer anything. +static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred, +                                                    const Value *ALHS, +                                                    const Value *ARHS, +                                                    CmpInst::Predicate BPred, +                                                    const Value *BLHS, +                                                    const Value *BRHS, +                                                    bool IsSwappedOps) { +  // Canonicalize the operands so they're matching. +  if (IsSwappedOps) { +    std::swap(BLHS, BRHS); +    BPred = ICmpInst::getSwappedPredicate(BPred); +  } +  if (CmpInst::isImpliedTrueByMatchingCmp(APred, BPred)) +    return true; +  if (CmpInst::isImpliedFalseByMatchingCmp(APred, BPred)) +    return false; + +  return None; +} + +/// Return true if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS C2" is +/// true.  Return false if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS +/// C2" is false.  Otherwise, return None if we can't infer anything. +static Optional<bool> +isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, const Value *ALHS, +                                 const ConstantInt *C1, +                                 CmpInst::Predicate BPred, +                                 const Value *BLHS, const ConstantInt *C2) { +  assert(ALHS == BLHS && "LHS operands must match."); +  ConstantRange DomCR = +      ConstantRange::makeExactICmpRegion(APred, C1->getValue()); +  ConstantRange CR = +      ConstantRange::makeAllowedICmpRegion(BPred, C2->getValue()); +  ConstantRange Intersection = DomCR.intersectWith(CR); +  ConstantRange Difference = DomCR.difference(CR); +  if (Intersection.isEmptySet()) +    return false; +  if (Difference.isEmptySet()) +    return true; +  return None; +} + +/// Return true if LHS implies RHS is true.  Return false if LHS implies RHS is +/// false.  Otherwise, return None if we can't infer anything. +static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, +                                         const ICmpInst *RHS, +                                         const DataLayout &DL, bool LHSIsTrue, +                                         unsigned Depth) { +  Value *ALHS = LHS->getOperand(0); +  Value *ARHS = LHS->getOperand(1); +  // The rest of the logic assumes the LHS condition is true.  If that's not the +  // case, invert the predicate to make it so. +  ICmpInst::Predicate APred = +      LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate(); + +  Value *BLHS = RHS->getOperand(0); +  Value *BRHS = RHS->getOperand(1); +  ICmpInst::Predicate BPred = RHS->getPredicate(); + +  // Can we infer anything when the two compares have matching operands? +  bool IsSwappedOps; +  if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, IsSwappedOps)) { +    if (Optional<bool> Implication = isImpliedCondMatchingOperands( +            APred, ALHS, ARHS, BPred, BLHS, BRHS, IsSwappedOps)) +      return Implication; +    // No amount of additional analysis will infer the second condition, so +    // early exit. +    return None; +  } + +  // Can we infer anything when the LHS operands match and the RHS operands are +  // constants (not necessarily matching)? +  if (ALHS == BLHS && isa<ConstantInt>(ARHS) && isa<ConstantInt>(BRHS)) { +    if (Optional<bool> Implication = isImpliedCondMatchingImmOperands( +            APred, ALHS, cast<ConstantInt>(ARHS), BPred, BLHS, +            cast<ConstantInt>(BRHS))) +      return Implication; +    // No amount of additional analysis will infer the second condition, so +    // early exit. +    return None; +  } + +  if (APred == BPred) +    return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS, DL, Depth); +  return None; +} + +/// Return true if LHS implies RHS is true.  Return false if LHS implies RHS is +/// false.  Otherwise, return None if we can't infer anything.  We expect the +/// RHS to be an icmp and the LHS to be an 'and' or an 'or' instruction. +static Optional<bool> isImpliedCondAndOr(const BinaryOperator *LHS, +                                         const ICmpInst *RHS, +                                         const DataLayout &DL, bool LHSIsTrue, +                                         unsigned Depth) { +  // The LHS must be an 'or' or an 'and' instruction. +  assert((LHS->getOpcode() == Instruction::And || +          LHS->getOpcode() == Instruction::Or) && +         "Expected LHS to be 'and' or 'or'."); + +  assert(Depth <= MaxDepth && "Hit recursion limit"); + +  // If the result of an 'or' is false, then we know both legs of the 'or' are +  // false.  Similarly, if the result of an 'and' is true, then we know both +  // legs of the 'and' are true. +  Value *ALHS, *ARHS; +  if ((!LHSIsTrue && match(LHS, m_Or(m_Value(ALHS), m_Value(ARHS)))) || +      (LHSIsTrue && match(LHS, m_And(m_Value(ALHS), m_Value(ARHS))))) { +    // FIXME: Make this non-recursion. +    if (Optional<bool> Implication = +            isImpliedCondition(ALHS, RHS, DL, LHSIsTrue, Depth + 1)) +      return Implication; +    if (Optional<bool> Implication = +            isImpliedCondition(ARHS, RHS, DL, LHSIsTrue, Depth + 1)) +      return Implication; +    return None; +  } +  return None; +} + +Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, +                                        const DataLayout &DL, bool LHSIsTrue, +                                        unsigned Depth) { +  // Bail out when we hit the limit. +  if (Depth == MaxDepth) +    return None; + +  // A mismatch occurs when we compare a scalar cmp to a vector cmp, for +  // example. +  if (LHS->getType() != RHS->getType()) +    return None; + +  Type *OpTy = LHS->getType(); +  assert(OpTy->isIntOrIntVectorTy(1) && "Expected integer type only!"); + +  // LHS ==> RHS by definition +  if (LHS == RHS) +    return LHSIsTrue; + +  // FIXME: Extending the code below to handle vectors. +  if (OpTy->isVectorTy()) +    return None; + +  assert(OpTy->isIntegerTy(1) && "implied by above"); + +  // Both LHS and RHS are icmps. +  const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS); +  const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS); +  if (LHSCmp && RHSCmp) +    return isImpliedCondICmps(LHSCmp, RHSCmp, DL, LHSIsTrue, Depth); + +  // The LHS should be an 'or' or an 'and' instruction.  We expect the RHS to be +  // an icmp. FIXME: Add support for and/or on the RHS. +  const BinaryOperator *LHSBO = dyn_cast<BinaryOperator>(LHS); +  if (LHSBO && RHSCmp) { +    if ((LHSBO->getOpcode() == Instruction::And || +         LHSBO->getOpcode() == Instruction::Or)) +      return isImpliedCondAndOr(LHSBO, RHSCmp, DL, LHSIsTrue, Depth); +  } +  return None; +} diff --git a/contrib/llvm/lib/Analysis/VectorUtils.cpp b/contrib/llvm/lib/Analysis/VectorUtils.cpp new file mode 100644 index 000000000000..d73d24736439 --- /dev/null +++ b/contrib/llvm/lib/Analysis/VectorUtils.cpp @@ -0,0 +1,577 @@ +//===----------- VectorUtils.cpp - Vectorizer utility functions -----------===// +// +//                     The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines vectorizer utilities. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Value.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +/// Identify if the intrinsic is trivially vectorizable. +/// This method returns true if the intrinsic's argument types are all +/// scalars for the scalar form of the intrinsic and all vectors for +/// the vector form of the intrinsic. +bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { +  switch (ID) { +  case Intrinsic::sqrt: +  case Intrinsic::sin: +  case Intrinsic::cos: +  case Intrinsic::exp: +  case Intrinsic::exp2: +  case Intrinsic::log: +  case Intrinsic::log10: +  case Intrinsic::log2: +  case Intrinsic::fabs: +  case Intrinsic::minnum: +  case Intrinsic::maxnum: +  case Intrinsic::copysign: +  case Intrinsic::floor: +  case Intrinsic::ceil: +  case Intrinsic::trunc: +  case Intrinsic::rint: +  case Intrinsic::nearbyint: +  case Intrinsic::round: +  case Intrinsic::bswap: +  case Intrinsic::bitreverse: +  case Intrinsic::ctpop: +  case Intrinsic::pow: +  case Intrinsic::fma: +  case Intrinsic::fmuladd: +  case Intrinsic::ctlz: +  case Intrinsic::cttz: +  case Intrinsic::powi: +    return true; +  default: +    return false; +  } +} + +/// Identifies if the intrinsic has a scalar operand. It check for +/// ctlz,cttz and powi special intrinsics whose argument is scalar. +bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, +                                        unsigned ScalarOpdIdx) { +  switch (ID) { +  case Intrinsic::ctlz: +  case Intrinsic::cttz: +  case Intrinsic::powi: +    return (ScalarOpdIdx == 1); +  default: +    return false; +  } +} + +/// Returns intrinsic ID for call. +/// For the input call instruction it finds mapping intrinsic and returns +/// its ID, in case it does not found it return not_intrinsic. +Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, +                                                const TargetLibraryInfo *TLI) { +  Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI); +  if (ID == Intrinsic::not_intrinsic) +    return Intrinsic::not_intrinsic; + +  if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || +      ID == Intrinsic::lifetime_end || ID == Intrinsic::assume || +      ID == Intrinsic::sideeffect) +    return ID; +  return Intrinsic::not_intrinsic; +} + +/// Find the operand of the GEP that should be checked for consecutive +/// stores. This ignores trailing indices that have no effect on the final +/// pointer. +unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { +  const DataLayout &DL = Gep->getModule()->getDataLayout(); +  unsigned LastOperand = Gep->getNumOperands() - 1; +  unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); + +  // Walk backwards and try to peel off zeros. +  while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { +    // Find the type we're currently indexing into. +    gep_type_iterator GEPTI = gep_type_begin(Gep); +    std::advance(GEPTI, LastOperand - 2); + +    // If it's a type with the same allocation size as the result of the GEP we +    // can peel off the zero index. +    if (DL.getTypeAllocSize(GEPTI.getIndexedType()) != GEPAllocSize) +      break; +    --LastOperand; +  } + +  return LastOperand; +} + +/// If the argument is a GEP, then returns the operand identified by +/// getGEPInductionOperand. However, if there is some other non-loop-invariant +/// operand, it returns that instead. +Value *llvm::stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { +  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); +  if (!GEP) +    return Ptr; + +  unsigned InductionOperand = getGEPInductionOperand(GEP); + +  // Check that all of the gep indices are uniform except for our induction +  // operand. +  for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) +    if (i != InductionOperand && +        !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) +      return Ptr; +  return GEP->getOperand(InductionOperand); +} + +/// If a value has only one user that is a CastInst, return it. +Value *llvm::getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { +  Value *UniqueCast = nullptr; +  for (User *U : Ptr->users()) { +    CastInst *CI = dyn_cast<CastInst>(U); +    if (CI && CI->getType() == Ty) { +      if (!UniqueCast) +        UniqueCast = CI; +      else +        return nullptr; +    } +  } +  return UniqueCast; +} + +/// Get the stride of a pointer access in a loop. Looks for symbolic +/// strides "a[i*stride]". Returns the symbolic stride, or null otherwise. +Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { +  auto *PtrTy = dyn_cast<PointerType>(Ptr->getType()); +  if (!PtrTy || PtrTy->isAggregateType()) +    return nullptr; + +  // Try to remove a gep instruction to make the pointer (actually index at this +  // point) easier analyzable. If OrigPtr is equal to Ptr we are analyzing the +  // pointer, otherwise, we are analyzing the index. +  Value *OrigPtr = Ptr; + +  // The size of the pointer access. +  int64_t PtrAccessSize = 1; + +  Ptr = stripGetElementPtr(Ptr, SE, Lp); +  const SCEV *V = SE->getSCEV(Ptr); + +  if (Ptr != OrigPtr) +    // Strip off casts. +    while (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) +      V = C->getOperand(); + +  const SCEVAddRecExpr *S = dyn_cast<SCEVAddRecExpr>(V); +  if (!S) +    return nullptr; + +  V = S->getStepRecurrence(*SE); +  if (!V) +    return nullptr; + +  // Strip off the size of access multiplication if we are still analyzing the +  // pointer. +  if (OrigPtr == Ptr) { +    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { +      if (M->getOperand(0)->getSCEVType() != scConstant) +        return nullptr; + +      const APInt &APStepVal = cast<SCEVConstant>(M->getOperand(0))->getAPInt(); + +      // Huge step value - give up. +      if (APStepVal.getBitWidth() > 64) +        return nullptr; + +      int64_t StepVal = APStepVal.getSExtValue(); +      if (PtrAccessSize != StepVal) +        return nullptr; +      V = M->getOperand(1); +    } +  } + +  // Strip off casts. +  Type *StripedOffRecurrenceCast = nullptr; +  if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(V)) { +    StripedOffRecurrenceCast = C->getType(); +    V = C->getOperand(); +  } + +  // Look for the loop invariant symbolic value. +  const SCEVUnknown *U = dyn_cast<SCEVUnknown>(V); +  if (!U) +    return nullptr; + +  Value *Stride = U->getValue(); +  if (!Lp->isLoopInvariant(Stride)) +    return nullptr; + +  // If we have stripped off the recurrence cast we have to make sure that we +  // return the value that is used in this loop so that we can replace it later. +  if (StripedOffRecurrenceCast) +    Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); + +  return Stride; +} + +/// Given a vector and an element number, see if the scalar value is +/// already around as a register, for example if it were inserted then extracted +/// from the vector. +Value *llvm::findScalarElement(Value *V, unsigned EltNo) { +  assert(V->getType()->isVectorTy() && "Not looking at a vector?"); +  VectorType *VTy = cast<VectorType>(V->getType()); +  unsigned Width = VTy->getNumElements(); +  if (EltNo >= Width)  // Out of range access. +    return UndefValue::get(VTy->getElementType()); + +  if (Constant *C = dyn_cast<Constant>(V)) +    return C->getAggregateElement(EltNo); + +  if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) { +    // If this is an insert to a variable element, we don't know what it is. +    if (!isa<ConstantInt>(III->getOperand(2))) +      return nullptr; +    unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue(); + +    // If this is an insert to the element we are looking for, return the +    // inserted value. +    if (EltNo == IIElt) +      return III->getOperand(1); + +    // Otherwise, the insertelement doesn't modify the value, recurse on its +    // vector input. +    return findScalarElement(III->getOperand(0), EltNo); +  } + +  if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) { +    unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements(); +    int InEl = SVI->getMaskValue(EltNo); +    if (InEl < 0) +      return UndefValue::get(VTy->getElementType()); +    if (InEl < (int)LHSWidth) +      return findScalarElement(SVI->getOperand(0), InEl); +    return findScalarElement(SVI->getOperand(1), InEl - LHSWidth); +  } + +  // Extract a value from a vector add operation with a constant zero. +  Value *Val = nullptr; Constant *Con = nullptr; +  if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) +    if (Constant *Elt = Con->getAggregateElement(EltNo)) +      if (Elt->isNullValue()) +        return findScalarElement(Val, EltNo); + +  // Otherwise, we don't know. +  return nullptr; +} + +/// Get splat value if the input is a splat vector or return nullptr. +/// This function is not fully general. It checks only 2 cases: +/// the input value is (1) a splat constants vector or (2) a sequence +/// of instructions that broadcast a single value into a vector. +/// +const llvm::Value *llvm::getSplatValue(const Value *V) { + +  if (auto *C = dyn_cast<Constant>(V)) +    if (isa<VectorType>(V->getType())) +      return C->getSplatValue(); + +  auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V); +  if (!ShuffleInst) +    return nullptr; +  // All-zero (or undef) shuffle mask elements. +  for (int MaskElt : ShuffleInst->getShuffleMask()) +    if (MaskElt != 0 && MaskElt != -1) +      return nullptr; +  // The first shuffle source is 'insertelement' with index 0. +  auto *InsertEltInst = +    dyn_cast<InsertElementInst>(ShuffleInst->getOperand(0)); +  if (!InsertEltInst || !isa<ConstantInt>(InsertEltInst->getOperand(2)) || +      !cast<ConstantInt>(InsertEltInst->getOperand(2))->isZero()) +    return nullptr; + +  return InsertEltInst->getOperand(1); +} + +MapVector<Instruction *, uint64_t> +llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, +                               const TargetTransformInfo *TTI) { + +  // DemandedBits will give us every value's live-out bits. But we want +  // to ensure no extra casts would need to be inserted, so every DAG +  // of connected values must have the same minimum bitwidth. +  EquivalenceClasses<Value *> ECs; +  SmallVector<Value *, 16> Worklist; +  SmallPtrSet<Value *, 4> Roots; +  SmallPtrSet<Value *, 16> Visited; +  DenseMap<Value *, uint64_t> DBits; +  SmallPtrSet<Instruction *, 4> InstructionSet; +  MapVector<Instruction *, uint64_t> MinBWs; + +  // Determine the roots. We work bottom-up, from truncs or icmps. +  bool SeenExtFromIllegalType = false; +  for (auto *BB : Blocks) +    for (auto &I : *BB) { +      InstructionSet.insert(&I); + +      if (TTI && (isa<ZExtInst>(&I) || isa<SExtInst>(&I)) && +          !TTI->isTypeLegal(I.getOperand(0)->getType())) +        SeenExtFromIllegalType = true; + +      // Only deal with non-vector integers up to 64-bits wide. +      if ((isa<TruncInst>(&I) || isa<ICmpInst>(&I)) && +          !I.getType()->isVectorTy() && +          I.getOperand(0)->getType()->getScalarSizeInBits() <= 64) { +        // Don't make work for ourselves. If we know the loaded type is legal, +        // don't add it to the worklist. +        if (TTI && isa<TruncInst>(&I) && TTI->isTypeLegal(I.getType())) +          continue; + +        Worklist.push_back(&I); +        Roots.insert(&I); +      } +    } +  // Early exit. +  if (Worklist.empty() || (TTI && !SeenExtFromIllegalType)) +    return MinBWs; + +  // Now proceed breadth-first, unioning values together. +  while (!Worklist.empty()) { +    Value *Val = Worklist.pop_back_val(); +    Value *Leader = ECs.getOrInsertLeaderValue(Val); + +    if (Visited.count(Val)) +      continue; +    Visited.insert(Val); + +    // Non-instructions terminate a chain successfully. +    if (!isa<Instruction>(Val)) +      continue; +    Instruction *I = cast<Instruction>(Val); + +    // If we encounter a type that is larger than 64 bits, we can't represent +    // it so bail out. +    if (DB.getDemandedBits(I).getBitWidth() > 64) +      return MapVector<Instruction *, uint64_t>(); + +    uint64_t V = DB.getDemandedBits(I).getZExtValue(); +    DBits[Leader] |= V; +    DBits[I] = V; + +    // Casts, loads and instructions outside of our range terminate a chain +    // successfully. +    if (isa<SExtInst>(I) || isa<ZExtInst>(I) || isa<LoadInst>(I) || +        !InstructionSet.count(I)) +      continue; + +    // Unsafe casts terminate a chain unsuccessfully. We can't do anything +    // useful with bitcasts, ptrtoints or inttoptrs and it'd be unsafe to +    // transform anything that relies on them. +    if (isa<BitCastInst>(I) || isa<PtrToIntInst>(I) || isa<IntToPtrInst>(I) || +        !I->getType()->isIntegerTy()) { +      DBits[Leader] |= ~0ULL; +      continue; +    } + +    // We don't modify the types of PHIs. Reductions will already have been +    // truncated if possible, and inductions' sizes will have been chosen by +    // indvars. +    if (isa<PHINode>(I)) +      continue; + +    if (DBits[Leader] == ~0ULL) +      // All bits demanded, no point continuing. +      continue; + +    for (Value *O : cast<User>(I)->operands()) { +      ECs.unionSets(Leader, O); +      Worklist.push_back(O); +    } +  } + +  // Now we've discovered all values, walk them to see if there are +  // any users we didn't see. If there are, we can't optimize that +  // chain. +  for (auto &I : DBits) +    for (auto *U : I.first->users()) +      if (U->getType()->isIntegerTy() && DBits.count(U) == 0) +        DBits[ECs.getOrInsertLeaderValue(I.first)] |= ~0ULL; + +  for (auto I = ECs.begin(), E = ECs.end(); I != E; ++I) { +    uint64_t LeaderDemandedBits = 0; +    for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) +      LeaderDemandedBits |= DBits[*MI]; + +    uint64_t MinBW = (sizeof(LeaderDemandedBits) * 8) - +                     llvm::countLeadingZeros(LeaderDemandedBits); +    // Round up to a power of 2 +    if (!isPowerOf2_64((uint64_t)MinBW)) +      MinBW = NextPowerOf2(MinBW); + +    // We don't modify the types of PHIs. Reductions will already have been +    // truncated if possible, and inductions' sizes will have been chosen by +    // indvars. +    // If we are required to shrink a PHI, abandon this entire equivalence class. +    bool Abort = false; +    for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) +      if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) { +        Abort = true; +        break; +      } +    if (Abort) +      continue; + +    for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { +      if (!isa<Instruction>(*MI)) +        continue; +      Type *Ty = (*MI)->getType(); +      if (Roots.count(*MI)) +        Ty = cast<Instruction>(*MI)->getOperand(0)->getType(); +      if (MinBW < Ty->getScalarSizeInBits()) +        MinBWs[cast<Instruction>(*MI)] = MinBW; +    } +  } + +  return MinBWs; +} + +/// \returns \p I after propagating metadata from \p VL. +Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { +  Instruction *I0 = cast<Instruction>(VL[0]); +  SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; +  I0->getAllMetadataOtherThanDebugLoc(Metadata); + +  for (auto Kind : +       {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, +        LLVMContext::MD_noalias, LLVMContext::MD_fpmath, +        LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load}) { +    MDNode *MD = I0->getMetadata(Kind); + +    for (int J = 1, E = VL.size(); MD && J != E; ++J) { +      const Instruction *IJ = cast<Instruction>(VL[J]); +      MDNode *IMD = IJ->getMetadata(Kind); +      switch (Kind) { +      case LLVMContext::MD_tbaa: +        MD = MDNode::getMostGenericTBAA(MD, IMD); +        break; +      case LLVMContext::MD_alias_scope: +        MD = MDNode::getMostGenericAliasScope(MD, IMD); +        break; +      case LLVMContext::MD_fpmath: +        MD = MDNode::getMostGenericFPMath(MD, IMD); +        break; +      case LLVMContext::MD_noalias: +      case LLVMContext::MD_nontemporal: +      case LLVMContext::MD_invariant_load: +        MD = MDNode::intersect(MD, IMD); +        break; +      default: +        llvm_unreachable("unhandled metadata"); +      } +    } + +    Inst->setMetadata(Kind, MD); +  } + +  return Inst; +} + +Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, +                                     unsigned NumVecs) { +  SmallVector<Constant *, 16> Mask; +  for (unsigned i = 0; i < VF; i++) +    for (unsigned j = 0; j < NumVecs; j++) +      Mask.push_back(Builder.getInt32(j * VF + i)); + +  return ConstantVector::get(Mask); +} + +Constant *llvm::createStrideMask(IRBuilder<> &Builder, unsigned Start, +                                 unsigned Stride, unsigned VF) { +  SmallVector<Constant *, 16> Mask; +  for (unsigned i = 0; i < VF; i++) +    Mask.push_back(Builder.getInt32(Start + i * Stride)); + +  return ConstantVector::get(Mask); +} + +Constant *llvm::createSequentialMask(IRBuilder<> &Builder, unsigned Start, +                                     unsigned NumInts, unsigned NumUndefs) { +  SmallVector<Constant *, 16> Mask; +  for (unsigned i = 0; i < NumInts; i++) +    Mask.push_back(Builder.getInt32(Start + i)); + +  Constant *Undef = UndefValue::get(Builder.getInt32Ty()); +  for (unsigned i = 0; i < NumUndefs; i++) +    Mask.push_back(Undef); + +  return ConstantVector::get(Mask); +} + +/// A helper function for concatenating vectors. This function concatenates two +/// vectors having the same element type. If the second vector has fewer +/// elements than the first, it is padded with undefs. +static Value *concatenateTwoVectors(IRBuilder<> &Builder, Value *V1, +                                    Value *V2) { +  VectorType *VecTy1 = dyn_cast<VectorType>(V1->getType()); +  VectorType *VecTy2 = dyn_cast<VectorType>(V2->getType()); +  assert(VecTy1 && VecTy2 && +         VecTy1->getScalarType() == VecTy2->getScalarType() && +         "Expect two vectors with the same element type"); + +  unsigned NumElts1 = VecTy1->getNumElements(); +  unsigned NumElts2 = VecTy2->getNumElements(); +  assert(NumElts1 >= NumElts2 && "Unexpect the first vector has less elements"); + +  if (NumElts1 > NumElts2) { +    // Extend with UNDEFs. +    Constant *ExtMask = +        createSequentialMask(Builder, 0, NumElts2, NumElts1 - NumElts2); +    V2 = Builder.CreateShuffleVector(V2, UndefValue::get(VecTy2), ExtMask); +  } + +  Constant *Mask = createSequentialMask(Builder, 0, NumElts1 + NumElts2, 0); +  return Builder.CreateShuffleVector(V1, V2, Mask); +} + +Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { +  unsigned NumVecs = Vecs.size(); +  assert(NumVecs > 1 && "Should be at least two vectors"); + +  SmallVector<Value *, 8> ResList; +  ResList.append(Vecs.begin(), Vecs.end()); +  do { +    SmallVector<Value *, 8> TmpList; +    for (unsigned i = 0; i < NumVecs - 1; i += 2) { +      Value *V0 = ResList[i], *V1 = ResList[i + 1]; +      assert((V0->getType() == V1->getType() || i == NumVecs - 2) && +             "Only the last vector may have a different type"); + +      TmpList.push_back(concatenateTwoVectors(Builder, V0, V1)); +    } + +    // Push the last vector if the total number of vectors is odd. +    if (NumVecs % 2 != 0) +      TmpList.push_back(ResList[NumVecs - 1]); + +    ResList = TmpList; +    NumVecs = ResList.size(); +  } while (NumVecs > 1); + +  return ResList[0]; +}  | 
