diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar')
61 files changed, 6579 insertions, 4020 deletions
diff --git a/llvm/lib/Transforms/Scalar/ADCE.cpp b/llvm/lib/Transforms/Scalar/ADCE.cpp index cc3d3bf7cdbf..c3709b9afffb 100644 --- a/llvm/lib/Transforms/Scalar/ADCE.cpp +++ b/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -182,7 +182,7 @@ class AggressiveDeadCodeElimination {    /// Identify connected sections of the control flow graph which have    /// dead terminators and rewrite the control flow graph to remove them. -  void updateDeadRegions(); +  bool updateDeadRegions();    /// Set the BlockInfo::PostOrder field based on a post-order    /// numbering of the reverse control flow graph. @@ -505,7 +505,7 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() {  //===----------------------------------------------------------------------===//  bool AggressiveDeadCodeElimination::removeDeadInstructions() {    // Updates control and dataflow around dead blocks -  updateDeadRegions(); +  bool RegionsUpdated = updateDeadRegions();    LLVM_DEBUG({      for (Instruction &I : instructions(F)) { @@ -556,11 +556,11 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() {      I->eraseFromParent();    } -  return !Worklist.empty(); +  return !Worklist.empty() || RegionsUpdated;  }  // A dead region is the set of dead blocks with a common live post-dominator. -void AggressiveDeadCodeElimination::updateDeadRegions() { +bool AggressiveDeadCodeElimination::updateDeadRegions() {    LLVM_DEBUG({      dbgs() << "final dead terminator blocks: " << '\n';      for (auto *BB : BlocksWithDeadTerminators) @@ -570,6 +570,7 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {    // Don't compute the post ordering unless we needed it.    bool HavePostOrder = false; +  bool Changed = false;    for (auto *BB : BlocksWithDeadTerminators) {      auto &Info = BlockInfo[BB]; @@ -624,7 +625,10 @@ void AggressiveDeadCodeElimination::updateDeadRegions() {          .applyUpdates(DeletedEdges);      NumBranchesRemoved += 1; +    Changed = true;    } + +  return Changed;  }  // reverse top-sort order @@ -685,10 +689,14 @@ PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) {      return PreservedAnalyses::all();    PreservedAnalyses PA; -  PA.preserveSet<CFGAnalyses>(); +  // TODO: We could track if we have actually done CFG changes. +  if (!RemoveControlFlowFlag) +    PA.preserveSet<CFGAnalyses>(); +  else { +    PA.preserve<DominatorTreeAnalysis>(); +    PA.preserve<PostDominatorTreeAnalysis>(); +  }    PA.preserve<GlobalsAA>(); -  PA.preserve<DominatorTreeAnalysis>(); -  PA.preserve<PostDominatorTreeAnalysis>();    return PA;  } diff --git a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp index 06deaf3c4f9a..bccf94fc217f 100644 --- a/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp +++ b/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -15,6 +15,7 @@  //  //===----------------------------------------------------------------------===// +#include "llvm/IR/Instructions.h"  #include "llvm/InitializePasses.h"  #define AA_NAME "alignment-from-assumptions"  #define DEBUG_TYPE AA_NAME @@ -30,6 +31,7 @@  #include "llvm/IR/Constant.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h"  #include "llvm/IR/Intrinsics.h"  #include "llvm/IR/Module.h"  #include "llvm/Support/Debug.h" @@ -90,9 +92,9 @@ FunctionPass *llvm::createAlignmentFromAssumptionsPass() {  // to a constant. Using SCEV to compute alignment handles the case where  // DiffSCEV is a recurrence with constant start such that the aligned offset  // is constant. e.g. {16,+,32} % 32 -> 16. -static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, -                                    const SCEV *AlignSCEV, -                                    ScalarEvolution *SE) { +static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV, +                                      const SCEV *AlignSCEV, +                                      ScalarEvolution *SE) {    // DiffUnits = Diff % int64_t(Alignment)    const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV); @@ -107,26 +109,30 @@ static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV,      // displaced pointer has the same alignment as the aligned pointer, so      // return the alignment value.      if (!DiffUnits) -      return (unsigned) -        cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue(); +      return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();      // If the displacement is not an exact multiple, but the remainder is a      // constant, then return this remainder (but only if it is a power of 2).      uint64_t DiffUnitsAbs = std::abs(DiffUnits);      if (isPowerOf2_64(DiffUnitsAbs)) -      return (unsigned) DiffUnitsAbs; +      return Align(DiffUnitsAbs);    } -  return 0; +  return None;  }  // There is an address given by an offset OffSCEV from AASCEV which has an  // alignment AlignSCEV. Use that information, if possible, to compute a new  // alignment for Ptr. -static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, -                                const SCEV *OffSCEV, Value *Ptr, -                                ScalarEvolution *SE) { +static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, +                             const SCEV *OffSCEV, Value *Ptr, +                             ScalarEvolution *SE) {    const SCEV *PtrSCEV = SE->getSCEV(Ptr); +  // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes +  // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV +  // may disagree. Trunc/extend so they agree. +  PtrSCEV = SE->getTruncateOrZeroExtend( +      PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));    const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);    // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always @@ -141,13 +147,12 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,                      << *AlignSCEV << " and offset " << *OffSCEV                      << " using diff " << *DiffSCEV << "\n"); -  unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE); -  LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n"); +  if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) { +    LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n"); +    return *NewAlignment; +  } -  if (NewAlignment) { -    return NewAlignment; -  } else if (const SCEVAddRecExpr *DiffARSCEV = -             dyn_cast<SCEVAddRecExpr>(DiffSCEV)) { +  if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {      // The relative offset to the alignment assumption did not yield a constant,      // but we should try harder: if we assume that a is 32-byte aligned, then in      // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are @@ -165,134 +170,67 @@ static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,      // first iteration, and also the alignment using the per-iteration delta.      // If these are the same, then use that answer. Otherwise, use the smaller      // one, but only if it divides the larger one. -    NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); -    unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); - -    LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n"); -    LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n"); - -    if (!NewAlignment || !NewIncAlignment) { -      return 0; -    } else if (NewAlignment > NewIncAlignment) { -      if (NewAlignment % NewIncAlignment == 0) { -        LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment -                          << "\n"); -        return NewIncAlignment; -      } -    } else if (NewIncAlignment > NewAlignment) { -      if (NewIncAlignment % NewAlignment == 0) { -        LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment -                          << "\n"); -        return NewAlignment; -      } -    } else if (NewIncAlignment == NewAlignment) { -      LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment +    MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); +    MaybeAlign NewIncAlignment = +        getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); + +    LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment) +                      << "\n"); +    LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment) +                      << "\n"); + +    if (!NewAlignment || !NewIncAlignment) +      return Align(1); + +    const Align NewAlign = *NewAlignment; +    const Align NewIncAlign = *NewIncAlignment; +    if (NewAlign > NewIncAlign) { +      LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " +                        << DebugStr(NewIncAlign) << "\n"); +      return NewIncAlign; +    } +    if (NewIncAlign > NewAlign) { +      LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)                          << "\n"); -      return NewAlignment; +      return NewAlign;      } +    assert(NewIncAlign == NewAlign); +    LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign) +                      << "\n"); +    return NewAlign;    } -  return 0; +  return Align(1);  }  bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I, +                                                        unsigned Idx,                                                          Value *&AAPtr,                                                          const SCEV *&AlignSCEV,                                                          const SCEV *&OffSCEV) { -  // An alignment assume must be a statement about the least-significant -  // bits of the pointer being zero, possibly with some offset. -  ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0)); -  if (!ICI) -    return false; - -  // This must be an expression of the form: x & m == 0. -  if (ICI->getPredicate() != ICmpInst::ICMP_EQ) -    return false; - -  // Swap things around so that the RHS is 0. -  Value *CmpLHS = ICI->getOperand(0); -  Value *CmpRHS = ICI->getOperand(1); -  const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS); -  const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS); -  if (CmpLHSSCEV->isZero()) -    std::swap(CmpLHS, CmpRHS); -  else if (!CmpRHSSCEV->isZero()) +  Type *Int64Ty = Type::getInt64Ty(I->getContext()); +  OperandBundleUse AlignOB = I->getOperandBundleAt(Idx); +  if (AlignOB.getTagName() != "align")      return false; - -  BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS); -  if (!CmpBO || CmpBO->getOpcode() != Instruction::And) -    return false; - -  // Swap things around so that the right operand of the and is a constant -  // (the mask); we cannot deal with variable masks. -  Value *AndLHS = CmpBO->getOperand(0); -  Value *AndRHS = CmpBO->getOperand(1); -  const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS); -  const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS); -  if (isa<SCEVConstant>(AndLHSSCEV)) { -    std::swap(AndLHS, AndRHS); -    std::swap(AndLHSSCEV, AndRHSSCEV); -  } - -  const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV); -  if (!MaskSCEV) -    return false; - -  // The mask must have some trailing ones (otherwise the condition is -  // trivial and tells us nothing about the alignment of the left operand). -  unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes(); -  if (!TrailingOnes) -    return false; - -  // Cap the alignment at the maximum with which LLVM can deal (and make sure -  // we don't overflow the shift). -  uint64_t Alignment; -  TrailingOnes = std::min(TrailingOnes, -    unsigned(sizeof(unsigned) * CHAR_BIT - 1)); -  Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment); - -  Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext()); -  AlignSCEV = SE->getConstant(Int64Ty, Alignment); - -  // The LHS might be a ptrtoint instruction, or it might be the pointer -  // with an offset. -  AAPtr = nullptr; -  OffSCEV = nullptr; -  if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) { -    AAPtr = PToI->getPointerOperand(); +  assert(AlignOB.Inputs.size() >= 2); +  AAPtr = AlignOB.Inputs[0].get(); +  // TODO: Consider accumulating the offset to the base. +  AAPtr = AAPtr->stripPointerCastsSameRepresentation(); +  AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get()); +  AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty); +  if (AlignOB.Inputs.size() == 3) +    OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get()); +  else      OffSCEV = SE->getZero(Int64Ty); -  } else if (const SCEVAddExpr* AndLHSAddSCEV = -             dyn_cast<SCEVAddExpr>(AndLHSSCEV)) { -    // Try to find the ptrtoint; subtract it and the rest is the offset. -    for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(), -         JE = AndLHSAddSCEV->op_end(); J != JE; ++J) -      if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J)) -        if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) { -          AAPtr = PToI->getPointerOperand(); -          OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J); -          break; -        } -  } - -  if (!AAPtr) -    return false; - -  // Sign extend the offset to 64 bits (so that it is like all of the other -  // expressions). -  unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits(); -  if (OffSCEVBits < 64) -    OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty); -  else if (OffSCEVBits > 64) -    return false; - -  AAPtr = AAPtr->stripPointerCasts(); +  OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);    return true;  } -bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { +bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall, +                                                     unsigned Idx) {    Value *AAPtr;    const SCEV *AlignSCEV, *OffSCEV; -  if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) +  if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV))      return false;    // Skip ConstantPointerNull and UndefValue.  Assumptions on these shouldn't @@ -310,35 +248,38 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {        continue;      if (Instruction *K = dyn_cast<Instruction>(J)) -      if (isValidAssumeForContext(ACall, K, DT))          WorkList.push_back(K);    }    while (!WorkList.empty()) {      Instruction *J = WorkList.pop_back_val(); -      if (LoadInst *LI = dyn_cast<LoadInst>(J)) { -      unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, -        LI->getPointerOperand(), SE); - -      if (NewAlignment > LI->getAlignment()) { -        LI->setAlignment(MaybeAlign(NewAlignment)); +      if (!isValidAssumeForContext(ACall, J, DT)) +        continue; +      Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, +                                           LI->getPointerOperand(), SE); +      if (NewAlignment > LI->getAlign()) { +        LI->setAlignment(NewAlignment);          ++NumLoadAlignChanged;        }      } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) { -      unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, -        SI->getPointerOperand(), SE); - -      if (NewAlignment > SI->getAlignment()) { -        SI->setAlignment(MaybeAlign(NewAlignment)); +      if (!isValidAssumeForContext(ACall, J, DT)) +        continue; +      Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, +                                           SI->getPointerOperand(), SE); +      if (NewAlignment > SI->getAlign()) { +        SI->setAlignment(NewAlignment);          ++NumStoreAlignChanged;        }      } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) { -      unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, -        MI->getDest(), SE); - -      LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";); -      if (NewDestAlignment > MI->getDestAlignment()) { +      if (!isValidAssumeForContext(ACall, J, DT)) +        continue; +      Align NewDestAlignment = +          getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE); + +      LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment) +                        << "\n";); +      if (NewDestAlignment > *MI->getDestAlign()) {          MI->setDestAlignment(NewDestAlignment);          ++NumMemIntAlignChanged;        } @@ -346,12 +287,13 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {        // For memory transfers, there is also a source alignment that        // can be set.        if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { -        unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, -          MTI->getSource(), SE); +        Align NewSrcAlignment = +            getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE); -        LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";); +        LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment) +                          << "\n";); -        if (NewSrcAlignment > MTI->getSourceAlignment()) { +        if (NewSrcAlignment > *MTI->getSourceAlign()) {            MTI->setSourceAlignment(NewSrcAlignment);            ++NumMemIntAlignChanged;          } @@ -363,7 +305,7 @@ bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {      Visited.insert(J);      for (User *UJ : J->users()) {        Instruction *K = cast<Instruction>(UJ); -      if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT)) +      if (!Visited.count(K))          WorkList.push_back(K);      }    } @@ -390,8 +332,11 @@ bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,    bool Changed = false;    for (auto &AssumeVH : AC.assumptions()) -    if (AssumeVH) -      Changed |= processAssumption(cast<CallInst>(AssumeVH)); +    if (AssumeVH) { +      CallInst *Call = cast<CallInst>(AssumeVH); +      for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++) +        Changed |= processAssumption(Call, Idx); +    }    return Changed;  } diff --git a/llvm/lib/Transforms/Scalar/BDCE.cpp b/llvm/lib/Transforms/Scalar/BDCE.cpp index 0fa38fa80b17..767c7656dcfa 100644 --- a/llvm/lib/Transforms/Scalar/BDCE.cpp +++ b/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -9,7 +9,8 @@  // This file implements the Bit-Tracking Dead Code Elimination pass. Some  // instructions (shifts, some ands, ors, etc.) kill some of their input bits.  // We track these dead bits and remove instructions that compute only these -// dead bits. +// dead bits. We also simplify sext that generates unused extension bits, +// converting it to a zext.  //  //===----------------------------------------------------------------------===// @@ -19,6 +20,7 @@  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/DemandedBits.h"  #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/IR/IRBuilder.h"  #include "llvm/IR/InstIterator.h"  #include "llvm/IR/Instructions.h"  #include "llvm/InitializePasses.h" @@ -33,6 +35,8 @@ using namespace llvm;  STATISTIC(NumRemoved, "Number of instructions removed (unused)");  STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)"); +STATISTIC(NumSExt2ZExt, +          "Number of sign extension instructions converted to zero extension");  /// If an instruction is trivialized (dead), then the chain of users of that  /// instruction may need to be cleared of assumptions that can no longer be @@ -102,13 +106,31 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) {          (I.getType()->isIntOrIntVectorTy() &&           DB.getDemandedBits(&I).isNullValue() &&           wouldInstructionBeTriviallyDead(&I))) { -      salvageDebugInfoOrMarkUndef(I); +      salvageDebugInfo(I);        Worklist.push_back(&I);        I.dropAllReferences();        Changed = true;        continue;      } +    // Convert SExt into ZExt if none of the extension bits is required +    if (SExtInst *SE = dyn_cast<SExtInst>(&I)) { +      APInt Demanded = DB.getDemandedBits(SE); +      const uint32_t SrcBitSize = SE->getSrcTy()->getScalarSizeInBits(); +      auto *const DstTy = SE->getDestTy(); +      const uint32_t DestBitSize = DstTy->getScalarSizeInBits(); +      if (Demanded.countLeadingZeros() >= (DestBitSize - SrcBitSize)) { +        clearAssumptionsOfUsers(SE, DB); +        IRBuilder<> Builder(SE); +        I.replaceAllUsesWith( +            Builder.CreateZExt(SE->getOperand(0), DstTy, SE->getName())); +        Worklist.push_back(SE); +        Changed = true; +        NumSExt2ZExt++; +        continue; +      } +    } +      for (Use &U : I.operands()) {        // DemandedBits only detects dead integer uses.        if (!U->getType()->isIntOrIntVectorTy()) diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp index e34c011b1c87..b26bd1114bd4 100644 --- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -85,37 +85,36 @@ static cl::opt<unsigned>                                    "their cost is below DuplicationThreshold"),                           cl::init(5)); -static void addNonNullAttribute(CallSite CS, Value *Op) { +static void addNonNullAttribute(CallBase &CB, Value *Op) {    unsigned ArgNo = 0; -  for (auto &I : CS.args()) { +  for (auto &I : CB.args()) {      if (&*I == Op) -      CS.addParamAttr(ArgNo, Attribute::NonNull); +      CB.addParamAttr(ArgNo, Attribute::NonNull);      ++ArgNo;    }  } -static void setConstantInArgument(CallSite CS, Value *Op, +static void setConstantInArgument(CallBase &CB, Value *Op,                                    Constant *ConstValue) {    unsigned ArgNo = 0; -  for (auto &I : CS.args()) { +  for (auto &I : CB.args()) {      if (&*I == Op) {        // It is possible we have already added the non-null attribute to the        // parameter by using an earlier constraining condition. -      CS.removeParamAttr(ArgNo, Attribute::NonNull); -      CS.setArgument(ArgNo, ConstValue); +      CB.removeParamAttr(ArgNo, Attribute::NonNull); +      CB.setArgOperand(ArgNo, ConstValue);      }      ++ArgNo;    }  } -static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { +static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallBase &CB) {    assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand.");    Value *Op0 = Cmp->getOperand(0);    unsigned ArgNo = 0; -  for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; -       ++I, ++ArgNo) { +  for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I, ++ArgNo) {      // Don't consider constant or arguments that are already known non-null. -    if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull)) +    if (isa<Constant>(*I) || CB.paramHasAttr(ArgNo, Attribute::NonNull))        continue;      if (*I == Op0) @@ -128,8 +127,8 @@ typedef std::pair<ICmpInst *, unsigned> ConditionTy;  typedef SmallVector<ConditionTy, 2> ConditionsTy;  /// If From has a conditional jump to To, add the condition to Conditions, -/// if it is relevant to any argument at CS. -static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To, +/// if it is relevant to any argument at CB. +static void recordCondition(CallBase &CB, BasicBlock *From, BasicBlock *To,                              ConditionsTy &Conditions) {    auto *BI = dyn_cast<BranchInst>(From->getTerminator());    if (!BI || !BI->isConditional()) @@ -142,38 +141,38 @@ static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To,    ICmpInst *Cmp = cast<ICmpInst>(Cond);    if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) -    if (isCondRelevantToAnyCallArgument(Cmp, CS)) +    if (isCondRelevantToAnyCallArgument(Cmp, CB))        Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To                                       ? Pred                                       : Cmp->getInversePredicate()});  } -/// Record ICmp conditions relevant to any argument in CS following Pred's +/// Record ICmp conditions relevant to any argument in CB following Pred's  /// single predecessors. If there are conflicting conditions along a path, like  /// x == 1 and x == 0, the first condition will be used. We stop once we reach  /// an edge to StopAt. -static void recordConditions(CallSite CS, BasicBlock *Pred, +static void recordConditions(CallBase &CB, BasicBlock *Pred,                               ConditionsTy &Conditions, BasicBlock *StopAt) {    BasicBlock *From = Pred;    BasicBlock *To = Pred;    SmallPtrSet<BasicBlock *, 4> Visited;    while (To != StopAt && !Visited.count(From->getSinglePredecessor()) &&           (From = From->getSinglePredecessor())) { -    recordCondition(CS, From, To, Conditions); +    recordCondition(CB, From, To, Conditions);      Visited.insert(From);      To = From;    }  } -static void addConditions(CallSite CS, const ConditionsTy &Conditions) { +static void addConditions(CallBase &CB, const ConditionsTy &Conditions) {    for (auto &Cond : Conditions) {      Value *Arg = Cond.first->getOperand(0);      Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1));      if (Cond.second == ICmpInst::ICMP_EQ) -      setConstantInArgument(CS, Arg, ConstVal); +      setConstantInArgument(CB, Arg, ConstVal);      else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) {        assert(Cond.second == ICmpInst::ICMP_NE); -      addNonNullAttribute(CS, Arg); +      addNonNullAttribute(CB, Arg);      }    }  } @@ -184,17 +183,16 @@ static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) {    return Preds;  } -static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) { -  if (CS.isConvergent() || CS.cannotDuplicate()) +static bool canSplitCallSite(CallBase &CB, TargetTransformInfo &TTI) { +  if (CB.isConvergent() || CB.cannotDuplicate())      return false;    // FIXME: As of now we handle only CallInst. InvokeInst could be handled    // without too much effort. -  Instruction *Instr = CS.getInstruction(); -  if (!isa<CallInst>(Instr)) +  if (!isa<CallInst>(CB))      return false; -  BasicBlock *CallSiteBB = Instr->getParent(); +  BasicBlock *CallSiteBB = CB.getParent();    // Need 2 predecessors and cannot split an edge from an IndirectBrInst.    SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB));    if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) || @@ -212,7 +210,7 @@ static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) {    // corresponding uses will be updated.    unsigned Cost = 0;    for (auto &InstBeforeCall : -       llvm::make_range(CallSiteBB->begin(), Instr->getIterator())) { +       llvm::make_range(CallSiteBB->begin(), CB.getIterator())) {      Cost += TTI.getInstructionCost(&InstBeforeCall,                                     TargetTransformInfo::TCK_CodeSize);      if (Cost >= DuplicationThreshold) @@ -304,24 +302,23 @@ static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI,  /// predecessors, new call-sites with more constrained arguments will be  /// created in createCallSitesOnPredicatedArgument().  static void splitCallSite( -    CallSite CS, +    CallBase &CB,      const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds,      DomTreeUpdater &DTU) { -  Instruction *Instr = CS.getInstruction(); -  BasicBlock *TailBB = Instr->getParent(); -  bool IsMustTailCall = CS.isMustTailCall(); +  BasicBlock *TailBB = CB.getParent(); +  bool IsMustTailCall = CB.isMustTailCall();    PHINode *CallPN = nullptr;    // `musttail` calls must be followed by optional `bitcast`, and `ret`. The    // split blocks will be terminated right after that so there're no users for    // this phi in a `TailBB`. -  if (!IsMustTailCall && !Instr->use_empty()) { -    CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); -    CallPN->setDebugLoc(Instr->getDebugLoc()); +  if (!IsMustTailCall && !CB.use_empty()) { +    CallPN = PHINode::Create(CB.getType(), Preds.size(), "phi.call"); +    CallPN->setDebugLoc(CB.getDebugLoc());    } -  LLVM_DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); +  LLVM_DEBUG(dbgs() << "split call-site : " << CB << " into \n");    assert(Preds.size() == 2 && "The ValueToValueMaps array has size 2.");    // ValueToValueMapTy is neither copy nor moveable, so we use a simple array @@ -330,21 +327,20 @@ static void splitCallSite(    for (unsigned i = 0; i < Preds.size(); i++) {      BasicBlock *PredBB = Preds[i].first;      BasicBlock *SplitBlock = DuplicateInstructionsInSplitBetween( -        TailBB, PredBB, &*std::next(Instr->getIterator()), ValueToValueMaps[i], +        TailBB, PredBB, &*std::next(CB.getIterator()), ValueToValueMaps[i],          DTU);      assert(SplitBlock && "Unexpected new basic block split."); -    Instruction *NewCI = -        &*std::prev(SplitBlock->getTerminator()->getIterator()); -    CallSite NewCS(NewCI); -    addConditions(NewCS, Preds[i].second); +    auto *NewCI = +        cast<CallBase>(&*std::prev(SplitBlock->getTerminator()->getIterator())); +    addConditions(*NewCI, Preds[i].second);      // Handle PHIs used as arguments in the call-site.      for (PHINode &PN : TailBB->phis()) {        unsigned ArgNo = 0; -      for (auto &CI : CS.args()) { +      for (auto &CI : CB.args()) {          if (&*CI == &PN) { -          NewCS.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock)); +          NewCI->setArgOperand(ArgNo, PN.getIncomingValueForBlock(SplitBlock));          }          ++ArgNo;        } @@ -356,7 +352,7 @@ static void splitCallSite(      // Clone and place bitcast and return instructions before `TI`      if (IsMustTailCall) -      copyMustTailReturn(SplitBlock, Instr, NewCI); +      copyMustTailReturn(SplitBlock, &CB, NewCI);    }    NumCallSiteSplit++; @@ -383,7 +379,7 @@ static void splitCallSite(    // Replace users of the original call with a PHI mering call-sites split.    if (CallPN) {      CallPN->insertBefore(OriginalBegin); -    Instr->replaceAllUsesWith(CallPN); +    CB.replaceAllUsesWith(CallPN);    }    // Remove instructions moved to split blocks from TailBB, from the duplicated @@ -393,7 +389,7 @@ static void splitCallSite(    // instruction, so we do not end up deleting them. By using reverse-order, we    // do not introduce unnecessary PHI nodes for def-use chains from the call    // instruction to the beginning of the block. -  auto I = Instr->getReverseIterator(); +  auto I = CB.getReverseIterator();    while (I != TailBB->rend()) {      Instruction *CurrentI = &*I++;      if (!CurrentI->use_empty()) { @@ -418,28 +414,25 @@ static void splitCallSite(  // Return true if the call-site has an argument which is a PHI with only  // constant incoming values. -static bool isPredicatedOnPHI(CallSite CS) { -  Instruction *Instr = CS.getInstruction(); -  BasicBlock *Parent = Instr->getParent(); -  if (Instr != Parent->getFirstNonPHIOrDbg()) +static bool isPredicatedOnPHI(CallBase &CB) { +  BasicBlock *Parent = CB.getParent(); +  if (&CB != Parent->getFirstNonPHIOrDbg())      return false; -  for (auto &BI : *Parent) { -    if (PHINode *PN = dyn_cast<PHINode>(&BI)) { -      for (auto &I : CS.args()) -        if (&*I == PN) { -          assert(PN->getNumIncomingValues() == 2 && -                 "Unexpected number of incoming values"); -          if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1)) -            return false; -          if (PN->getIncomingValue(0) == PN->getIncomingValue(1)) -            continue; -          if (isa<Constant>(PN->getIncomingValue(0)) && -              isa<Constant>(PN->getIncomingValue(1))) -            return true; -        } +  for (auto &PN : Parent->phis()) { +    for (auto &Arg : CB.args()) { +      if (&*Arg != &PN) +        continue; +      assert(PN.getNumIncomingValues() == 2 && +             "Unexpected number of incoming values"); +      if (PN.getIncomingBlock(0) == PN.getIncomingBlock(1)) +        return false; +      if (PN.getIncomingValue(0) == PN.getIncomingValue(1)) +        continue; +      if (isa<Constant>(PN.getIncomingValue(0)) && +          isa<Constant>(PN.getIncomingValue(1))) +        return true;      } -    break;    }    return false;  } @@ -448,20 +441,20 @@ using PredsWithCondsTy = SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2>;  // Check if any of the arguments in CS are predicated on a PHI node and return  // the set of predecessors we should use for splitting. -static PredsWithCondsTy shouldSplitOnPHIPredicatedArgument(CallSite CS) { -  if (!isPredicatedOnPHI(CS)) +static PredsWithCondsTy shouldSplitOnPHIPredicatedArgument(CallBase &CB) { +  if (!isPredicatedOnPHI(CB))      return {}; -  auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); +  auto Preds = getTwoPredecessors(CB.getParent());    return {{Preds[0], {}}, {Preds[1], {}}};  }  // Checks if any of the arguments in CS are predicated in a predecessor and  // returns a list of predecessors with the conditions that hold on their edges  // to CS. -static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS, +static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallBase &CB,                                                          DomTreeUpdater &DTU) { -  auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); +  auto Preds = getTwoPredecessors(CB.getParent());    if (Preds[0] == Preds[1])      return {}; @@ -470,16 +463,16 @@ static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS,    // that node will be the same for all paths to the call site and splitting    // is not beneficial.    assert(DTU.hasDomTree() && "We need a DTU with a valid DT!"); -  auto *CSDTNode = DTU.getDomTree().getNode(CS.getInstruction()->getParent()); +  auto *CSDTNode = DTU.getDomTree().getNode(CB.getParent());    BasicBlock *StopAt = CSDTNode ? CSDTNode->getIDom()->getBlock() : nullptr;    SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS;    for (auto *Pred : make_range(Preds.rbegin(), Preds.rend())) {      ConditionsTy Conditions;      // Record condition on edge BB(CS) <- Pred -    recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); +    recordCondition(CB, Pred, CB.getParent(), Conditions);      // Record conditions following Pred's single predecessors. -    recordConditions(CS, Pred, Conditions, StopAt); +    recordConditions(CB, Pred, Conditions, StopAt);      PredsCS.push_back({Pred, Conditions});    } @@ -491,19 +484,19 @@ static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS,    return PredsCS;  } -static bool tryToSplitCallSite(CallSite CS, TargetTransformInfo &TTI, +static bool tryToSplitCallSite(CallBase &CB, TargetTransformInfo &TTI,                                 DomTreeUpdater &DTU) {    // Check if we can split the call site. -  if (!CS.arg_size() || !canSplitCallSite(CS, TTI)) +  if (!CB.arg_size() || !canSplitCallSite(CB, TTI))      return false; -  auto PredsWithConds = shouldSplitOnPredicatedArgument(CS, DTU); +  auto PredsWithConds = shouldSplitOnPredicatedArgument(CB, DTU);    if (PredsWithConds.empty()) -    PredsWithConds = shouldSplitOnPHIPredicatedArgument(CS); +    PredsWithConds = shouldSplitOnPHIPredicatedArgument(CB);    if (PredsWithConds.empty())      return false; -  splitCallSite(CS, PredsWithConds, DTU); +  splitCallSite(CB, PredsWithConds, DTU);    return true;  } @@ -521,20 +514,19 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI,      // case, IE will be invalidated and we also have to check the current      // terminator.      while (II != IE && &*II != BB.getTerminator()) { -      Instruction *I = &*II++; -      CallSite CS(cast<Value>(I)); -      if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI)) +      CallBase *CB = dyn_cast<CallBase>(&*II++); +      if (!CB || isa<IntrinsicInst>(CB) || isInstructionTriviallyDead(CB, &TLI))          continue; -      Function *Callee = CS.getCalledFunction(); +      Function *Callee = CB->getCalledFunction();        if (!Callee || Callee->isDeclaration())          continue;        // Successful musttail call-site splits result in erased CI and erased BB.        // Check if such path is possible before attempting the splitting. -      bool IsMustTail = CS.isMustTailCall(); +      bool IsMustTail = CB->isMustTailCall(); -      Changed |= tryToSplitCallSite(CS, TTI, DTU); +      Changed |= tryToSplitCallSite(*CB, TTI, DTU);        // There're no interesting instructions after this. The call site        // itself might have been erased on splitting. diff --git a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp index 5bfece010bec..7c14b69d658d 100644 --- a/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -250,7 +250,7 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,    Orders.push_back(Entry);    while (Idx != Orders.size()) {      BasicBlock *Node = Orders[Idx++]; -    for (auto ChildDomNode : DT.getNode(Node)->getChildren()) { +    for (auto ChildDomNode : DT.getNode(Node)->children()) {        if (Candidates.count(ChildDomNode->getBlock()))          Orders.push_back(ChildDomNode->getBlock());      } @@ -363,10 +363,12 @@ void ConstantHoistingPass::collectConstantCandidates(    // instruction and operand index.    if (auto IntrInst = dyn_cast<IntrinsicInst>(Inst))      Cost = TTI->getIntImmCostIntrin(IntrInst->getIntrinsicID(), Idx, -                                    ConstInt->getValue(), ConstInt->getType()); +                                    ConstInt->getValue(), ConstInt->getType(), +                                    TargetTransformInfo::TCK_SizeAndLatency);    else      Cost = TTI->getIntImmCostInst(Inst->getOpcode(), Idx, ConstInt->getValue(), -                                  ConstInt->getType()); +                                  ConstInt->getType(), +                                  TargetTransformInfo::TCK_SizeAndLatency);    // Ignore cheap integer constants.    if (Cost > TargetTransformInfo::TCC_Basic) { @@ -416,7 +418,8 @@ void ConstantHoistingPass::collectConstantCandidates(    // usually lowered to a load from constant pool. Such operation is unlikely    // to be cheaper than compute it by <Base + Offset>, which can be lowered to    // an ADD instruction or folded into Load/Store instruction. -  int Cost = TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy); +  int Cost = TTI->getIntImmCostInst(Instruction::Add, 1, Offset, PtrIntTy, +                                    TargetTransformInfo::TCK_SizeAndLatency);    ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV];    ConstCandMapType::iterator Itr;    bool Inserted; @@ -491,7 +494,7 @@ void ConstantHoistingPass::collectConstantCandidates(      // take constant variables is lower than `TargetTransformInfo::TCC_Basic`.      // So it's safe for us to collect constant candidates from all      // IntrinsicInsts. -    if (canReplaceOperandWithVariable(Inst, Idx) || isa<IntrinsicInst>(Inst)) { +    if (canReplaceOperandWithVariable(Inst, Idx)) {        collectConstantCandidates(ConstCandMap, Inst, Idx);      }    } // end of for all operands @@ -582,7 +585,8 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S,      for (auto User : ConstCand->Uses) {        unsigned Opcode = User.Inst->getOpcode();        unsigned OpndIdx = User.OpndIdx; -      Cost += TTI->getIntImmCostInst(Opcode, OpndIdx, Value, Ty); +      Cost += TTI->getIntImmCostInst(Opcode, OpndIdx, Value, Ty, +                                     TargetTransformInfo::TCK_SizeAndLatency);        LLVM_DEBUG(dbgs() << "Cost: " << Cost << "\n");        for (auto C2 = S; C2 != E; ++C2) { @@ -975,8 +979,8 @@ PreservedAnalyses ConstantHoistingPass::run(Function &F,    auto BFI = ConstHoistWithBlockFrequency                   ? &AM.getResult<BlockFrequencyAnalysis>(F)                   : nullptr; -  auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); -  auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); +  auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); +  auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());    if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock(), PSI))      return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 3435bc7f5eaa..cd2f4ca36f3b 100644 --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -22,7 +22,6 @@  #include "llvm/IR/Attributes.h"  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/CFG.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constant.h"  #include "llvm/IR/ConstantRange.h"  #include "llvm/IR/Constants.h" @@ -125,7 +124,7 @@ Pass *llvm::createCorrelatedValuePropagationPass() {  static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {    if (S->getType()->isVectorTy()) return false; -  if (isa<Constant>(S->getOperand(0))) return false; +  if (isa<Constant>(S->getCondition())) return false;    Constant *C = LVI->getConstant(S->getCondition(), S->getParent(), S);    if (!C) return false; @@ -133,11 +132,7 @@ static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {    ConstantInt *CI = dyn_cast<ConstantInt>(C);    if (!CI) return false; -  Value *ReplaceWith = S->getTrueValue(); -  Value *Other = S->getFalseValue(); -  if (!CI->isOne()) std::swap(ReplaceWith, Other); -  if (ReplaceWith == S) ReplaceWith = UndefValue::get(S->getType()); - +  Value *ReplaceWith = CI->isOne() ? S->getTrueValue() : S->getFalseValue();    S->replaceAllUsesWith(ReplaceWith);    S->eraseFromParent(); @@ -310,9 +305,10 @@ static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) {    // the comparison is testing local values.  While LVI can sometimes reason    // about such cases, it's not its primary purpose.  We do make sure to do    // the block local query for uses from terminator instructions, but that's -  // handled in the code for each terminator. +  // handled in the code for each terminator. As an exception, we allow phi +  // nodes, for which LVI can thread the condition into predecessors.    auto *I = dyn_cast<Instruction>(Op0); -  if (I && I->getParent() == Cmp->getParent()) +  if (I && I->getParent() == Cmp->getParent() && !isa<PHINode>(I))      return false;    LazyValueInfo::Tristate Result = @@ -535,18 +531,18 @@ static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) {  }  /// Infer nonnull attributes for the arguments at the specified callsite. -static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { +static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {    SmallVector<unsigned, 4> ArgNos;    unsigned ArgNo = 0; -  if (auto *WO = dyn_cast<WithOverflowInst>(CS.getInstruction())) { +  if (auto *WO = dyn_cast<WithOverflowInst>(&CB)) {      if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) {        processOverflowIntrinsic(WO, LVI);        return true;      }    } -  if (auto *SI = dyn_cast<SaturatingInst>(CS.getInstruction())) { +  if (auto *SI = dyn_cast<SaturatingInst>(&CB)) {      if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) {        processSaturatingInst(SI, LVI);        return true; @@ -559,8 +555,8 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {    // desireable since it may allow further optimization of that value (e.g. via    // single use rules in instcombine).  Since deopt uses tend to,    // idiomatically, appear along rare conditional paths, it's reasonable likely -  // we may have a conditional fact with which LVI can fold.    -  if (auto DeoptBundle = CS.getOperandBundle(LLVMContext::OB_deopt)) { +  // we may have a conditional fact with which LVI can fold. +  if (auto DeoptBundle = CB.getOperandBundle(LLVMContext::OB_deopt)) {      bool Progress = false;      for (const Use &ConstU : DeoptBundle->Inputs) {        Use &U = const_cast<Use&>(ConstU); @@ -568,7 +564,7 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {        if (V->getType()->isVectorTy()) continue;        if (isa<Constant>(V)) continue; -      Constant *C = LVI->getConstant(V, CS.getParent(), CS.getInstruction()); +      Constant *C = LVI->getConstant(V, CB.getParent(), &CB);        if (!C) continue;        U.set(C);        Progress = true; @@ -577,30 +573,30 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {        return true;    } -  for (Value *V : CS.args()) { +  for (Value *V : CB.args()) {      PointerType *Type = dyn_cast<PointerType>(V->getType());      // Try to mark pointer typed parameters as non-null.  We skip the      // relatively expensive analysis for constants which are obviously either      // null or non-null to start with. -    if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) && +    if (Type && !CB.paramHasAttr(ArgNo, Attribute::NonNull) &&          !isa<Constant>(V) &&          LVI->getPredicateAt(ICmpInst::ICMP_EQ, V,                              ConstantPointerNull::get(Type), -                            CS.getInstruction()) == LazyValueInfo::False) +                            &CB) == LazyValueInfo::False)        ArgNos.push_back(ArgNo);      ArgNo++;    } -  assert(ArgNo == CS.arg_size() && "sanity check"); +  assert(ArgNo == CB.arg_size() && "sanity check");    if (ArgNos.empty())      return false; -  AttributeList AS = CS.getAttributes(); -  LLVMContext &Ctx = CS.getInstruction()->getContext(); +  AttributeList AS = CB.getAttributes(); +  LLVMContext &Ctx = CB.getContext();    AS = AS.addParamAttribute(Ctx, ArgNos,                              Attribute::get(Ctx, Attribute::NonNull)); -  CS.setAttributes(AS); +  CB.setAttributes(AS);    return true;  } @@ -793,7 +789,10 @@ static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) {    if (!RHS || !RHS->getValue().isMask())      return false; -  ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp); +  // We can only replace the AND with LHS based on range info if the range does +  // not include undef. +  ConstantRange LRange = +      LVI->getConstantRange(LHS, BB, BinOp, /*UndefAllowed=*/false);    if (!LRange.getUnsignedMax().ule(RHS->getValue()))      return false; @@ -856,7 +855,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,          break;        case Instruction::Call:        case Instruction::Invoke: -        BBChanged |= processCallSite(CallSite(II), LVI); +        BBChanged |= processCallSite(cast<CallBase>(*II), LVI);          break;        case Instruction::SRem:          BBChanged |= processSRem(cast<BinaryOperator>(II), LVI); diff --git a/llvm/lib/Transforms/Scalar/DCE.cpp b/llvm/lib/Transforms/Scalar/DCE.cpp index a4b0c8df98f6..28947482e303 100644 --- a/llvm/lib/Transforms/Scalar/DCE.cpp +++ b/llvm/lib/Transforms/Scalar/DCE.cpp @@ -25,6 +25,7 @@  #include "llvm/Pass.h"  #include "llvm/Support/DebugCounter.h"  #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h"  using namespace llvm; @@ -127,6 +128,7 @@ static bool DCEInstruction(Instruction *I,        return false;      salvageDebugInfo(*I); +    salvageKnowledge(I);      // Null out all of the instruction's operands to see if any operand becomes      // dead as we go. diff --git a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp index 1ba4aab999e1..e58db03225ee 100644 --- a/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -18,6 +18,7 @@  #include "llvm/ADT/APInt.h"  #include "llvm/ADT/DenseMap.h"  #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h"  #include "llvm/ADT/SetVector.h"  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/ADT/SmallVector.h" @@ -29,17 +30,19 @@  #include "llvm/Analysis/MemoryBuiltins.h"  #include "llvm/Analysis/MemoryDependenceAnalysis.h"  #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/OrderedBasicBlock.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/Argument.h"  #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h"  #include "llvm/IR/InstrTypes.h"  #include "llvm/IR/Instruction.h"  #include "llvm/IR/Instructions.h" @@ -48,16 +51,19 @@  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Module.h"  #include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h"  #include "llvm/IR/Value.h"  #include "llvm/InitializePasses.h"  #include "llvm/Pass.h"  #include "llvm/Support/Casting.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h"  #include "llvm/Support/ErrorHandling.h"  #include "llvm/Support/MathExtras.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"  #include "llvm/Transforms/Utils/Local.h"  #include <algorithm>  #include <cassert> @@ -68,14 +74,23 @@  #include <utility>  using namespace llvm; +using namespace PatternMatch;  #define DEBUG_TYPE "dse" +STATISTIC(NumRemainingStores, "Number of stores remaining after DSE");  STATISTIC(NumRedundantStores, "Number of redundant stores deleted");  STATISTIC(NumFastStores, "Number of stores deleted");  STATISTIC(NumFastOther, "Number of other instrs removed");  STATISTIC(NumCompletePartials, "Number of stores dead by later partials");  STATISTIC(NumModifiedStores, "Number of stores modified"); +STATISTIC(NumNoopStores, "Number of noop stores deleted"); +STATISTIC(NumCFGChecks, "Number of stores modified"); +STATISTIC(NumCFGTries, "Number of stores modified"); +STATISTIC(NumCFGSuccess, "Number of stores modified"); + +DEBUG_COUNTER(MemorySSACounter, "dse-memoryssa", +              "Controls which MemoryDefs are eliminated.");  static cl::opt<bool>  EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", @@ -87,6 +102,25 @@ EnablePartialStoreMerging("enable-dse-partial-store-merging",    cl::init(true), cl::Hidden,    cl::desc("Enable partial store merging in DSE")); +static cl::opt<bool> +    EnableMemorySSA("enable-dse-memoryssa", cl::init(false), cl::Hidden, +                    cl::desc("Use the new MemorySSA-backed DSE.")); + +static cl::opt<unsigned> +    MemorySSAScanLimit("dse-memoryssa-scanlimit", cl::init(100), cl::Hidden, +                       cl::desc("The number of memory instructions to scan for " +                                "dead store elimination (default = 100)")); + +static cl::opt<unsigned> MemorySSADefsPerBlockLimit( +    "dse-memoryssa-defs-per-block-limit", cl::init(5000), cl::Hidden, +    cl::desc("The number of MemoryDefs we consider as candidates to eliminated " +             "other stores per basic block (default = 5000)")); + +static cl::opt<unsigned> MemorySSAPathCheckLimit( +    "dse-memoryssa-path-check-limit", cl::init(50), cl::Hidden, +    cl::desc("The maximum number of blocks to check when trying to prove that " +             "all paths to an exit go through a killing block (default = 50)")); +  //===----------------------------------------------------------------------===//  // Helper functions  //===----------------------------------------------------------------------===// @@ -100,7 +134,7 @@ using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>;  static void  deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI,                        MemoryDependenceResults &MD, const TargetLibraryInfo &TLI, -                      InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB, +                      InstOverlapIntervalsTy &IOL,                        MapVector<Instruction *, bool> &ThrowableInst,                        SmallSetVector<const Value *, 16> *ValueSet = nullptr) {    SmallVector<Instruction*, 32> NowDeadInsts; @@ -123,6 +157,7 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI,      // Try to preserve debug information attached to the dead instruction.      salvageDebugInfo(*DeadInst); +    salvageKnowledge(DeadInst);      // This instruction is dead, zap it, in stages.  Start by removing it from      // MemDep, which needs to know the operands and needs it to be in the @@ -143,7 +178,6 @@ deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI,      if (ValueSet) ValueSet->remove(DeadInst);      IOL.erase(DeadInst); -    OBB.eraseInstruction(DeadInst);      if (NewIter == DeadInst->getIterator())        NewIter = DeadInst->eraseFromParent(); @@ -177,19 +211,17 @@ static bool hasAnalyzableMemoryWrite(Instruction *I,        return true;      }    } -  if (auto CS = CallSite(I)) { -    if (Function *F = CS.getCalledFunction()) { -      LibFunc LF; -      if (TLI.getLibFunc(*F, LF) && TLI.has(LF)) { -        switch (LF) { -        case LibFunc_strcpy: -        case LibFunc_strncpy: -        case LibFunc_strcat: -        case LibFunc_strncat: -          return true; -        default: -          return false; -        } +  if (auto *CB = dyn_cast<CallBase>(I)) { +    LibFunc LF; +    if (TLI.getLibFunc(*CB, LF) && TLI.has(LF)) { +      switch (LF) { +      case LibFunc_strcpy: +      case LibFunc_strncpy: +      case LibFunc_strcat: +      case LibFunc_strncat: +        return true; +      default: +        return false;        }      }    } @@ -222,10 +254,10 @@ static MemoryLocation getLocForWrite(Instruction *Inst) {      }      }    } -  if (auto CS = CallSite(Inst)) +  if (auto *CB = dyn_cast<CallBase>(Inst))      // All the supported TLI functions so far happen to have dest as their      // first argument. -    return MemoryLocation(CS.getArgument(0)); +    return MemoryLocation(CB->getArgOperand(0));    return MemoryLocation();  } @@ -272,8 +304,8 @@ static bool isRemovable(Instruction *I) {    }    // note: only get here for calls with analyzable writes - i.e. libcalls -  if (auto CS = CallSite(I)) -    return CS.getInstruction()->use_empty(); +  if (auto *CB = dyn_cast<CallBase>(I)) +    return CB->use_empty();    return false;  } @@ -597,51 +629,82 @@ static bool isPossibleSelfRead(Instruction *Inst,  /// instruction.  static bool memoryIsNotModifiedBetween(Instruction *FirstI,                                         Instruction *SecondI, -                                       AliasAnalysis *AA) { -  SmallVector<BasicBlock *, 16> WorkList; -  SmallPtrSet<BasicBlock *, 8> Visited; +                                       AliasAnalysis *AA, +                                       const DataLayout &DL, +                                       DominatorTree *DT) { +  // Do a backwards scan through the CFG from SecondI to FirstI. Look for +  // instructions which can modify the memory location accessed by SecondI. +  // +  // While doing the walk keep track of the address to check. It might be +  // different in different basic blocks due to PHI translation. +  using BlockAddressPair = std::pair<BasicBlock *, PHITransAddr>; +  SmallVector<BlockAddressPair, 16> WorkList; +  // Keep track of the address we visited each block with. Bail out if we +  // visit a block with different addresses. +  DenseMap<BasicBlock *, Value *> Visited; +    BasicBlock::iterator FirstBBI(FirstI);    ++FirstBBI;    BasicBlock::iterator SecondBBI(SecondI);    BasicBlock *FirstBB = FirstI->getParent();    BasicBlock *SecondBB = SecondI->getParent();    MemoryLocation MemLoc = MemoryLocation::get(SecondI); +  auto *MemLocPtr = const_cast<Value *>(MemLoc.Ptr); -  // Start checking the store-block. -  WorkList.push_back(SecondBB); +  // Start checking the SecondBB. +  WorkList.push_back( +      std::make_pair(SecondBB, PHITransAddr(MemLocPtr, DL, nullptr)));    bool isFirstBlock = true; -  // Check all blocks going backward until we reach the load-block. +  // Check all blocks going backward until we reach the FirstBB.    while (!WorkList.empty()) { -    BasicBlock *B = WorkList.pop_back_val(); +    BlockAddressPair Current = WorkList.pop_back_val(); +    BasicBlock *B = Current.first; +    PHITransAddr &Addr = Current.second; +    Value *Ptr = Addr.getAddr(); -    // Ignore instructions before LI if this is the FirstBB. +    // Ignore instructions before FirstI if this is the FirstBB.      BasicBlock::iterator BI = (B == FirstBB ? FirstBBI : B->begin());      BasicBlock::iterator EI;      if (isFirstBlock) { -      // Ignore instructions after SI if this is the first visit of SecondBB. +      // Ignore instructions after SecondI if this is the first visit of SecondBB.        assert(B == SecondBB && "first block is not the store block");        EI = SecondBBI;        isFirstBlock = false;      } else {        // It's not SecondBB or (in case of a loop) the second visit of SecondBB. -      // In this case we also have to look at instructions after SI. +      // In this case we also have to look at instructions after SecondI.        EI = B->end();      }      for (; BI != EI; ++BI) {        Instruction *I = &*BI;        if (I->mayWriteToMemory() && I != SecondI) -        if (isModSet(AA->getModRefInfo(I, MemLoc))) +        if (isModSet(AA->getModRefInfo(I, MemLoc.getWithNewPtr(Ptr))))            return false;      }      if (B != FirstBB) {        assert(B != &FirstBB->getParent()->getEntryBlock() &&            "Should not hit the entry block because SI must be dominated by LI");        for (auto PredI = pred_begin(B), PE = pred_end(B); PredI != PE; ++PredI) { -        if (!Visited.insert(*PredI).second) +        PHITransAddr PredAddr = Addr; +        if (PredAddr.NeedsPHITranslationFromBlock(B)) { +          if (!PredAddr.IsPotentiallyPHITranslatable()) +            return false; +          if (PredAddr.PHITranslateValue(B, *PredI, DT, false)) +            return false; +        } +        Value *TranslatedPtr = PredAddr.getAddr(); +        auto Inserted = Visited.insert(std::make_pair(*PredI, TranslatedPtr)); +        if (!Inserted.second) { +          // We already visited this block before. If it was with a different +          // address - bail out! +          if (TranslatedPtr != Inserted.first->second) +            return false; +          // ... otherwise just skip it.            continue; -        WorkList.push_back(*PredI); +        } +        WorkList.push_back(std::make_pair(*PredI, PredAddr));        }      }    } @@ -669,7 +732,7 @@ static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks,  static bool handleFree(CallInst *F, AliasAnalysis *AA,                         MemoryDependenceResults *MD, DominatorTree *DT,                         const TargetLibraryInfo *TLI, -                       InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB, +                       InstOverlapIntervalsTy &IOL,                         MapVector<Instruction *, bool> &ThrowableInst) {    bool MadeChange = false; @@ -704,7 +767,7 @@ static bool handleFree(CallInst *F, AliasAnalysis *AA,        // DCE instructions only used to calculate that store.        BasicBlock::iterator BBI(Dependency); -      deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, OBB, +      deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL,                              ThrowableInst);        ++NumFastStores;        MadeChange = true; @@ -762,7 +825,7 @@ static void removeAccessedObjects(const MemoryLocation &LoadedLoc,  static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,                             MemoryDependenceResults *MD,                             const TargetLibraryInfo *TLI, -                           InstOverlapIntervalsTy &IOL, OrderedBasicBlock &OBB, +                           InstOverlapIntervalsTy &IOL,                             MapVector<Instruction *, bool> &ThrowableInst) {    bool MadeChange = false; @@ -785,7 +848,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,    // Treat byval or inalloca arguments the same, stores to them are dead at the    // end of the function.    for (Argument &AI : BB.getParent()->args()) -    if (AI.hasByValOrInAllocaAttr()) +    if (AI.hasPassPointeeByValueAttr())        DeadStackObjects.insert(&AI);    const DataLayout &DL = BB.getModule()->getDataLayout(); @@ -824,7 +887,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,                     << '\n');          // DCE instructions only used to calculate that store. -        deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst, +        deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, ThrowableInst,                                &DeadStackObjects);          ++NumFastStores;          MadeChange = true; @@ -836,7 +899,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA,      if (isInstructionTriviallyDead(&*BBI, TLI)) {        LLVM_DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n  DEAD: "                          << *&*BBI << '\n'); -      deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst, +      deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, ThrowableInst,                              &DeadStackObjects);        ++NumFastOther;        MadeChange = true; @@ -1043,8 +1106,8 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,                                 const DataLayout &DL,                                 const TargetLibraryInfo *TLI,                                 InstOverlapIntervalsTy &IOL, -                               OrderedBasicBlock &OBB, -                               MapVector<Instruction *, bool> &ThrowableInst) { +                               MapVector<Instruction *, bool> &ThrowableInst, +                               DominatorTree *DT) {    // Must be a store instruction.    StoreInst *SI = dyn_cast<StoreInst>(Inst);    if (!SI) @@ -1054,13 +1117,14 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,    // then the store can be removed.    if (LoadInst *DepLoad = dyn_cast<LoadInst>(SI->getValueOperand())) {      if (SI->getPointerOperand() == DepLoad->getPointerOperand() && -        isRemovable(SI) && memoryIsNotModifiedBetween(DepLoad, SI, AA)) { +        isRemovable(SI) && +        memoryIsNotModifiedBetween(DepLoad, SI, AA, DL, DT)) {        LLVM_DEBUG(            dbgs() << "DSE: Remove Store Of Load from same pointer:\n  LOAD: "                   << *DepLoad << "\n  STORE: " << *SI << '\n'); -      deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst); +      deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, ThrowableInst);        ++NumRedundantStores;        return true;      } @@ -1073,12 +1137,12 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,          dyn_cast<Instruction>(GetUnderlyingObject(SI->getPointerOperand(), DL));      if (UnderlyingPointer && isCallocLikeFn(UnderlyingPointer, TLI) && -        memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA)) { +        memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA, DL, DT)) {        LLVM_DEBUG(            dbgs() << "DSE: Remove null store to the calloc'ed object:\n  DEAD: "                   << *Inst << "\n  OBJECT: " << *UnderlyingPointer << '\n'); -      deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, OBB, ThrowableInst); +      deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, ThrowableInst);        ++NumRedundantStores;        return true;      } @@ -1086,13 +1150,58 @@ static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI,    return false;  } +static Constant * +tryToMergePartialOverlappingStores(StoreInst *Earlier, StoreInst *Later, +                                   int64_t InstWriteOffset, +                                   int64_t DepWriteOffset, const DataLayout &DL, +                                   AliasAnalysis *AA, DominatorTree *DT) { + +  if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && +      DL.typeSizeEqualsStoreSize(Earlier->getValueOperand()->getType()) && +      Later && isa<ConstantInt>(Later->getValueOperand()) && +      DL.typeSizeEqualsStoreSize(Later->getValueOperand()->getType()) && +      memoryIsNotModifiedBetween(Earlier, Later, AA, DL, DT)) { +    // If the store we find is: +    //   a) partially overwritten by the store to 'Loc' +    //   b) the later store is fully contained in the earlier one and +    //   c) they both have a constant value +    //   d) none of the two stores need padding +    // Merge the two stores, replacing the earlier store's value with a +    // merge of both values. +    // TODO: Deal with other constant types (vectors, etc), and probably +    // some mem intrinsics (if needed) + +    APInt EarlierValue = +        cast<ConstantInt>(Earlier->getValueOperand())->getValue(); +    APInt LaterValue = cast<ConstantInt>(Later->getValueOperand())->getValue(); +    unsigned LaterBits = LaterValue.getBitWidth(); +    assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); +    LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); + +    // Offset of the smaller store inside the larger store +    unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; +    unsigned LShiftAmount = DL.isBigEndian() ? EarlierValue.getBitWidth() - +                                                   BitOffsetDiff - LaterBits +                                             : BitOffsetDiff; +    APInt Mask = APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, +                                   LShiftAmount + LaterBits); +    // Clear the bits we'll be replacing, then OR with the smaller +    // store, shifted appropriately. +    APInt Merged = (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); +    LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n  Earlier: " << *Earlier +                      << "\n  Later: " << *Later +                      << "\n  Merged Value: " << Merged << '\n'); +    return ConstantInt::get(Earlier->getValueOperand()->getType(), Merged); +  } +  return nullptr; +} +  static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,                                  MemoryDependenceResults *MD, DominatorTree *DT,                                  const TargetLibraryInfo *TLI) {    const DataLayout &DL = BB.getModule()->getDataLayout();    bool MadeChange = false; -  OrderedBasicBlock OBB(&BB);    MapVector<Instruction *, bool> ThrowableInst;    // A map of interval maps representing partially-overwritten value parts. @@ -1102,7 +1211,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,    for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) {      // Handle 'free' calls specially.      if (CallInst *F = isFreeCall(&*BBI, TLI)) { -      MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, OBB, ThrowableInst); +      MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, ThrowableInst);        // Increment BBI after handleFree has potentially deleted instructions.        // This ensures we maintain a valid iterator.        ++BBI; @@ -1121,14 +1230,14 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,        continue;      // eliminateNoopStore will update in iterator, if necessary. -    if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, OBB, -                           ThrowableInst)) { +    if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, +                           ThrowableInst, DT)) {        MadeChange = true;        continue;      }      // If we find something that writes memory, get its memory dependence. -    MemDepResult InstDep = MD->getDependency(Inst, &OBB); +    MemDepResult InstDep = MD->getDependency(Inst);      // Ignore any store where we can't find a local dependence.      // FIXME: cross-block DSE would be fun. :) @@ -1179,7 +1288,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,        // If the underlying object is a non-escaping memory allocation, any store        // to it is dead along the unwind edge. Otherwise, we need to preserve        // the store. -      if (LastThrowing && OBB.dominates(DepWrite, LastThrowing)) { +      if (LastThrowing && DepWrite->comesBefore(LastThrowing)) {          const Value* Underlying = GetUnderlyingObject(DepLoc.Ptr, DL);          bool IsStoreDeadOnUnwind = isa<AllocaInst>(Underlying);          if (!IsStoreDeadOnUnwind) { @@ -1210,13 +1319,13 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,                              << "\n  KILLER: " << *Inst << '\n');            // Delete the store and now-dead instructions that feed it. -          deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB, +          deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL,                                  ThrowableInst);            ++NumFastStores;            MadeChange = true;            // We erased DepWrite; start over. -          InstDep = MD->getDependency(Inst, &OBB); +          InstDep = MD->getDependency(Inst);            continue;          } else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) ||                     ((OR == OW_Begin && @@ -1234,53 +1343,12 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,                     OR == OW_PartialEarlierWithFullLater) {            auto *Earlier = dyn_cast<StoreInst>(DepWrite);            auto *Later = dyn_cast<StoreInst>(Inst); -          if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && -              DL.typeSizeEqualsStoreSize( -                  Earlier->getValueOperand()->getType()) && -              Later && isa<ConstantInt>(Later->getValueOperand()) && -              DL.typeSizeEqualsStoreSize( -                  Later->getValueOperand()->getType()) && -              memoryIsNotModifiedBetween(Earlier, Later, AA)) { -            // If the store we find is: -            //   a) partially overwritten by the store to 'Loc' -            //   b) the later store is fully contained in the earlier one and -            //   c) they both have a constant value -            //   d) none of the two stores need padding -            // Merge the two stores, replacing the earlier store's value with a -            // merge of both values. -            // TODO: Deal with other constant types (vectors, etc), and probably -            // some mem intrinsics (if needed) - -            APInt EarlierValue = -                cast<ConstantInt>(Earlier->getValueOperand())->getValue(); -            APInt LaterValue = -                cast<ConstantInt>(Later->getValueOperand())->getValue(); -            unsigned LaterBits = LaterValue.getBitWidth(); -            assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); -            LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); - -            // Offset of the smaller store inside the larger store -            unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; -            unsigned LShiftAmount = -                DL.isBigEndian() -                    ? EarlierValue.getBitWidth() - BitOffsetDiff - LaterBits -                    : BitOffsetDiff; -            APInt Mask = -                APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, -                                  LShiftAmount + LaterBits); -            // Clear the bits we'll be replacing, then OR with the smaller -            // store, shifted appropriately. -            APInt Merged = -                (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); -            LLVM_DEBUG(dbgs() << "DSE: Merge Stores:\n  Earlier: " << *DepWrite -                              << "\n  Later: " << *Inst -                              << "\n  Merged Value: " << Merged << '\n'); - +          if (Constant *C = tryToMergePartialOverlappingStores( +                  Earlier, Later, InstWriteOffset, DepWriteOffset, DL, AA, +                  DT)) {              auto *SI = new StoreInst( -                ConstantInt::get(Earlier->getValueOperand()->getType(), Merged), -                Earlier->getPointerOperand(), false, -                MaybeAlign(Earlier->getAlignment()), Earlier->getOrdering(), -                Earlier->getSyncScopeID(), DepWrite); +                C, Earlier->getPointerOperand(), false, Earlier->getAlign(), +                Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite);              unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa,                                     LLVMContext::MD_alias_scope, @@ -1289,13 +1357,10 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,              SI->copyMetadata(*DepWrite, MDToKeep);              ++NumModifiedStores; -            // Remove earlier, wider, store -            OBB.replaceInstruction(DepWrite, SI); -              // Delete the old stores and now-dead instructions that feed them. -            deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, OBB, +            deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL,                                    ThrowableInst); -            deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, OBB, +            deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL,                                    ThrowableInst);              MadeChange = true; @@ -1331,7 +1396,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA,    // If this block ends in a return, unwind, or unreachable, all allocas are    // dead at its end, which means stores to them are also dead.    if (BB.getTerminator()->getNumSuccessors() == 0) -    MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, OBB, ThrowableInst); +    MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, ThrowableInst);    return MadeChange;  } @@ -1349,22 +1414,913 @@ static bool eliminateDeadStores(Function &F, AliasAnalysis *AA,    return MadeChange;  } +namespace { +//============================================================================= +// MemorySSA backed dead store elimination. +// +// The code below implements dead store elimination using MemorySSA. It uses +// the following general approach: given a MemoryDef, walk upwards to find +// clobbering MemoryDefs that may be killed by the starting def. Then check +// that there are no uses that may read the location of the original MemoryDef +// in between both MemoryDefs. A bit more concretely: +// +// For all MemoryDefs StartDef: +// 1. Get the next dominating clobbering MemoryDef (DomAccess) by walking +//    upwards. +// 2. Check that there are no reads between DomAccess and the StartDef by +//    checking all uses starting at DomAccess and walking until we see StartDef. +// 3. For each found DomDef, check that: +//   1. There are no barrier instructions between DomDef and StartDef (like +//       throws or stores with ordering constraints). +//   2. StartDef is executed whenever DomDef is executed. +//   3. StartDef completely overwrites DomDef. +// 4. Erase DomDef from the function and MemorySSA. + +// Returns true if \p M is an intrisnic that does not read or write memory. +bool isNoopIntrinsic(MemoryUseOrDef *M) { +  if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(M->getMemoryInst())) { +    switch (II->getIntrinsicID()) { +    case Intrinsic::lifetime_start: +    case Intrinsic::lifetime_end: +    case Intrinsic::invariant_end: +    case Intrinsic::launder_invariant_group: +    case Intrinsic::assume: +      return true; +    case Intrinsic::dbg_addr: +    case Intrinsic::dbg_declare: +    case Intrinsic::dbg_label: +    case Intrinsic::dbg_value: +      llvm_unreachable("Intrinsic should not be modeled in MemorySSA"); +    default: +      return false; +    } +  } +  return false; +} + +// Check if we can ignore \p D for DSE. +bool canSkipDef(MemoryDef *D, bool DefVisibleToCaller) { +  Instruction *DI = D->getMemoryInst(); +  // Calls that only access inaccessible memory cannot read or write any memory +  // locations we consider for elimination. +  if (auto *CB = dyn_cast<CallBase>(DI)) +    if (CB->onlyAccessesInaccessibleMemory()) +      return true; + +  // We can eliminate stores to locations not visible to the caller across +  // throwing instructions. +  if (DI->mayThrow() && !DefVisibleToCaller) +    return true; + +  // We can remove the dead stores, irrespective of the fence and its ordering +  // (release/acquire/seq_cst). Fences only constraints the ordering of +  // already visible stores, it does not make a store visible to other +  // threads. So, skipping over a fence does not change a store from being +  // dead. +  if (isa<FenceInst>(DI)) +    return true; + +  // Skip intrinsics that do not really read or modify memory. +  if (isNoopIntrinsic(D)) +    return true; + +  return false; +} + +struct DSEState { +  Function &F; +  AliasAnalysis &AA; +  MemorySSA &MSSA; +  DominatorTree &DT; +  PostDominatorTree &PDT; +  const TargetLibraryInfo &TLI; + +  // All MemoryDefs that potentially could kill other MemDefs. +  SmallVector<MemoryDef *, 64> MemDefs; +  // Any that should be skipped as they are already deleted +  SmallPtrSet<MemoryAccess *, 4> SkipStores; +  // Keep track of all of the objects that are invisible to the caller before +  // the function returns. +  SmallPtrSet<const Value *, 16> InvisibleToCallerBeforeRet; +  // Keep track of all of the objects that are invisible to the caller after +  // the function returns. +  SmallPtrSet<const Value *, 16> InvisibleToCallerAfterRet; +  // Keep track of blocks with throwing instructions not modeled in MemorySSA. +  SmallPtrSet<BasicBlock *, 16> ThrowingBlocks; +  // Post-order numbers for each basic block. Used to figure out if memory +  // accesses are executed before another access. +  DenseMap<BasicBlock *, unsigned> PostOrderNumbers; + +  /// Keep track of instructions (partly) overlapping with killing MemoryDefs per +  /// basic block. +  DenseMap<BasicBlock *, InstOverlapIntervalsTy> IOLs; + +  DSEState(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, DominatorTree &DT, +           PostDominatorTree &PDT, const TargetLibraryInfo &TLI) +      : F(F), AA(AA), MSSA(MSSA), DT(DT), PDT(PDT), TLI(TLI) {} + +  static DSEState get(Function &F, AliasAnalysis &AA, MemorySSA &MSSA, +                      DominatorTree &DT, PostDominatorTree &PDT, +                      const TargetLibraryInfo &TLI) { +    DSEState State(F, AA, MSSA, DT, PDT, TLI); +    // Collect blocks with throwing instructions not modeled in MemorySSA and +    // alloc-like objects. +    unsigned PO = 0; +    for (BasicBlock *BB : post_order(&F)) { +      State.PostOrderNumbers[BB] = PO++; +      for (Instruction &I : *BB) { +        MemoryAccess *MA = MSSA.getMemoryAccess(&I); +        if (I.mayThrow() && !MA) +          State.ThrowingBlocks.insert(I.getParent()); + +        auto *MD = dyn_cast_or_null<MemoryDef>(MA); +        if (MD && State.MemDefs.size() < MemorySSADefsPerBlockLimit && +            (State.getLocForWriteEx(&I) || State.isMemTerminatorInst(&I))) +          State.MemDefs.push_back(MD); + +        // Track whether alloca and alloca-like objects are visible in the +        // caller before and after the function returns. Alloca objects are +        // invalid in the caller, so they are neither visible before or after +        // the function returns. +        if (isa<AllocaInst>(&I)) { +          State.InvisibleToCallerBeforeRet.insert(&I); +          State.InvisibleToCallerAfterRet.insert(&I); +        } + +        // For alloca-like objects we need to check if they are captured before +        // the function returns and if the return might capture the object. +        if (isAllocLikeFn(&I, &TLI)) { +          bool CapturesBeforeRet = PointerMayBeCaptured(&I, false, true); +          if (!CapturesBeforeRet) { +            State.InvisibleToCallerBeforeRet.insert(&I); +            if (!PointerMayBeCaptured(&I, true, false)) +              State.InvisibleToCallerAfterRet.insert(&I); +          } +        } +      } +    } + +    // Treat byval or inalloca arguments the same as Allocas, stores to them are +    // dead at the end of the function. +    for (Argument &AI : F.args()) +      if (AI.hasPassPointeeByValueAttr()) { +        // For byval, the caller doesn't know the address of the allocation. +        if (AI.hasByValAttr()) +          State.InvisibleToCallerBeforeRet.insert(&AI); +        State.InvisibleToCallerAfterRet.insert(&AI); +      } + +    return State; +  } + +  Optional<MemoryLocation> getLocForWriteEx(Instruction *I) const { +    if (!I->mayWriteToMemory()) +      return None; + +    if (auto *MTI = dyn_cast<AnyMemIntrinsic>(I)) +      return {MemoryLocation::getForDest(MTI)}; + +    if (auto *CB = dyn_cast<CallBase>(I)) { +      LibFunc LF; +      if (TLI.getLibFunc(*CB, LF) && TLI.has(LF)) { +        switch (LF) { +        case LibFunc_strcpy: +        case LibFunc_strncpy: +        case LibFunc_strcat: +        case LibFunc_strncat: +          return {MemoryLocation(CB->getArgOperand(0))}; +        default: +          break; +        } +      } +      return None; +    } + +    return MemoryLocation::getOrNone(I); +  } + +  /// Returns true if \p Use completely overwrites \p DefLoc. +  bool isCompleteOverwrite(MemoryLocation DefLoc, Instruction *UseInst) const { +    // UseInst has a MemoryDef associated in MemorySSA. It's possible for a +    // MemoryDef to not write to memory, e.g. a volatile load is modeled as a +    // MemoryDef. +    if (!UseInst->mayWriteToMemory()) +      return false; + +    if (auto *CB = dyn_cast<CallBase>(UseInst)) +      if (CB->onlyAccessesInaccessibleMemory()) +        return false; + +    int64_t InstWriteOffset, DepWriteOffset; +    auto CC = getLocForWriteEx(UseInst); +    InstOverlapIntervalsTy IOL; + +    const DataLayout &DL = F.getParent()->getDataLayout(); + +    return CC && +           isOverwrite(*CC, DefLoc, DL, TLI, DepWriteOffset, InstWriteOffset, +                       UseInst, IOL, AA, &F) == OW_Complete; +  } + +  /// Returns true if \p Def is not read before returning from the function. +  bool isWriteAtEndOfFunction(MemoryDef *Def) { +    LLVM_DEBUG(dbgs() << "  Check if def " << *Def << " (" +                      << *Def->getMemoryInst() +                      << ") is at the end the function \n"); + +    auto MaybeLoc = getLocForWriteEx(Def->getMemoryInst()); +    if (!MaybeLoc) { +      LLVM_DEBUG(dbgs() << "  ... could not get location for write.\n"); +      return false; +    } + +    SmallVector<MemoryAccess *, 4> WorkList; +    SmallPtrSet<MemoryAccess *, 8> Visited; +    auto PushMemUses = [&WorkList, &Visited](MemoryAccess *Acc) { +      if (!Visited.insert(Acc).second) +        return; +      for (Use &U : Acc->uses()) +        WorkList.push_back(cast<MemoryAccess>(U.getUser())); +    }; +    PushMemUses(Def); +    for (unsigned I = 0; I < WorkList.size(); I++) { +      if (WorkList.size() >= MemorySSAScanLimit) { +        LLVM_DEBUG(dbgs() << "  ... hit exploration limit.\n"); +        return false; +      } + +      MemoryAccess *UseAccess = WorkList[I]; +      if (isa<MemoryPhi>(UseAccess)) { +        PushMemUses(UseAccess); +        continue; +      } + +      // TODO: Checking for aliasing is expensive. Consider reducing the amount +      // of times this is called and/or caching it. +      Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst(); +      if (isReadClobber(*MaybeLoc, UseInst)) { +        LLVM_DEBUG(dbgs() << "  ... hit read clobber " << *UseInst << ".\n"); +        return false; +      } + +      if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess)) +        PushMemUses(UseDef); +    } +    return true; +  } + +  /// If \p I is a memory  terminator like llvm.lifetime.end or free, return a +  /// pair with the MemoryLocation terminated by \p I and a boolean flag +  /// indicating whether \p I is a free-like call. +  Optional<std::pair<MemoryLocation, bool>> +  getLocForTerminator(Instruction *I) const { +    uint64_t Len; +    Value *Ptr; +    if (match(I, m_Intrinsic<Intrinsic::lifetime_end>(m_ConstantInt(Len), +                                                      m_Value(Ptr)))) +      return {std::make_pair(MemoryLocation(Ptr, Len), false)}; + +    if (auto *CB = dyn_cast<CallBase>(I)) { +      if (isFreeCall(I, &TLI)) +        return {std::make_pair(MemoryLocation(CB->getArgOperand(0)), true)}; +    } + +    return None; +  } + +  /// Returns true if \p I is a memory terminator instruction like +  /// llvm.lifetime.end or free. +  bool isMemTerminatorInst(Instruction *I) const { +    IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); +    return (II && II->getIntrinsicID() == Intrinsic::lifetime_end) || +           isFreeCall(I, &TLI); +  } + +  /// Returns true if \p MaybeTerm is a memory terminator for the same +  /// underlying object as \p DefLoc. +  bool isMemTerminator(MemoryLocation DefLoc, Instruction *MaybeTerm) const { +    Optional<std::pair<MemoryLocation, bool>> MaybeTermLoc = +        getLocForTerminator(MaybeTerm); + +    if (!MaybeTermLoc) +      return false; + +    // If the terminator is a free-like call, all accesses to the underlying +    // object can be considered terminated. +    if (MaybeTermLoc->second) { +      DataLayout DL = MaybeTerm->getParent()->getModule()->getDataLayout(); +      DefLoc = MemoryLocation(GetUnderlyingObject(DefLoc.Ptr, DL)); +    } +    return AA.isMustAlias(MaybeTermLoc->first, DefLoc); +  } + +  // Returns true if \p Use may read from \p DefLoc. +  bool isReadClobber(MemoryLocation DefLoc, Instruction *UseInst) const { +    if (!UseInst->mayReadFromMemory()) +      return false; + +    if (auto *CB = dyn_cast<CallBase>(UseInst)) +      if (CB->onlyAccessesInaccessibleMemory()) +        return false; + +    ModRefInfo MR = AA.getModRefInfo(UseInst, DefLoc); +    // If necessary, perform additional analysis. +    if (isRefSet(MR)) +      MR = AA.callCapturesBefore(UseInst, DefLoc, &DT); +    return isRefSet(MR); +  } + +  // Find a MemoryDef writing to \p DefLoc and dominating \p Current, with no +  // read access between them or on any other path to a function exit block if +  // \p DefLoc is not accessible after the function returns. If there is no such +  // MemoryDef, return None. The returned value may not (completely) overwrite +  // \p DefLoc. Currently we bail out when we encounter an aliasing MemoryUse +  // (read). +  Optional<MemoryAccess *> +  getDomMemoryDef(MemoryDef *KillingDef, MemoryAccess *Current, +                  MemoryLocation DefLoc, bool DefVisibleToCallerBeforeRet, +                  bool DefVisibleToCallerAfterRet, int &ScanLimit) const { +    MemoryAccess *DomAccess; +    bool StepAgain; +    LLVM_DEBUG(dbgs() << "  trying to get dominating access for " << *Current +                      << "\n"); +    // Find the next clobbering Mod access for DefLoc, starting at Current. +    do { +      StepAgain = false; +      // Reached TOP. +      if (MSSA.isLiveOnEntryDef(Current)) +        return None; + +      if (isa<MemoryPhi>(Current)) { +        DomAccess = Current; +        break; +      } +      MemoryUseOrDef *CurrentUD = cast<MemoryUseOrDef>(Current); +      // Look for access that clobber DefLoc. +      DomAccess = MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(CurrentUD, +                                                                      DefLoc); +      if (MSSA.isLiveOnEntryDef(DomAccess)) +        return None; + +      if (isa<MemoryPhi>(DomAccess)) +        break; + +      // Check if we can skip DomDef for DSE. +      MemoryDef *DomDef = dyn_cast<MemoryDef>(DomAccess); +      if (DomDef && canSkipDef(DomDef, DefVisibleToCallerBeforeRet)) { +        StepAgain = true; +        Current = DomDef->getDefiningAccess(); +      } + +    } while (StepAgain); + +    // Accesses to objects accessible after the function returns can only be +    // eliminated if the access is killed along all paths to the exit. Collect +    // the blocks with killing (=completely overwriting MemoryDefs) and check if +    // they cover all paths from DomAccess to any function exit. +    SmallPtrSet<BasicBlock *, 16> KillingBlocks = {KillingDef->getBlock()}; +    LLVM_DEBUG({ +      dbgs() << "  Checking for reads of " << *DomAccess; +      if (isa<MemoryDef>(DomAccess)) +        dbgs() << " (" << *cast<MemoryDef>(DomAccess)->getMemoryInst() << ")\n"; +      else +        dbgs() << ")\n"; +    }); + +    SmallSetVector<MemoryAccess *, 32> WorkList; +    auto PushMemUses = [&WorkList](MemoryAccess *Acc) { +      for (Use &U : Acc->uses()) +        WorkList.insert(cast<MemoryAccess>(U.getUser())); +    }; +    PushMemUses(DomAccess); + +    // Check if DomDef may be read. +    for (unsigned I = 0; I < WorkList.size(); I++) { +      MemoryAccess *UseAccess = WorkList[I]; + +      LLVM_DEBUG(dbgs() << "   " << *UseAccess); +      if (--ScanLimit == 0) { +        LLVM_DEBUG(dbgs() << "\n    ...  hit scan limit\n"); +        return None; +      } + +      if (isa<MemoryPhi>(UseAccess)) { +        LLVM_DEBUG(dbgs() << "\n    ... adding PHI uses\n"); +        PushMemUses(UseAccess); +        continue; +      } + +      Instruction *UseInst = cast<MemoryUseOrDef>(UseAccess)->getMemoryInst(); +      LLVM_DEBUG(dbgs() << " (" << *UseInst << ")\n"); + +      if (isNoopIntrinsic(cast<MemoryUseOrDef>(UseAccess))) { +        LLVM_DEBUG(dbgs() << "    ... adding uses of intrinsic\n"); +        PushMemUses(UseAccess); +        continue; +      } + +      // A memory terminator kills all preceeding MemoryDefs and all succeeding +      // MemoryAccesses. We do not have to check it's users. +      if (isMemTerminator(DefLoc, UseInst)) +        continue; + +      // Uses which may read the original MemoryDef mean we cannot eliminate the +      // original MD. Stop walk. +      if (isReadClobber(DefLoc, UseInst)) { +        LLVM_DEBUG(dbgs() << "    ... found read clobber\n"); +        return None; +      } + +      // For the KillingDef and DomAccess we only have to check if it reads the +      // memory location. +      // TODO: It would probably be better to check for self-reads before +      // calling the function. +      if (KillingDef == UseAccess || DomAccess == UseAccess) { +        LLVM_DEBUG(dbgs() << "    ... skipping killing def/dom access\n"); +        continue; +      } + +      // Check all uses for MemoryDefs, except for defs completely overwriting +      // the original location. Otherwise we have to check uses of *all* +      // MemoryDefs we discover, including non-aliasing ones. Otherwise we might +      // miss cases like the following +      //   1 = Def(LoE) ; <----- DomDef stores [0,1] +      //   2 = Def(1)   ; (2, 1) = NoAlias,   stores [2,3] +      //   Use(2)       ; MayAlias 2 *and* 1, loads [0, 3]. +      //                  (The Use points to the *first* Def it may alias) +      //   3 = Def(1)   ; <---- Current  (3, 2) = NoAlias, (3,1) = MayAlias, +      //                  stores [0,1] +      if (MemoryDef *UseDef = dyn_cast<MemoryDef>(UseAccess)) { +        if (isCompleteOverwrite(DefLoc, UseInst)) { +          if (DefVisibleToCallerAfterRet && UseAccess != DomAccess) { +            BasicBlock *MaybeKillingBlock = UseInst->getParent(); +            if (PostOrderNumbers.find(MaybeKillingBlock)->second < +                PostOrderNumbers.find(DomAccess->getBlock())->second) { + +              LLVM_DEBUG(dbgs() << "    ... found killing block " +                                << MaybeKillingBlock->getName() << "\n"); +              KillingBlocks.insert(MaybeKillingBlock); +            } +          } +        } else +          PushMemUses(UseDef); +      } +    } + +    // For accesses to locations visible after the function returns, make sure +    // that the location is killed (=overwritten) along all paths from DomAccess +    // to the exit. +    if (DefVisibleToCallerAfterRet) { +      assert(!KillingBlocks.empty() && +             "Expected at least a single killing block"); +      // Find the common post-dominator of all killing blocks. +      BasicBlock *CommonPred = *KillingBlocks.begin(); +      for (auto I = std::next(KillingBlocks.begin()), E = KillingBlocks.end(); +           I != E; I++) { +        if (!CommonPred) +          break; +        CommonPred = PDT.findNearestCommonDominator(CommonPred, *I); +      } + +      // If CommonPred is in the set of killing blocks, just check if it +      // post-dominates DomAccess. +      if (KillingBlocks.count(CommonPred)) { +        if (PDT.dominates(CommonPred, DomAccess->getBlock())) +          return {DomAccess}; +        return None; +      } + +      // If the common post-dominator does not post-dominate DomAccess, there +      // is a path from DomAccess to an exit not going through a killing block. +      if (PDT.dominates(CommonPred, DomAccess->getBlock())) { +        SetVector<BasicBlock *> WorkList; + +        // DomAccess's post-order number provides an upper bound of the blocks +        // on a path starting at DomAccess. +        unsigned UpperBound = +            PostOrderNumbers.find(DomAccess->getBlock())->second; + +        // If CommonPred is null, there are multiple exits from the function. +        // They all have to be added to the worklist. +        if (CommonPred) +          WorkList.insert(CommonPred); +        else +          for (BasicBlock *R : PDT.roots()) +            WorkList.insert(R); + +        NumCFGTries++; +        // Check if all paths starting from an exit node go through one of the +        // killing blocks before reaching DomAccess. +        for (unsigned I = 0; I < WorkList.size(); I++) { +          NumCFGChecks++; +          BasicBlock *Current = WorkList[I]; +          if (KillingBlocks.count(Current)) +            continue; +          if (Current == DomAccess->getBlock()) +            return None; + +          // DomAccess is reachable from the entry, so we don't have to explore +          // unreachable blocks further. +          if (!DT.isReachableFromEntry(Current)) +            continue; + +          unsigned CPO = PostOrderNumbers.find(Current)->second; +          // Current block is not on a path starting at DomAccess. +          if (CPO > UpperBound) +            continue; +          for (BasicBlock *Pred : predecessors(Current)) +            WorkList.insert(Pred); + +          if (WorkList.size() >= MemorySSAPathCheckLimit) +            return None; +        } +        NumCFGSuccess++; +        return {DomAccess}; +      } +      return None; +    } + +    // No aliasing MemoryUses of DomAccess found, DomAccess is potentially dead. +    return {DomAccess}; +  } + +  // Delete dead memory defs +  void deleteDeadInstruction(Instruction *SI) { +    MemorySSAUpdater Updater(&MSSA); +    SmallVector<Instruction *, 32> NowDeadInsts; +    NowDeadInsts.push_back(SI); +    --NumFastOther; + +    while (!NowDeadInsts.empty()) { +      Instruction *DeadInst = NowDeadInsts.pop_back_val(); +      ++NumFastOther; + +      // Try to preserve debug information attached to the dead instruction. +      salvageDebugInfo(*DeadInst); +      salvageKnowledge(DeadInst); + +      // Remove the Instruction from MSSA. +      if (MemoryAccess *MA = MSSA.getMemoryAccess(DeadInst)) { +        if (MemoryDef *MD = dyn_cast<MemoryDef>(MA)) { +          SkipStores.insert(MD); +        } +        Updater.removeMemoryAccess(MA); +      } + +      auto I = IOLs.find(DeadInst->getParent()); +      if (I != IOLs.end()) +        I->second.erase(DeadInst); +      // Remove its operands +      for (Use &O : DeadInst->operands()) +        if (Instruction *OpI = dyn_cast<Instruction>(O)) { +          O = nullptr; +          if (isInstructionTriviallyDead(OpI, &TLI)) +            NowDeadInsts.push_back(OpI); +        } + +      DeadInst->eraseFromParent(); +    } +  } + +  // Check for any extra throws between SI and NI that block DSE.  This only +  // checks extra maythrows (those that aren't MemoryDef's). MemoryDef that may +  // throw are handled during the walk from one def to the next. +  bool mayThrowBetween(Instruction *SI, Instruction *NI, +                       const Value *SILocUnd) const { +    // First see if we can ignore it by using the fact that SI is an +    // alloca/alloca like object that is not visible to the caller during +    // execution of the function. +    if (SILocUnd && InvisibleToCallerBeforeRet.count(SILocUnd)) +      return false; + +    if (SI->getParent() == NI->getParent()) +      return ThrowingBlocks.count(SI->getParent()); +    return !ThrowingBlocks.empty(); +  } + +  // Check if \p NI acts as a DSE barrier for \p SI. The following instructions +  // act as barriers: +  //  * A memory instruction that may throw and \p SI accesses a non-stack +  //  object. +  //  * Atomic stores stronger that monotonic. +  bool isDSEBarrier(const Value *SILocUnd, Instruction *NI) const { +    // If NI may throw it acts as a barrier, unless we are to an alloca/alloca +    // like object that does not escape. +    if (NI->mayThrow() && !InvisibleToCallerBeforeRet.count(SILocUnd)) +      return true; + +    // If NI is an atomic load/store stronger than monotonic, do not try to +    // eliminate/reorder it. +    if (NI->isAtomic()) { +      if (auto *LI = dyn_cast<LoadInst>(NI)) +        return isStrongerThanMonotonic(LI->getOrdering()); +      if (auto *SI = dyn_cast<StoreInst>(NI)) +        return isStrongerThanMonotonic(SI->getOrdering()); +      llvm_unreachable("other instructions should be skipped in MemorySSA"); +    } +    return false; +  } + +  /// Eliminate writes to objects that are not visible in the caller and are not +  /// accessed before returning from the function. +  bool eliminateDeadWritesAtEndOfFunction() { +    const DataLayout &DL = F.getParent()->getDataLayout(); +    bool MadeChange = false; +    LLVM_DEBUG( +        dbgs() +        << "Trying to eliminate MemoryDefs at the end of the function\n"); +    for (int I = MemDefs.size() - 1; I >= 0; I--) { +      MemoryDef *Def = MemDefs[I]; +      if (SkipStores.find(Def) != SkipStores.end() || +          !isRemovable(Def->getMemoryInst())) +        continue; + +      // TODO: Consider doing the underlying object check first, if it is +      // beneficial compile-time wise. +      if (isWriteAtEndOfFunction(Def)) { +        Instruction *DefI = Def->getMemoryInst(); +        // See through pointer-to-pointer bitcasts +        SmallVector<const Value *, 4> Pointers; +        GetUnderlyingObjects(getLocForWriteEx(DefI)->Ptr, Pointers, DL); + +        LLVM_DEBUG(dbgs() << "   ... MemoryDef is not accessed until the end " +                             "of the function\n"); +        bool CanKill = true; +        for (const Value *Pointer : Pointers) { +          if (!InvisibleToCallerAfterRet.count(Pointer)) { +            CanKill = false; +            break; +          } +        } + +        if (CanKill) { +          deleteDeadInstruction(DefI); +          ++NumFastStores; +          MadeChange = true; +        } +      } +    } +    return MadeChange; +  } + +  /// \returns true if \p Def is a no-op store, either because it +  /// directly stores back a loaded value or stores zero to a calloced object. +  bool storeIsNoop(MemoryDef *Def, MemoryLocation DefLoc, const Value *DefUO) { +    StoreInst *Store = dyn_cast<StoreInst>(Def->getMemoryInst()); +    if (!Store) +      return false; + +    if (auto *LoadI = dyn_cast<LoadInst>(Store->getOperand(0))) { +      if (LoadI->getPointerOperand() == Store->getOperand(1)) { +        auto *LoadAccess = MSSA.getMemoryAccess(LoadI)->getDefiningAccess(); +        // If both accesses share the same defining access, no instructions +        // between them can modify the memory location. +        return LoadAccess == Def->getDefiningAccess(); +      } +    } + +    Constant *StoredConstant = dyn_cast<Constant>(Store->getOperand(0)); +    if (StoredConstant && StoredConstant->isNullValue()) { +      auto *DefUOInst = dyn_cast<Instruction>(DefUO); +      if (DefUOInst && isCallocLikeFn(DefUOInst, &TLI)) { +        auto *UnderlyingDef = cast<MemoryDef>(MSSA.getMemoryAccess(DefUOInst)); +        // If UnderlyingDef is the clobbering access of Def, no instructions +        // between them can modify the memory location. +        auto *ClobberDef = +            MSSA.getSkipSelfWalker()->getClobberingMemoryAccess(Def); +        return UnderlyingDef == ClobberDef; +      } +    } +    return false; +  } +}; + +bool eliminateDeadStoresMemorySSA(Function &F, AliasAnalysis &AA, +                                  MemorySSA &MSSA, DominatorTree &DT, +                                  PostDominatorTree &PDT, +                                  const TargetLibraryInfo &TLI) { +  const DataLayout &DL = F.getParent()->getDataLayout(); +  bool MadeChange = false; + +  DSEState State = DSEState::get(F, AA, MSSA, DT, PDT, TLI); +  // For each store: +  for (unsigned I = 0; I < State.MemDefs.size(); I++) { +    MemoryDef *KillingDef = State.MemDefs[I]; +    if (State.SkipStores.count(KillingDef)) +      continue; +    Instruction *SI = KillingDef->getMemoryInst(); + +    auto MaybeSILoc = State.getLocForWriteEx(SI); +    if (State.isMemTerminatorInst(SI)) +      MaybeSILoc = State.getLocForTerminator(SI).map( +          [](const std::pair<MemoryLocation, bool> &P) { return P.first; }); +    else +      MaybeSILoc = State.getLocForWriteEx(SI); + +    if (!MaybeSILoc) { +      LLVM_DEBUG(dbgs() << "Failed to find analyzable write location for " +                        << *SI << "\n"); +      continue; +    } +    MemoryLocation SILoc = *MaybeSILoc; +    assert(SILoc.Ptr && "SILoc should not be null"); +    const Value *SILocUnd = GetUnderlyingObject(SILoc.Ptr, DL); + +    // Check if the store is a no-op. +    if (isRemovable(SI) && State.storeIsNoop(KillingDef, SILoc, SILocUnd)) { +      LLVM_DEBUG(dbgs() << "DSE: Remove No-Op Store:\n  DEAD: " << *SI << '\n'); +      State.deleteDeadInstruction(SI); +      NumNoopStores++; +      MadeChange = true; +      continue; +    } + +    Instruction *DefObj = +        const_cast<Instruction *>(dyn_cast<Instruction>(SILocUnd)); +    bool DefVisibleToCallerBeforeRet = +        !State.InvisibleToCallerBeforeRet.count(SILocUnd); +    bool DefVisibleToCallerAfterRet = +        !State.InvisibleToCallerAfterRet.count(SILocUnd); +    if (DefObj && isAllocLikeFn(DefObj, &TLI)) { +      if (DefVisibleToCallerBeforeRet) +        DefVisibleToCallerBeforeRet = +            PointerMayBeCapturedBefore(DefObj, false, true, SI, &DT); +    } + +    MemoryAccess *Current = KillingDef; +    LLVM_DEBUG(dbgs() << "Trying to eliminate MemoryDefs killed by " +                      << *KillingDef << " (" << *SI << ")\n"); + +    int ScanLimit = MemorySSAScanLimit; +    // Worklist of MemoryAccesses that may be killed by KillingDef. +    SetVector<MemoryAccess *> ToCheck; +    ToCheck.insert(KillingDef->getDefiningAccess()); + +    // Check if MemoryAccesses in the worklist are killed by KillingDef. +    for (unsigned I = 0; I < ToCheck.size(); I++) { +      Current = ToCheck[I]; +      if (State.SkipStores.count(Current)) +        continue; + +      Optional<MemoryAccess *> Next = State.getDomMemoryDef( +          KillingDef, Current, SILoc, DefVisibleToCallerBeforeRet, +          DefVisibleToCallerAfterRet, ScanLimit); + +      if (!Next) { +        LLVM_DEBUG(dbgs() << "  finished walk\n"); +        continue; +      } + +      MemoryAccess *DomAccess = *Next; +      LLVM_DEBUG(dbgs() << " Checking if we can kill " << *DomAccess); +      if (isa<MemoryPhi>(DomAccess)) { +        LLVM_DEBUG(dbgs() << "\n  ... adding incoming values to worklist\n"); +        for (Value *V : cast<MemoryPhi>(DomAccess)->incoming_values()) { +          MemoryAccess *IncomingAccess = cast<MemoryAccess>(V); +          BasicBlock *IncomingBlock = IncomingAccess->getBlock(); +          BasicBlock *PhiBlock = DomAccess->getBlock(); + +          // We only consider incoming MemoryAccesses that come before the +          // MemoryPhi. Otherwise we could discover candidates that do not +          // strictly dominate our starting def. +          if (State.PostOrderNumbers[IncomingBlock] > +              State.PostOrderNumbers[PhiBlock]) +            ToCheck.insert(IncomingAccess); +        } +        continue; +      } +      MemoryDef *NextDef = dyn_cast<MemoryDef>(DomAccess); +      Instruction *NI = NextDef->getMemoryInst(); +      LLVM_DEBUG(dbgs() << " (" << *NI << ")\n"); + +      // Before we try to remove anything, check for any extra throwing +      // instructions that block us from DSEing +      if (State.mayThrowBetween(SI, NI, SILocUnd)) { +        LLVM_DEBUG(dbgs() << "  ... skip, may throw!\n"); +        break; +      } + +      // Check for anything that looks like it will be a barrier to further +      // removal +      if (State.isDSEBarrier(SILocUnd, NI)) { +        LLVM_DEBUG(dbgs() << "  ... skip, barrier\n"); +        continue; +      } + +      ToCheck.insert(NextDef->getDefiningAccess()); + +      if (!hasAnalyzableMemoryWrite(NI, TLI)) { +        LLVM_DEBUG(dbgs() << "  ... skip, cannot analyze def\n"); +        continue; +      } + +      if (!isRemovable(NI)) { +        LLVM_DEBUG(dbgs() << "  ... skip, cannot remove def\n"); +        continue; +      } + +      if (!DebugCounter::shouldExecute(MemorySSACounter)) +        continue; + +      MemoryLocation NILoc = *State.getLocForWriteEx(NI); + +      if (State.isMemTerminatorInst(SI)) { +        const Value *NIUnd = GetUnderlyingObject(NILoc.Ptr, DL); +        if (!SILocUnd || SILocUnd != NIUnd) +          continue; +        LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *NI +                          << "\n  KILLER: " << *SI << '\n'); +        State.deleteDeadInstruction(NI); +        ++NumFastStores; +        MadeChange = true; +      } else { +        // Check if NI overwrites SI. +        int64_t InstWriteOffset, DepWriteOffset; +        auto Iter = State.IOLs.insert( +            std::make_pair<BasicBlock *, InstOverlapIntervalsTy>( +                NI->getParent(), InstOverlapIntervalsTy())); +        auto &IOL = Iter.first->second; +        OverwriteResult OR = isOverwrite(SILoc, NILoc, DL, TLI, DepWriteOffset, +                                         InstWriteOffset, NI, IOL, AA, &F); + +        if (EnablePartialStoreMerging && OR == OW_PartialEarlierWithFullLater) { +          auto *Earlier = dyn_cast<StoreInst>(NI); +          auto *Later = dyn_cast<StoreInst>(SI); +          if (Constant *Merged = tryToMergePartialOverlappingStores( +                  Earlier, Later, InstWriteOffset, DepWriteOffset, DL, &AA, +                  &DT)) { + +            // Update stored value of earlier store to merged constant. +            Earlier->setOperand(0, Merged); +            ++NumModifiedStores; +            MadeChange = true; + +            // Remove later store and remove any outstanding overlap intervals +            // for the updated store. +            State.deleteDeadInstruction(Later); +            auto I = State.IOLs.find(Earlier->getParent()); +            if (I != State.IOLs.end()) +              I->second.erase(Earlier); +            break; +          } +        } + +        if (OR == OW_Complete) { +          LLVM_DEBUG(dbgs() << "DSE: Remove Dead Store:\n  DEAD: " << *NI +                            << "\n  KILLER: " << *SI << '\n'); +          State.deleteDeadInstruction(NI); +          ++NumFastStores; +          MadeChange = true; +        } +      } +    } +  } + +  if (EnablePartialOverwriteTracking) +    for (auto &KV : State.IOLs) +      MadeChange |= removePartiallyOverlappedStores(&AA, DL, KV.second); + +  MadeChange |= State.eliminateDeadWritesAtEndOfFunction(); +  return MadeChange; +} +} // end anonymous namespace +  //===----------------------------------------------------------------------===//  // DSE Pass  //===----------------------------------------------------------------------===//  PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { -  AliasAnalysis *AA = &AM.getResult<AAManager>(F); -  DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); -  MemoryDependenceResults *MD = &AM.getResult<MemoryDependenceAnalysis>(F); -  const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); +  AliasAnalysis &AA = AM.getResult<AAManager>(F); +  const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F); +  DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + +  bool Changed = false; +  if (EnableMemorySSA) { +    MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); +    PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); -  if (!eliminateDeadStores(F, AA, MD, DT, TLI)) +    Changed = eliminateDeadStoresMemorySSA(F, AA, MSSA, DT, PDT, TLI); +  } else { +    MemoryDependenceResults &MD = AM.getResult<MemoryDependenceAnalysis>(F); + +    Changed = eliminateDeadStores(F, &AA, &MD, &DT, &TLI); +  } + +#ifdef LLVM_ENABLE_STATS +  if (AreStatisticsEnabled()) +    for (auto &I : instructions(F)) +      NumRemainingStores += isa<StoreInst>(&I); +#endif + +  if (!Changed)      return PreservedAnalyses::all();    PreservedAnalyses PA;    PA.preserveSet<CFGAnalyses>();    PA.preserve<GlobalsAA>(); -  PA.preserve<MemoryDependenceAnalysis>(); +  if (EnableMemorySSA) +    PA.preserve<MemorySSAAnalysis>(); +  else +    PA.preserve<MemoryDependenceAnalysis>();    return PA;  } @@ -1383,25 +2339,51 @@ public:      if (skipFunction(F))        return false; -    DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); -    AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); -    MemoryDependenceResults *MD = -        &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); -    const TargetLibraryInfo *TLI = -        &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); +    AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); +    DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +    const TargetLibraryInfo &TLI = +        getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); + +    bool Changed = false; +    if (EnableMemorySSA) { +      MemorySSA &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); +      PostDominatorTree &PDT = +          getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + +      Changed = eliminateDeadStoresMemorySSA(F, AA, MSSA, DT, PDT, TLI); +    } else { +      MemoryDependenceResults &MD = +          getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + +      Changed = eliminateDeadStores(F, &AA, &MD, &DT, &TLI); +    } -    return eliminateDeadStores(F, AA, MD, DT, TLI); +#ifdef LLVM_ENABLE_STATS +    if (AreStatisticsEnabled()) +      for (auto &I : instructions(F)) +        NumRemainingStores += isa<StoreInst>(&I); +#endif + +    return Changed;    }    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.setPreservesCFG(); -    AU.addRequired<DominatorTreeWrapperPass>();      AU.addRequired<AAResultsWrapperPass>(); -    AU.addRequired<MemoryDependenceWrapperPass>();      AU.addRequired<TargetLibraryInfoWrapperPass>(); -    AU.addPreserved<DominatorTreeWrapperPass>();      AU.addPreserved<GlobalsAAWrapperPass>(); -    AU.addPreserved<MemoryDependenceWrapperPass>(); +    AU.addRequired<DominatorTreeWrapperPass>(); +    AU.addPreserved<DominatorTreeWrapperPass>(); + +    if (EnableMemorySSA) { +      AU.addRequired<PostDominatorTreeWrapperPass>(); +      AU.addRequired<MemorySSAWrapperPass>(); +      AU.addPreserved<PostDominatorTreeWrapperPass>(); +      AU.addPreserved<MemorySSAWrapperPass>(); +    } else { +      AU.addRequired<MemoryDependenceWrapperPass>(); +      AU.addPreserved<MemoryDependenceWrapperPass>(); +    }    }  }; @@ -1412,8 +2394,10 @@ char DSELegacyPass::ID = 0;  INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false,                        false)  INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass)  INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)  INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)  INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass)  INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)  INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false, diff --git a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp index 132dfc8f6da1..d44a5979a8b2 100644 --- a/llvm/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/llvm/lib/Transforms/Scalar/DivRemPairs.cpp @@ -17,6 +17,7 @@  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/PatternMatch.h" @@ -71,6 +72,7 @@ static llvm::Optional<ExpandedMatch> matchExpandedRem(Instruction &I) {    return M;  } +namespace {  /// A thin wrapper to store two values that we matched as div-rem pair.  /// We want this extra indirection to avoid dealing with RAUW'ing the map keys.  struct DivRemPairWorklistEntry { @@ -111,6 +113,7 @@ struct DivRemPairWorklistEntry {      }    }  }; +} // namespace  using DivRemWorklistTy = SmallVector<DivRemPairWorklistEntry, 4>;  /// Find matching pairs of integer div/rem ops (they have the same numerator, @@ -218,6 +221,7 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,        NumRecomposed++;        // Note that we have left ((X / Y) * Y) around.        // If it had other uses we could rewrite it as X - X % Y +      Changed = true;      }      assert((!E.isRemExpanded() || !HasDivRemOp) && @@ -301,6 +305,29 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI,        Mul->insertAfter(RemInst);        Sub->insertAfter(Mul); +      // If X can be undef, X should be frozen first. +      // For example, let's assume that Y = 1 & X = undef: +      //   %div = sdiv undef, 1 // %div = undef +      //   %rem = srem undef, 1 // %rem = 0 +      // => +      //   %div = sdiv undef, 1 // %div = undef +      //   %mul = mul %div, 1   // %mul = undef +      //   %rem = sub %x, %mul  // %rem = undef - undef = undef +      // If X is not frozen, %rem becomes undef after transformation. +      // TODO: We need a undef-specific checking function in ValueTracking +      if (!isGuaranteedNotToBeUndefOrPoison(X, DivInst, &DT)) { +        auto *FrX = new FreezeInst(X, X->getName() + ".frozen", DivInst); +        DivInst->setOperand(0, FrX); +        Sub->setOperand(0, FrX); +      } +      // Same for Y. If X = 1 and Y = (undef | 1), %rem in src is either 1 or 0, +      // but %rem in tgt can be one of many integer values. +      if (!isGuaranteedNotToBeUndefOrPoison(Y, DivInst, &DT)) { +        auto *FrY = new FreezeInst(Y, Y->getName() + ".frozen", DivInst); +        DivInst->setOperand(1, FrY); +        Mul->setOperand(1, FrY); +      } +        // Now kill the explicit remainder. We have replaced it with:        // (sub X, (mul (div X, Y), Y)        Sub->setName(RemInst->getName() + ".decomposed"); diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp index 40c1ba88354f..ddfc8555b0a0 100644 --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -41,6 +41,7 @@  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/PassManager.h"  #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Statepoint.h"  #include "llvm/IR/Type.h"  #include "llvm/IR/Use.h"  #include "llvm/IR/Value.h" @@ -54,6 +55,7 @@  #include "llvm/Support/RecyclingAllocator.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"  #include "llvm/Transforms/Utils/GuardUtils.h"  #include "llvm/Transforms/Utils/Local.h"  #include <cassert> @@ -114,7 +116,7 @@ struct SimpleValue {             isa<CmpInst>(Inst) || isa<SelectInst>(Inst) ||             isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||             isa<ShuffleVectorInst>(Inst) || isa<ExtractValueInst>(Inst) || -           isa<InsertValueInst>(Inst); +           isa<InsertValueInst>(Inst) || isa<FreezeInst>(Inst);    }  }; @@ -152,13 +154,50 @@ static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A,      std::swap(A, B);    } -  // Set flavor if we find a match, or set it to unknown otherwise; in -  // either case, return true to indicate that this is a select we can -  // process. -  if (auto *CmpI = dyn_cast<ICmpInst>(Cond)) -    Flavor = matchDecomposedSelectPattern(CmpI, A, B, A, B).Flavor; -  else -    Flavor = SPF_UNKNOWN; +  // Match canonical forms of abs/nabs/min/max. We are not using ValueTracking's +  // more powerful matchSelectPattern() because it may rely on instruction flags +  // such as "nsw". That would be incompatible with the current hashing +  // mechanism that may remove flags to increase the likelihood of CSE. + +  // These are the canonical forms of abs(X) and nabs(X) created by instcombine: +  // %N = sub i32 0, %X +  // %C = icmp slt i32 %X, 0 +  // %ABS = select i1 %C, i32 %N, i32 %X +  // +  // %N = sub i32 0, %X +  // %C = icmp slt i32 %X, 0 +  // %NABS = select i1 %C, i32 %X, i32 %N +  Flavor = SPF_UNKNOWN; +  CmpInst::Predicate Pred; +  if (match(Cond, m_ICmp(Pred, m_Specific(B), m_ZeroInt())) && +      Pred == ICmpInst::ICMP_SLT && match(A, m_Neg(m_Specific(B)))) { +    // ABS: B < 0 ? -B : B +    Flavor = SPF_ABS; +    return true; +  } +  if (match(Cond, m_ICmp(Pred, m_Specific(A), m_ZeroInt())) && +      Pred == ICmpInst::ICMP_SLT && match(B, m_Neg(m_Specific(A)))) { +    // NABS: A < 0 ? A : -A +    Flavor = SPF_NABS; +    return true; +  } + +  if (!match(Cond, m_ICmp(Pred, m_Specific(A), m_Specific(B)))) { +    // Check for commuted variants of min/max by swapping predicate. +    // If we do not match the standard or commuted patterns, this is not a +    // recognized form of min/max, but it is still a select, so return true. +    if (!match(Cond, m_ICmp(Pred, m_Specific(B), m_Specific(A)))) +      return true; +    Pred = ICmpInst::getSwappedPredicate(Pred); +  } + +  switch (Pred) { +  case CmpInst::ICMP_UGT: Flavor = SPF_UMAX; break; +  case CmpInst::ICMP_ULT: Flavor = SPF_UMIN; break; +  case CmpInst::ICMP_SGT: Flavor = SPF_SMAX; break; +  case CmpInst::ICMP_SLT: Flavor = SPF_SMIN; break; +  default: break; +  }    return true;  } @@ -231,6 +270,9 @@ static unsigned getHashValueImpl(SimpleValue Val) {    if (CastInst *CI = dyn_cast<CastInst>(Inst))      return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0)); +  if (FreezeInst *FI = dyn_cast<FreezeInst>(Inst)) +    return hash_combine(FI->getOpcode(), FI->getOperand(0)); +    if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(Inst))      return hash_combine(EVI->getOpcode(), EVI->getOperand(0),                          hash_combine_range(EVI->idx_begin(), EVI->idx_end())); @@ -242,7 +284,8 @@ static unsigned getHashValueImpl(SimpleValue Val) {    assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) ||            isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || -          isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst)) && +          isa<ShuffleVectorInst>(Inst) || isa<UnaryOperator>(Inst) || +          isa<FreezeInst>(Inst)) &&           "Invalid/unknown instruction");    // Mix in the opcode. @@ -414,6 +457,14 @@ template <> struct DenseMapInfo<CallValue> {  unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) {    Instruction *Inst = Val.Inst; + +  // gc.relocate is 'special' call: its second and third operands are +  // not real values, but indices into statepoint's argument list. +  // Get values they point to. +  if (const GCRelocateInst *GCR = dyn_cast<GCRelocateInst>(Inst)) +    return hash_combine(GCR->getOpcode(), GCR->getOperand(0), +                        GCR->getBasePtr(), GCR->getDerivedPtr()); +    // Hash all of the operands as pointers and mix in the opcode.    return hash_combine(        Inst->getOpcode(), @@ -424,6 +475,14 @@ bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) {    Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst;    if (LHS.isSentinel() || RHS.isSentinel())      return LHSI == RHSI; + +  // See comment above in `getHashValue()`. +  if (const GCRelocateInst *GCR1 = dyn_cast<GCRelocateInst>(LHSI)) +    if (const GCRelocateInst *GCR2 = dyn_cast<GCRelocateInst>(RHSI)) +      return GCR1->getOperand(0) == GCR2->getOperand(0) && +             GCR1->getBasePtr() == GCR2->getBasePtr() && +             GCR1->getDerivedPtr() == GCR2->getDerivedPtr(); +    return LHSI->isIdenticalTo(RHSI);  } @@ -561,8 +620,8 @@ private:    public:      StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads,                InvariantHTType &AvailableInvariants, CallHTType &AvailableCalls, -              unsigned cg, DomTreeNode *n, DomTreeNode::iterator child, -              DomTreeNode::iterator end) +              unsigned cg, DomTreeNode *n, DomTreeNode::const_iterator child, +              DomTreeNode::const_iterator end)          : CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child),            EndIter(end),            Scopes(AvailableValues, AvailableLoads, AvailableInvariants, @@ -576,7 +635,7 @@ private:      unsigned childGeneration() { return ChildGeneration; }      void childGeneration(unsigned generation) { ChildGeneration = generation; }      DomTreeNode *node() { return Node; } -    DomTreeNode::iterator childIter() { return ChildIter; } +    DomTreeNode::const_iterator childIter() { return ChildIter; }      DomTreeNode *nextChild() {        DomTreeNode *child = *ChildIter; @@ -584,7 +643,7 @@ private:        return child;      } -    DomTreeNode::iterator end() { return EndIter; } +    DomTreeNode::const_iterator end() { return EndIter; }      bool isProcessed() { return Processed; }      void process() { Processed = true; } @@ -592,8 +651,8 @@ private:      unsigned CurrentGeneration;      unsigned ChildGeneration;      DomTreeNode *Node; -    DomTreeNode::iterator ChildIter; -    DomTreeNode::iterator EndIter; +    DomTreeNode::const_iterator ChildIter; +    DomTreeNode::const_iterator EndIter;      NodeScope Scopes;      bool Processed = false;    }; @@ -716,7 +775,7 @@ private:    bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration,                             Instruction *EarlierInst, Instruction *LaterInst); -  void removeMSSA(Instruction *Inst) { +  void removeMSSA(Instruction &Inst) {      if (!MSSA)        return;      if (VerifyMemorySSA) @@ -727,7 +786,7 @@ private:      // is handled by MemorySSA when passing OptimizePhis = true to      // removeMemoryAccess.  The non-optimized MemoryUse case is lazily updated      // by MemorySSA's getClobberingMemoryAccess. -    MSSAUpdater->removeMemoryAccess(Inst, true); +    MSSAUpdater->removeMemoryAccess(&Inst, true);    }  }; @@ -897,20 +956,19 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {    // See if any instructions in the block can be eliminated.  If so, do it.  If    // not, add them to AvailableValues. -  for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { -    Instruction *Inst = &*I++; - +  for (Instruction &Inst : make_early_inc_range(BB->getInstList())) {      // Dead instructions should just be removed. -    if (isInstructionTriviallyDead(Inst, &TLI)) { -      LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n'); +    if (isInstructionTriviallyDead(&Inst, &TLI)) { +      LLVM_DEBUG(dbgs() << "EarlyCSE DCE: " << Inst << '\n');        if (!DebugCounter::shouldExecute(CSECounter)) {          LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");          continue;        } -      salvageDebugInfoOrMarkUndef(*Inst); +      salvageKnowledge(&Inst, &AC); +      salvageDebugInfo(Inst);        removeMSSA(Inst); -      Inst->eraseFromParent(); +      Inst.eraseFromParent();        Changed = true;        ++NumSimplify;        continue; @@ -920,21 +978,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      // they're marked as such to ensure preservation of control dependencies),      // and this pass will not bother with its removal. However, we should mark      // its condition as true for all dominated blocks. -    if (match(Inst, m_Intrinsic<Intrinsic::assume>())) { +    if (match(&Inst, m_Intrinsic<Intrinsic::assume>())) {        auto *CondI = -          dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0)); +          dyn_cast<Instruction>(cast<CallInst>(Inst).getArgOperand(0));        if (CondI && SimpleValue::canHandle(CondI)) { -        LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << *Inst +        LLVM_DEBUG(dbgs() << "EarlyCSE considering assumption: " << Inst                            << '\n');          AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext()));        } else -        LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); +        LLVM_DEBUG(dbgs() << "EarlyCSE skipping assumption: " << Inst << '\n');        continue;      }      // Skip sideeffect intrinsics, for the same reason as assume intrinsics. -    if (match(Inst, m_Intrinsic<Intrinsic::sideeffect>())) { -      LLVM_DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << *Inst << '\n'); +    if (match(&Inst, m_Intrinsic<Intrinsic::sideeffect>())) { +      LLVM_DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << Inst << '\n');        continue;      } @@ -951,21 +1009,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      // store 40, i8* p      // We can DSE the store to 30, since the store 40 to invariant location p      // causes undefined behaviour. -    if (match(Inst, m_Intrinsic<Intrinsic::invariant_start>())) { +    if (match(&Inst, m_Intrinsic<Intrinsic::invariant_start>())) {        // If there are any uses, the scope might end. -      if (!Inst->use_empty()) +      if (!Inst.use_empty())          continue; -      auto *CI = cast<CallInst>(Inst); -      MemoryLocation MemLoc = MemoryLocation::getForArgument(CI, 1, TLI); +      MemoryLocation MemLoc = +          MemoryLocation::getForArgument(&cast<CallInst>(Inst), 1, TLI);        // Don't start a scope if we already have a better one pushed        if (!AvailableInvariants.count(MemLoc))          AvailableInvariants.insert(MemLoc, CurrentGeneration);        continue;      } -    if (isGuard(Inst)) { +    if (isGuard(&Inst)) {        if (auto *CondI = -              dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) { +              dyn_cast<Instruction>(cast<CallInst>(Inst).getArgOperand(0))) {          if (SimpleValue::canHandle(CondI)) {            // Do we already know the actual value of this condition?            if (auto *KnownCond = AvailableValues.lookup(CondI)) { @@ -973,14 +1031,15 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {              if (isa<ConstantInt>(KnownCond) &&                  cast<ConstantInt>(KnownCond)->isOne()) {                LLVM_DEBUG(dbgs() -                         << "EarlyCSE removing guard: " << *Inst << '\n'); +                         << "EarlyCSE removing guard: " << Inst << '\n'); +              salvageKnowledge(&Inst, &AC);                removeMSSA(Inst); -              Inst->eraseFromParent(); +              Inst.eraseFromParent();                Changed = true;                continue;              } else                // Use the known value if it wasn't true. -              cast<CallInst>(Inst)->setArgOperand(0, KnownCond); +              cast<CallInst>(Inst).setArgOperand(0, KnownCond);            }            // The condition we're on guarding here is true for all dominated            // locations. @@ -997,20 +1056,21 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      // If the instruction can be simplified (e.g. X+0 = X) then replace it with      // its simpler value. -    if (Value *V = SimplifyInstruction(Inst, SQ)) { -      LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << "  to: " << *V +    if (Value *V = SimplifyInstruction(&Inst, SQ)) { +      LLVM_DEBUG(dbgs() << "EarlyCSE Simplify: " << Inst << "  to: " << *V                          << '\n');        if (!DebugCounter::shouldExecute(CSECounter)) {          LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");        } else {          bool Killed = false; -        if (!Inst->use_empty()) { -          Inst->replaceAllUsesWith(V); +        if (!Inst.use_empty()) { +          Inst.replaceAllUsesWith(V);            Changed = true;          } -        if (isInstructionTriviallyDead(Inst, &TLI)) { +        if (isInstructionTriviallyDead(&Inst, &TLI)) { +          salvageKnowledge(&Inst, &AC);            removeMSSA(Inst); -          Inst->eraseFromParent(); +          Inst.eraseFromParent();            Changed = true;            Killed = true;          } @@ -1022,31 +1082,32 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      }      // If this is a simple instruction that we can value number, process it. -    if (SimpleValue::canHandle(Inst)) { +    if (SimpleValue::canHandle(&Inst)) {        // See if the instruction has an available value.  If so, use it. -      if (Value *V = AvailableValues.lookup(Inst)) { -        LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << "  to: " << *V +      if (Value *V = AvailableValues.lookup(&Inst)) { +        LLVM_DEBUG(dbgs() << "EarlyCSE CSE: " << Inst << "  to: " << *V                            << '\n');          if (!DebugCounter::shouldExecute(CSECounter)) {            LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");            continue;          }          if (auto *I = dyn_cast<Instruction>(V)) -          I->andIRFlags(Inst); -        Inst->replaceAllUsesWith(V); +          I->andIRFlags(&Inst); +        Inst.replaceAllUsesWith(V); +        salvageKnowledge(&Inst, &AC);          removeMSSA(Inst); -        Inst->eraseFromParent(); +        Inst.eraseFromParent();          Changed = true;          ++NumCSE;          continue;        }        // Otherwise, just remember that this value is available. -      AvailableValues.insert(Inst, Inst); +      AvailableValues.insert(&Inst, &Inst);        continue;      } -    ParseMemoryInst MemInst(Inst, TTI); +    ParseMemoryInst MemInst(&Inst, TTI);      // If this is a non-volatile load, process it.      if (MemInst.isValid() && MemInst.isLoad()) {        // (conservatively) we can't peak past the ordering implied by this @@ -1062,7 +1123,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {          // We conservatively treat the invariant_load as that moment.  If we          // pass a invariant load after already establishing a scope, don't          // restart it since we want to preserve the earliest point seen. -        auto MemLoc = MemoryLocation::get(Inst); +        auto MemLoc = MemoryLocation::get(&Inst);          if (!AvailableInvariants.count(MemLoc))            AvailableInvariants.insert(MemLoc, CurrentGeneration);        } @@ -1081,21 +1142,22 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {            !MemInst.isVolatile() && MemInst.isUnordered() &&            // We can't replace an atomic load with one which isn't also atomic.            InVal.IsAtomic >= MemInst.isAtomic() && -          (isOperatingOnInvariantMemAt(Inst, InVal.Generation) || +          (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) ||             isSameMemGeneration(InVal.Generation, CurrentGeneration, -                               InVal.DefInst, Inst))) { -        Value *Op = getOrCreateResult(InVal.DefInst, Inst->getType()); +                               InVal.DefInst, &Inst))) { +        Value *Op = getOrCreateResult(InVal.DefInst, Inst.getType());          if (Op != nullptr) { -          LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst +          LLVM_DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << Inst                              << "  to: " << *InVal.DefInst << '\n');            if (!DebugCounter::shouldExecute(CSECounter)) {              LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");              continue;            } -          if (!Inst->use_empty()) -            Inst->replaceAllUsesWith(Op); +          if (!Inst.use_empty()) +            Inst.replaceAllUsesWith(Op); +          salvageKnowledge(&Inst, &AC);            removeMSSA(Inst); -          Inst->eraseFromParent(); +          Inst.eraseFromParent();            Changed = true;            ++NumCSELoad;            continue; @@ -1103,10 +1165,10 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {        }        // Otherwise, remember that we have this instruction. -      AvailableLoads.insert( -          MemInst.getPointerOperand(), -          LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), -                    MemInst.isAtomic())); +      AvailableLoads.insert(MemInst.getPointerOperand(), +                            LoadValue(&Inst, CurrentGeneration, +                                      MemInst.getMatchingId(), +                                      MemInst.isAtomic()));        LastStore = nullptr;        continue;      } @@ -1117,36 +1179,36 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      // may override this (e.g. so that a store intrinsic does not read from      // memory, and thus will be treated the same as a regular store for      // commoning purposes). -    if ((Inst->mayReadFromMemory() || Inst->mayThrow()) && +    if ((Inst.mayReadFromMemory() || Inst.mayThrow()) &&          !(MemInst.isValid() && !MemInst.mayReadFromMemory()))        LastStore = nullptr;      // If this is a read-only call, process it. -    if (CallValue::canHandle(Inst)) { +    if (CallValue::canHandle(&Inst)) {        // If we have an available version of this call, and if it is the right        // generation, replace this instruction. -      std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(Inst); +      std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(&Inst);        if (InVal.first != nullptr &&            isSameMemGeneration(InVal.second, CurrentGeneration, InVal.first, -                              Inst)) { -        LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst +                              &Inst)) { +        LLVM_DEBUG(dbgs() << "EarlyCSE CSE CALL: " << Inst                            << "  to: " << *InVal.first << '\n');          if (!DebugCounter::shouldExecute(CSECounter)) {            LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");            continue;          } -        if (!Inst->use_empty()) -          Inst->replaceAllUsesWith(InVal.first); +        if (!Inst.use_empty()) +          Inst.replaceAllUsesWith(InVal.first); +        salvageKnowledge(&Inst, &AC);          removeMSSA(Inst); -        Inst->eraseFromParent(); +        Inst.eraseFromParent();          Changed = true;          ++NumCSECall;          continue;        }        // Otherwise, remember that we have this instruction. -      AvailableCalls.insert( -          Inst, std::pair<Instruction *, unsigned>(Inst, CurrentGeneration)); +      AvailableCalls.insert(&Inst, std::make_pair(&Inst, CurrentGeneration));        continue;      } @@ -1155,9 +1217,9 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      // result, we don't need to consider it as writing to memory and don't need      // to advance the generation.  We do need to prevent DSE across the fence,      // but that's handled above. -    if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) +    if (auto *FI = dyn_cast<FenceInst>(&Inst))        if (FI->getOrdering() == AtomicOrdering::Release) { -        assert(Inst->mayReadFromMemory() && "relied on to prevent DSE above"); +        assert(Inst.mayReadFromMemory() && "relied on to prevent DSE above");          continue;        } @@ -1169,13 +1231,13 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      if (MemInst.isValid() && MemInst.isStore()) {        LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand());        if (InVal.DefInst && -          InVal.DefInst == getOrCreateResult(Inst, InVal.DefInst->getType()) && +          InVal.DefInst == getOrCreateResult(&Inst, InVal.DefInst->getType()) &&            InVal.MatchingId == MemInst.getMatchingId() &&            // We don't yet handle removing stores with ordering of any kind.            !MemInst.isVolatile() && MemInst.isUnordered() && -          (isOperatingOnInvariantMemAt(Inst, InVal.Generation) || +          (isOperatingOnInvariantMemAt(&Inst, InVal.Generation) ||             isSameMemGeneration(InVal.Generation, CurrentGeneration, -                               InVal.DefInst, Inst))) { +                               InVal.DefInst, &Inst))) {          // It is okay to have a LastStore to a different pointer here if MemorySSA          // tells us that the load and store are from the same memory generation.          // In that case, LastStore should keep its present value since we're @@ -1185,13 +1247,14 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {                      MemInst.getPointerOperand() ||                  MSSA) &&                 "can't have an intervening store if not using MemorySSA!"); -        LLVM_DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << *Inst << '\n'); +        LLVM_DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << Inst << '\n');          if (!DebugCounter::shouldExecute(CSECounter)) {            LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");            continue;          } +        salvageKnowledge(&Inst, &AC);          removeMSSA(Inst); -        Inst->eraseFromParent(); +        Inst.eraseFromParent();          Changed = true;          ++NumDSE;          // We can avoid incrementing the generation count since we were able @@ -1203,7 +1266,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {      // Okay, this isn't something we can CSE at all.  Check to see if it is      // something that could modify memory.  If so, our available memory values      // cannot be used so bump the generation count. -    if (Inst->mayWriteToMemory()) { +    if (Inst.mayWriteToMemory()) {        ++CurrentGeneration;        if (MemInst.isValid() && MemInst.isStore()) { @@ -1221,11 +1284,12 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {                   "Violated invariant");            if (LastStoreMemInst.isMatchingMemLoc(MemInst)) {              LLVM_DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore -                              << "  due to: " << *Inst << '\n'); +                              << "  due to: " << Inst << '\n');              if (!DebugCounter::shouldExecute(CSECounter)) {                LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n");              } else { -              removeMSSA(LastStore); +              salvageKnowledge(&Inst, &AC); +              removeMSSA(*LastStore);                LastStore->eraseFromParent();                Changed = true;                ++NumDSE; @@ -1240,10 +1304,10 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {          // version of the pointer.  It is safe to forward from volatile stores          // to non-volatile loads, so we don't have to check for volatility of          // the store. -        AvailableLoads.insert( -            MemInst.getPointerOperand(), -            LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), -                      MemInst.isAtomic())); +        AvailableLoads.insert(MemInst.getPointerOperand(), +                              LoadValue(&Inst, CurrentGeneration, +                                        MemInst.getMatchingId(), +                                        MemInst.isAtomic()));          // Remember that this was the last unordered store we saw for DSE. We          // don't yet handle DSE on ordered or volatile stores since we don't @@ -1253,7 +1317,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) {          // it's not clear this is a profitable transform. Another option would          // be to merge the ordering with that of the post dominating store.          if (MemInst.isUnordered() && !MemInst.isVolatile()) -          LastStore = Inst; +          LastStore = &Inst;          else            LastStore = nullptr;        } diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp index af223cc837f2..83f4c402ed4d 100644 --- a/llvm/lib/Transforms/Scalar/Float2Int.cpp +++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -120,8 +120,7 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {  // Find the roots - instructions that convert from the FP domain to  // integer domain. -void Float2IntPass::findRoots(Function &F, const DominatorTree &DT, -                              SmallPtrSet<Instruction*,8> &Roots) { +void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {    for (BasicBlock &BB : F) {      // Unreachable code can take on strange forms that we are not prepared to      // handle. For example, an instruction may have itself as an operand. @@ -184,7 +183,7 @@ ConstantRange Float2IntPass::validateRange(ConstantRange R) {  // Breadth-first walk of the use-def graph; determine the set of nodes  // we care about and eagerly determine if some of them are poisonous. -void Float2IntPass::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { +void Float2IntPass::walkBackwards() {    std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());    while (!Worklist.empty()) {      Instruction *I = Worklist.back(); @@ -327,7 +326,7 @@ void Float2IntPass::walkForwards() {          APFloat NewF = F;          auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven); -        if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) { +        if (Res != APFloat::opOK || NewF != F) {            seen(I, badRange());            Abort = true;            break; @@ -525,9 +524,9 @@ bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {    Ctx = &F.getParent()->getContext(); -  findRoots(F, DT, Roots); +  findRoots(F, DT); -  walkBackwards(Roots); +  walkBackwards();    walkForwards();    bool Modified = validateAndTransform(); diff --git a/llvm/lib/Transforms/Scalar/GVN.cpp b/llvm/lib/Transforms/Scalar/GVN.cpp index 1e6aab14e7b4..b16f8591b5a4 100644 --- a/llvm/lib/Transforms/Scalar/GVN.cpp +++ b/llvm/lib/Transforms/Scalar/GVN.cpp @@ -26,6 +26,7 @@  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumeBundleQueries.h"  #include "llvm/Analysis/AliasAnalysis.h"  #include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/CFG.h" @@ -42,7 +43,6 @@  #include "llvm/Config/llvm-config.h"  #include "llvm/IR/Attributes.h"  #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h" @@ -72,6 +72,7 @@  #include "llvm/Support/Debug.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/SSAUpdater.h" @@ -97,10 +98,11 @@ STATISTIC(NumGVNSimpl,  "Number of instructions simplified");  STATISTIC(NumGVNEqProp, "Number of equalities propagated");  STATISTIC(NumPRELoad,   "Number of loads PRE'd"); -static cl::opt<bool> EnablePRE("enable-pre", -                               cl::init(true), cl::Hidden); -static cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true)); -static cl::opt<bool> EnableMemDep("enable-gvn-memdep", cl::init(true)); +static cl::opt<bool> GVNEnablePRE("enable-pre", cl::init(true), cl::Hidden); +static cl::opt<bool> GVNEnableLoadPRE("enable-load-pre", cl::init(true)); +static cl::opt<bool> GVNEnableLoadInLoopPRE("enable-load-in-loop-pre", +                                            cl::init(true)); +static cl::opt<bool> GVNEnableMemDep("enable-gvn-memdep", cl::init(true));  // Maximum allowed recursion depth.  static cl::opt<uint32_t> @@ -113,8 +115,8 @@ static cl::opt<uint32_t> MaxNumDeps(  struct llvm::GVN::Expression {    uint32_t opcode; -  Type *type = nullptr;    bool commutative = false; +  Type *type = nullptr;    SmallVector<uint32_t, 4> varargs;    Expression(uint32_t o = ~2U) : opcode(o) {} @@ -288,7 +290,7 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) {      e.commutative = true;    } -  if (CmpInst *C = dyn_cast<CmpInst>(I)) { +  if (auto *C = dyn_cast<CmpInst>(I)) {      // Sort the operand value numbers so x<y and y>x get the same value number.      CmpInst::Predicate Predicate = C->getPredicate();      if (e.varargs[0] > e.varargs[1]) { @@ -297,10 +299,11 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) {      }      e.opcode = (C->getOpcode() << 8) | Predicate;      e.commutative = true; -  } else if (InsertValueInst *E = dyn_cast<InsertValueInst>(I)) { -    for (InsertValueInst::idx_iterator II = E->idx_begin(), IE = E->idx_end(); -         II != IE; ++II) -      e.varargs.push_back(*II); +  } else if (auto *E = dyn_cast<InsertValueInst>(I)) { +    e.varargs.append(E->idx_begin(), E->idx_end()); +  } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) { +    ArrayRef<int> ShuffleMask = SVI->getShuffleMask(); +    e.varargs.append(ShuffleMask.begin(), ShuffleMask.end());    }    return e; @@ -530,6 +533,7 @@ uint32_t GVN::ValueTable::lookupOrAdd(Value *V) {      case Instruction::AddrSpaceCast:      case Instruction::BitCast:      case Instruction::Select: +    case Instruction::Freeze:      case Instruction::ExtractElement:      case Instruction::InsertElement:      case Instruction::ShuffleVector: @@ -610,6 +614,22 @@ void GVN::ValueTable::verifyRemoved(const Value *V) const {  //                                GVN Pass  //===----------------------------------------------------------------------===// +bool GVN::isPREEnabled() const { +  return Options.AllowPRE.getValueOr(GVNEnablePRE); +} + +bool GVN::isLoadPREEnabled() const { +  return Options.AllowLoadPRE.getValueOr(GVNEnableLoadPRE); +} + +bool GVN::isLoadInLoopPREEnabled() const { +  return Options.AllowLoadInLoopPRE.getValueOr(GVNEnableLoadInLoopPRE); +} + +bool GVN::isMemDepEnabled() const { +  return Options.AllowMemDep.getValueOr(GVNEnableMemDep); +} +  PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) {    // FIXME: The order of evaluation of these 'getResult' calls is very    // significant! Re-ordering these variables will cause GVN when run alone to @@ -619,10 +639,11 @@ PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) {    auto &DT = AM.getResult<DominatorTreeAnalysis>(F);    auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);    auto &AA = AM.getResult<AAManager>(F); -  auto &MemDep = AM.getResult<MemoryDependenceAnalysis>(F); +  auto *MemDep = +      isMemDepEnabled() ? &AM.getResult<MemoryDependenceAnalysis>(F) : nullptr;    auto *LI = AM.getCachedResult<LoopAnalysis>(F);    auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); -  bool Changed = runImpl(F, AC, DT, TLI, AA, &MemDep, LI, &ORE); +  bool Changed = runImpl(F, AC, DT, TLI, AA, MemDep, LI, &ORE);    if (!Changed)      return PreservedAnalyses::all();    PreservedAnalyses PA; @@ -927,6 +948,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo,    // Loading the allocation -> undef.    if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || +      isAlignedAllocLikeFn(DepInst, TLI) ||        // Loading immediately after lifetime begin -> undef.        isLifetimeStart(DepInst)) {      Res = AvailableValue::get(UndefValue::get(LI->getType())); @@ -1245,7 +1267,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock,      auto *NewLoad = new LoadInst(          LI->getType(), LoadPtr, LI->getName() + ".pre", LI->isVolatile(), -        MaybeAlign(LI->getAlignment()), LI->getOrdering(), LI->getSyncScopeID(), +        LI->getAlign(), LI->getOrdering(), LI->getSyncScopeID(),          UnavailablePred->getTerminator());      NewLoad->setDebugLoc(LI->getDebugLoc()); @@ -1383,7 +1405,10 @@ bool GVN::processNonLocalLoad(LoadInst *LI) {    }    // Step 4: Eliminate partial redundancy. -  if (!EnablePRE || !EnableLoadPRE) +  if (!isPREEnabled() || !isLoadPREEnabled()) +    return false; +  if (!isLoadInLoopPREEnabled() && this->LI && +      this->LI->getLoopFor(LI->getParent()))      return false;    return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks); @@ -1428,7 +1453,7 @@ static bool impliesEquivalanceIfFalse(CmpInst* Cmp) {        Value *LHS = Cmp->getOperand(0);        Value *RHS = Cmp->getOperand(1);        // If we can prove either side non-zero, then equality must imply -      // equivalence.  +      // equivalence.        // FIXME: We should do this optimization if 'no signed zeros' is        // applicable via an instruction-level fast-math-flag or some other        // indicator that relaxed FP semantics are being used. @@ -1465,7 +1490,8 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {                      Constant::getNullValue(Int8Ty->getPointerTo()),                      IntrinsicI);      } -    markInstructionForDeletion(IntrinsicI); +    if (isAssumeWithEmptyBundle(*IntrinsicI)) +      markInstructionForDeletion(IntrinsicI);      return false;    } else if (isa<Constant>(V)) {      // If it's not false, and constant, it must evaluate to true. This means our @@ -1493,10 +1519,10 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {    // If we find an equality fact, canonicalize all dominated uses in this block    // to one of the two values.  We heuristically choice the "oldest" of the    // two where age is determined by value number. (Note that propagateEquality -  // above handles the cross block case.)  -  //  +  // above handles the cross block case.) +  //    // Key case to cover are: -  // 1)  +  // 1)    // %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen    // call void @llvm.assume(i1 %cmp)    // ret float %0 ; will change it to ret float 3.000000e+00 @@ -1537,7 +1563,7 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) {                   << *CmpLHS << " with "                   << *CmpRHS << " in block "                   << IntrinsicI->getParent()->getName() << "\n"); -       +        // Setup the replacement map - this handles uses within the same block        if (hasUsersIn(CmpLHS, IntrinsicI->getParent())) @@ -1710,7 +1736,8 @@ uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred,      // instead of value numbers. Those index numbers should not be      // translated.      if ((i > 1 && Exp.opcode == Instruction::InsertValue) || -        (i > 0 && Exp.opcode == Instruction::ExtractValue)) +        (i > 0 && Exp.opcode == Instruction::ExtractValue) || +        (i > 1 && Exp.opcode == Instruction::ShuffleVector))        continue;      Exp.varargs[i] = phiTranslate(Pred, PhiBlock, Exp.varargs[i], Gvn);    } @@ -1802,7 +1829,7 @@ void GVN::assignBlockRPONumber(Function &F) {  bool GVN::replaceOperandsForInBlockEquality(Instruction *Instr) const {    bool Changed = false;    for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) { -    Value *Operand = Instr->getOperand(OpNum);  +    Value *Operand = Instr->getOperand(OpNum);      auto it = ReplaceOperandsWithMap.find(Operand);      if (it != ReplaceOperandsWithMap.end()) {        LLVM_DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " @@ -1922,7 +1949,7 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root,        // If "A == B" is known true, or "A != B" is known false, then replace        // A with B everywhere in the scope.  For floating point operations, we -      // have to be careful since equality does not always imply equivalance.   +      // have to be careful since equality does not always imply equivalance.        if ((isKnownTrue && impliesEquivalanceIfTrue(Cmp)) ||            (isKnownFalse && impliesEquivalanceIfFalse(Cmp)))          Worklist.push_back(std::make_pair(Op0, Op1)); @@ -2117,7 +2144,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,    TLI = &RunTLI;    VN.setAliasAnalysis(&RunAA);    MD = RunMD; -  ImplicitControlFlowTracking ImplicitCFT(DT); +  ImplicitControlFlowTracking ImplicitCFT;    ICF = &ImplicitCFT;    this->LI = LI;    VN.setMemDep(MD); @@ -2148,7 +2175,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT,      ++Iteration;    } -  if (EnablePRE) { +  if (isPREEnabled()) {      // Fabricate val-num for dead-code in order to suppress assertion in      // performPRE().      assignValNumForDeadCode(); @@ -2206,6 +2233,7 @@ bool GVN::processBlock(BasicBlock *BB) {      for (auto *I : InstrsToErase) {        assert(I->getParent() == BB && "Removing instruction from wrong block?");        LLVM_DEBUG(dbgs() << "GVN removed: " << *I << '\n'); +      salvageKnowledge(I, AC);        salvageDebugInfo(*I);        if (MD) MD->removeInstruction(I);        LLVM_DEBUG(verifyRemoved(I)); @@ -2478,8 +2506,11 @@ bool GVN::performPRE(Function &F) {  /// Split the critical edge connecting the given two blocks, and return  /// the block inserted to the critical edge.  BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { -  BasicBlock *BB = -      SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT, LI)); +  // GVN does not require loop-simplify, do not try to preserve it if it is not +  // possible. +  BasicBlock *BB = SplitCriticalEdge( +      Pred, Succ, +      CriticalEdgeSplittingOptions(DT, LI).unsetPreserveLoopSimplify());    if (MD)      MD->invalidateCachedPredecessors();    InvalidBlockRPONumbers = true; @@ -2682,8 +2713,8 @@ class llvm::gvn::GVNLegacyPass : public FunctionPass {  public:    static char ID; // Pass identification, replacement for typeid -  explicit GVNLegacyPass(bool NoMemDepAnalysis = !EnableMemDep) -      : FunctionPass(ID), NoMemDepAnalysis(NoMemDepAnalysis) { +  explicit GVNLegacyPass(bool NoMemDepAnalysis = !GVNEnableMemDep) +      : FunctionPass(ID), Impl(GVNOptions().setMemDep(!NoMemDepAnalysis)) {      initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry());    } @@ -2698,9 +2729,9 @@ public:          getAnalysis<DominatorTreeWrapperPass>().getDomTree(),          getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),          getAnalysis<AAResultsWrapperPass>().getAAResults(), -        NoMemDepAnalysis -            ? nullptr -            : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(), +        Impl.isMemDepEnabled() +            ? &getAnalysis<MemoryDependenceWrapperPass>().getMemDep() +            : nullptr,          LIWP ? &LIWP->getLoopInfo() : nullptr,          &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE());    } @@ -2710,7 +2741,7 @@ public:      AU.addRequired<DominatorTreeWrapperPass>();      AU.addRequired<TargetLibraryInfoWrapperPass>();      AU.addRequired<LoopInfoWrapperPass>(); -    if (!NoMemDepAnalysis) +    if (Impl.isMemDepEnabled())        AU.addRequired<MemoryDependenceWrapperPass>();      AU.addRequired<AAResultsWrapperPass>(); @@ -2718,12 +2749,10 @@ public:      AU.addPreserved<GlobalsAAWrapperPass>();      AU.addPreserved<TargetLibraryInfoWrapperPass>();      AU.addPreserved<LoopInfoWrapperPass>(); -    AU.addPreservedID(LoopSimplifyID);      AU.addRequired<OptimizationRemarkEmitterWrapperPass>();    }  private: -  bool NoMemDepAnalysis;    GVN Impl;  }; diff --git a/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/llvm/lib/Transforms/Scalar/GVNHoist.cpp index e1796f6bf05a..9c4cdf2feb56 100644 --- a/llvm/lib/Transforms/Scalar/GVNHoist.cpp +++ b/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -890,18 +890,16 @@ private:    void updateAlignment(Instruction *I, Instruction *Repl) {      if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) { -      ReplacementLoad->setAlignment(MaybeAlign(std::min( -          ReplacementLoad->getAlignment(), cast<LoadInst>(I)->getAlignment()))); +      ReplacementLoad->setAlignment( +          std::min(ReplacementLoad->getAlign(), cast<LoadInst>(I)->getAlign()));        ++NumLoadsRemoved;      } else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) { -      ReplacementStore->setAlignment( -          MaybeAlign(std::min(ReplacementStore->getAlignment(), -                              cast<StoreInst>(I)->getAlignment()))); +      ReplacementStore->setAlignment(std::min(ReplacementStore->getAlign(), +                                              cast<StoreInst>(I)->getAlign()));        ++NumStoresRemoved;      } else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) { -      ReplacementAlloca->setAlignment( -          MaybeAlign(std::max(ReplacementAlloca->getAlignment(), -                              cast<AllocaInst>(I)->getAlignment()))); +      ReplacementAlloca->setAlignment(std::max( +          ReplacementAlloca->getAlign(), cast<AllocaInst>(I)->getAlign()));      } else if (isa<CallInst>(Repl)) {        ++NumCallsRemoved;      } diff --git a/llvm/lib/Transforms/Scalar/GVNSink.cpp b/llvm/lib/Transforms/Scalar/GVNSink.cpp index 6d0a4975e266..dfb4b7e038ba 100644 --- a/llvm/lib/Transforms/Scalar/GVNSink.cpp +++ b/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -350,6 +350,7 @@ using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>;  class InstructionUseExpr : public GVNExpression::BasicExpression {    unsigned MemoryUseOrder = -1;    bool Volatile = false; +  ArrayRef<int> ShuffleMask;  public:    InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R, @@ -359,6 +360,9 @@ public:      setOpcode(I->getOpcode());      setType(I->getType()); +    if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I)) +      ShuffleMask = SVI->getShuffleMask().copy(A); +      for (auto &U : I->uses())        op_push_back(U.getUser());      llvm::sort(op_begin(), op_end()); @@ -369,12 +373,12 @@ public:    hash_code getHashValue() const override {      return hash_combine(GVNExpression::BasicExpression::getHashValue(), -                        MemoryUseOrder, Volatile); +                        MemoryUseOrder, Volatile, ShuffleMask);    }    template <typename Function> hash_code getHashValue(Function MapFn) { -    hash_code H = -        hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile); +    hash_code H = hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile, +                               ShuffleMask);      for (auto *V : operands())        H = hash_combine(H, MapFn(V));      return H; @@ -475,6 +479,7 @@ public:      case Instruction::PtrToInt:      case Instruction::IntToPtr:      case Instruction::BitCast: +    case Instruction::AddrSpaceCast:      case Instruction::Select:      case Instruction::ExtractElement:      case Instruction::InsertElement: @@ -576,7 +581,7 @@ public:  private:    ValueTable VN; -  bool isInstructionBlacklisted(Instruction *I) { +  bool shouldAvoidSinkingInstruction(Instruction *I) {      // These instructions may change or break semantics if moved.      if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) ||          I->getType()->isTokenTy()) @@ -668,7 +673,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking(        NewInsts.push_back(I);    }    for (auto *I : NewInsts) -    if (isInstructionBlacklisted(I)) +    if (shouldAvoidSinkingInstruction(I))        return None;    // If we've restricted the incoming blocks, restrict all needed PHIs also diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index d8d7acae5c9f..0f36c3f772e6 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -38,8 +38,9 @@  #include "llvm/ADT/iterator_range.h"  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/TargetTransformInfo.h" @@ -81,6 +82,7 @@  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include "llvm/Transforms/Utils/SimplifyIndVar.h"  #include <cassert>  #include <cstdint> @@ -100,10 +102,10 @@ STATISTIC(NumElimIV      , "Number of congruent IVs eliminated");  // implement a strong expression equivalence checker in SCEV. Until then, we  // use the verify-indvars flag, which may assert in some cases.  static cl::opt<bool> VerifyIndvars( -  "verify-indvars", cl::Hidden, -  cl::desc("Verify the ScalarEvolution result after running indvars")); - -enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, NoHardUse, AlwaysRepl }; +    "verify-indvars", cl::Hidden, +    cl::desc("Verify the ScalarEvolution result after running indvars. Has no " +             "effect in release builds. (Note: this adds additional SCEV " +             "queries potentially changing the analysis result)"));  static cl::opt<ReplaceExitVal> ReplaceExitValue(      "replexitval", cl::Hidden, cl::init(OnlyCheapRepl), @@ -140,11 +142,10 @@ class IndVarSimplify {    const DataLayout &DL;    TargetLibraryInfo *TLI;    const TargetTransformInfo *TTI; +  std::unique_ptr<MemorySSAUpdater> MSSAU;    SmallVector<WeakTrackingVH, 16> DeadInsts; -  bool isValidRewrite(Value *FromVal, Value *ToVal); -    bool handleFloatingPointIV(Loop *L, PHINode *PH);    bool rewriteNonIntegerIVs(Loop *L); @@ -155,10 +156,7 @@ class IndVarSimplify {    /// iterations of the loop run when that is unobservable.    bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); -  bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet); -  bool rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter);    bool rewriteFirstIterationLoopExitValues(Loop *L); -  bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) const;    bool linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,                                   const SCEV *ExitCount, @@ -169,66 +167,17 @@ class IndVarSimplify {  public:    IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT,                   const DataLayout &DL, TargetLibraryInfo *TLI, -                 TargetTransformInfo *TTI) -      : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) {} +                 TargetTransformInfo *TTI, MemorySSA *MSSA) +      : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) { +    if (MSSA) +      MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); +  }    bool run(Loop *L);  };  } // end anonymous namespace -/// Return true if the SCEV expansion generated by the rewriter can replace the -/// original value. SCEV guarantees that it produces the same value, but the way -/// it is produced may be illegal IR.  Ideally, this function will only be -/// called for verification. -bool IndVarSimplify::isValidRewrite(Value *FromVal, Value *ToVal) { -  // If an SCEV expression subsumed multiple pointers, its expansion could -  // reassociate the GEP changing the base pointer. This is illegal because the -  // final address produced by a GEP chain must be inbounds relative to its -  // underlying object. Otherwise basic alias analysis, among other things, -  // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid -  // producing an expression involving multiple pointers. Until then, we must -  // bail out here. -  // -  // Retrieve the pointer operand of the GEP. Don't use GetUnderlyingObject -  // because it understands lcssa phis while SCEV does not. -  Value *FromPtr = FromVal; -  Value *ToPtr = ToVal; -  if (auto *GEP = dyn_cast<GEPOperator>(FromVal)) { -    FromPtr = GEP->getPointerOperand(); -  } -  if (auto *GEP = dyn_cast<GEPOperator>(ToVal)) { -    ToPtr = GEP->getPointerOperand(); -  } -  if (FromPtr != FromVal || ToPtr != ToVal) { -    // Quickly check the common case -    if (FromPtr == ToPtr) -      return true; - -    // SCEV may have rewritten an expression that produces the GEP's pointer -    // operand. That's ok as long as the pointer operand has the same base -    // pointer. Unlike GetUnderlyingObject(), getPointerBase() will find the -    // base of a recurrence. This handles the case in which SCEV expansion -    // converts a pointer type recurrence into a nonrecurrent pointer base -    // indexed by an integer recurrence. - -    // If the GEP base pointer is a vector of pointers, abort. -    if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy()) -      return false; - -    const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr)); -    const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr)); -    if (FromBase == ToBase) -      return true; - -    LLVM_DEBUG(dbgs() << "INDVARS: GEP rewrite bail out " << *FromBase -                      << " != " << *ToBase << "\n"); - -    return false; -  } -  return true; -} -  /// Determine the insertion point for this user. By default, insert immediately  /// before the user. SCEVExpander or LICM will hoist loop invariants out of the  /// loop. For PHI nodes, there may be multiple uses, so compute the nearest @@ -477,11 +426,11 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) {    // new comparison.    NewCompare->takeName(Compare);    Compare->replaceAllUsesWith(NewCompare); -  RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI); +  RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI, MSSAU.get());    // Delete the old floating point increment.    Incr->replaceAllUsesWith(UndefValue::get(Incr->getType())); -  RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI); +  RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI, MSSAU.get());    // If the FP induction variable still has uses, this is because something else    // in the loop uses its value.  In order to canonicalize the induction @@ -494,7 +443,7 @@ bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) {      Value *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv",                                   &*PN->getParent()->getFirstInsertionPt());      PN->replaceAllUsesWith(Conv); -    RecursivelyDeleteTriviallyDeadInstructions(PN, TLI); +    RecursivelyDeleteTriviallyDeadInstructions(PN, TLI, MSSAU.get());    }    return true;  } @@ -522,222 +471,6 @@ bool IndVarSimplify::rewriteNonIntegerIVs(Loop *L) {    return Changed;  } -namespace { - -// Collect information about PHI nodes which can be transformed in -// rewriteLoopExitValues. -struct RewritePhi { -  PHINode *PN; - -  // Ith incoming value. -  unsigned Ith; - -  // Exit value after expansion. -  Value *Val; - -  // High Cost when expansion. -  bool HighCost; - -  RewritePhi(PHINode *P, unsigned I, Value *V, bool H) -      : PN(P), Ith(I), Val(V), HighCost(H) {} -}; - -} // end anonymous namespace - -//===----------------------------------------------------------------------===// -// rewriteLoopExitValues - Optimize IV users outside the loop. -// As a side effect, reduces the amount of IV processing within the loop. -//===----------------------------------------------------------------------===// - -bool IndVarSimplify::hasHardUserWithinLoop(const Loop *L, const Instruction *I) const { -  SmallPtrSet<const Instruction *, 8> Visited; -  SmallVector<const Instruction *, 8> WorkList; -  Visited.insert(I); -  WorkList.push_back(I); -  while (!WorkList.empty()) { -    const Instruction *Curr = WorkList.pop_back_val(); -    // This use is outside the loop, nothing to do. -    if (!L->contains(Curr)) -      continue; -    // Do we assume it is a "hard" use which will not be eliminated easily? -    if (Curr->mayHaveSideEffects()) -      return true; -    // Otherwise, add all its users to worklist. -    for (auto U : Curr->users()) { -      auto *UI = cast<Instruction>(U); -      if (Visited.insert(UI).second) -        WorkList.push_back(UI); -    } -  } -  return false; -} - -/// Check to see if this loop has a computable loop-invariant execution count. -/// If so, this means that we can compute the final value of any expressions -/// that are recurrent in the loop, and substitute the exit values from the loop -/// into any instructions outside of the loop that use the final values of the -/// current expressions. -/// -/// This is mostly redundant with the regular IndVarSimplify activities that -/// happen later, except that it's more powerful in some cases, because it's -/// able to brute-force evaluate arbitrary instructions as long as they have -/// constant operands at the beginning of the loop. -bool IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { -  // Check a pre-condition. -  assert(L->isRecursivelyLCSSAForm(*DT, *LI) && -         "Indvars did not preserve LCSSA!"); - -  SmallVector<BasicBlock*, 8> ExitBlocks; -  L->getUniqueExitBlocks(ExitBlocks); - -  SmallVector<RewritePhi, 8> RewritePhiSet; -  // Find all values that are computed inside the loop, but used outside of it. -  // Because of LCSSA, these values will only occur in LCSSA PHI Nodes.  Scan -  // the exit blocks of the loop to find them. -  for (BasicBlock *ExitBB : ExitBlocks) { -    // If there are no PHI nodes in this exit block, then no values defined -    // inside the loop are used on this path, skip it. -    PHINode *PN = dyn_cast<PHINode>(ExitBB->begin()); -    if (!PN) continue; - -    unsigned NumPreds = PN->getNumIncomingValues(); - -    // Iterate over all of the PHI nodes. -    BasicBlock::iterator BBI = ExitBB->begin(); -    while ((PN = dyn_cast<PHINode>(BBI++))) { -      if (PN->use_empty()) -        continue; // dead use, don't replace it - -      if (!SE->isSCEVable(PN->getType())) -        continue; - -      // It's necessary to tell ScalarEvolution about this explicitly so that -      // it can walk the def-use list and forget all SCEVs, as it may not be -      // watching the PHI itself. Once the new exit value is in place, there -      // may not be a def-use connection between the loop and every instruction -      // which got a SCEVAddRecExpr for that loop. -      SE->forgetValue(PN); - -      // Iterate over all of the values in all the PHI nodes. -      for (unsigned i = 0; i != NumPreds; ++i) { -        // If the value being merged in is not integer or is not defined -        // in the loop, skip it. -        Value *InVal = PN->getIncomingValue(i); -        if (!isa<Instruction>(InVal)) -          continue; - -        // If this pred is for a subloop, not L itself, skip it. -        if (LI->getLoopFor(PN->getIncomingBlock(i)) != L) -          continue; // The Block is in a subloop, skip it. - -        // Check that InVal is defined in the loop. -        Instruction *Inst = cast<Instruction>(InVal); -        if (!L->contains(Inst)) -          continue; - -        // Okay, this instruction has a user outside of the current loop -        // and varies predictably *inside* the loop.  Evaluate the value it -        // contains when the loop exits, if possible.  We prefer to start with -        // expressions which are true for all exits (so as to maximize -        // expression reuse by the SCEVExpander), but resort to per-exit -        // evaluation if that fails.   -        const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); -        if (isa<SCEVCouldNotCompute>(ExitValue) || -            !SE->isLoopInvariant(ExitValue, L) || -            !isSafeToExpand(ExitValue, *SE)) { -          // TODO: This should probably be sunk into SCEV in some way; maybe a -          // getSCEVForExit(SCEV*, L, ExitingBB)?  It can be generalized for -          // most SCEV expressions and other recurrence types (e.g. shift -          // recurrences).  Is there existing code we can reuse? -          const SCEV *ExitCount = SE->getExitCount(L, PN->getIncomingBlock(i)); -          if (isa<SCEVCouldNotCompute>(ExitCount)) -            continue; -          if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Inst))) -            if (AddRec->getLoop() == L) -              ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE); -          if (isa<SCEVCouldNotCompute>(ExitValue) || -              !SE->isLoopInvariant(ExitValue, L) || -              !isSafeToExpand(ExitValue, *SE)) -            continue; -        } -         -        // Computing the value outside of the loop brings no benefit if it is -        // definitely used inside the loop in a way which can not be optimized -        // away.  Avoid doing so unless we know we have a value which computes -        // the ExitValue already.  TODO: This should be merged into SCEV -        // expander to leverage its knowledge of existing expressions. -        if (ReplaceExitValue != AlwaysRepl && -            !isa<SCEVConstant>(ExitValue) && !isa<SCEVUnknown>(ExitValue) && -            hasHardUserWithinLoop(L, Inst)) -          continue; - -        bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst); -        Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst); - -        LLVM_DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal -                          << '\n' -                          << "  LoopVal = " << *Inst << "\n"); - -        if (!isValidRewrite(Inst, ExitVal)) { -          DeadInsts.push_back(ExitVal); -          continue; -        } - -#ifndef NDEBUG -        // If we reuse an instruction from a loop which is neither L nor one of -        // its containing loops, we end up breaking LCSSA form for this loop by -        // creating a new use of its instruction. -        if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal)) -          if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) -            if (EVL != L) -              assert(EVL->contains(L) && "LCSSA breach detected!"); -#endif - -        // Collect all the candidate PHINodes to be rewritten. -        RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost); -      } -    } -  } - -  bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); - -  bool Changed = false; -  // Transformation. -  for (const RewritePhi &Phi : RewritePhiSet) { -    PHINode *PN = Phi.PN; -    Value *ExitVal = Phi.Val; - -    // Only do the rewrite when the ExitValue can be expanded cheaply. -    // If LoopCanBeDel is true, rewrite exit value aggressively. -    if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) { -      DeadInsts.push_back(ExitVal); -      continue; -    } - -    Changed = true; -    ++NumReplaced; -    Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith)); -    PN->setIncomingValue(Phi.Ith, ExitVal); - -    // If this instruction is dead now, delete it. Don't do it now to avoid -    // invalidating iterators. -    if (isInstructionTriviallyDead(Inst, TLI)) -      DeadInsts.push_back(Inst); - -    // Replace PN with ExitVal if that is legal and does not break LCSSA. -    if (PN->getNumIncomingValues() == 1 && -        LI->replacementPreservesLCSSAForm(PN, ExitVal)) { -      PN->replaceAllUsesWith(ExitVal); -      PN->eraseFromParent(); -    } -  } - -  // The insertion point instruction may have been deleted; clear it out -  // so that the rewriter doesn't trip over it later. -  Rewriter.clearInsertPoint(); -  return Changed; -} -  //===---------------------------------------------------------------------===//  // rewriteFirstIterationLoopExitValues: Rewrite loop exit values if we know  // they will exit at the first iteration. @@ -813,61 +546,6 @@ bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) {    return MadeAnyChanges;  } -/// Check whether it is possible to delete the loop after rewriting exit -/// value. If it is possible, ignore ReplaceExitValue and do rewriting -/// aggressively. -bool IndVarSimplify::canLoopBeDeleted( -    Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) { -  BasicBlock *Preheader = L->getLoopPreheader(); -  // If there is no preheader, the loop will not be deleted. -  if (!Preheader) -    return false; - -  // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1. -  // We obviate multiple ExitingBlocks case for simplicity. -  // TODO: If we see testcase with multiple ExitingBlocks can be deleted -  // after exit value rewriting, we can enhance the logic here. -  SmallVector<BasicBlock *, 4> ExitingBlocks; -  L->getExitingBlocks(ExitingBlocks); -  SmallVector<BasicBlock *, 8> ExitBlocks; -  L->getUniqueExitBlocks(ExitBlocks); -  if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1) -    return false; - -  BasicBlock *ExitBlock = ExitBlocks[0]; -  BasicBlock::iterator BI = ExitBlock->begin(); -  while (PHINode *P = dyn_cast<PHINode>(BI)) { -    Value *Incoming = P->getIncomingValueForBlock(ExitingBlocks[0]); - -    // If the Incoming value of P is found in RewritePhiSet, we know it -    // could be rewritten to use a loop invariant value in transformation -    // phase later. Skip it in the loop invariant check below. -    bool found = false; -    for (const RewritePhi &Phi : RewritePhiSet) { -      unsigned i = Phi.Ith; -      if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { -        found = true; -        break; -      } -    } - -    Instruction *I; -    if (!found && (I = dyn_cast<Instruction>(Incoming))) -      if (!L->hasLoopInvariantOperands(I)) -        return false; - -    ++BI; -  } - -  for (auto *BB : L->blocks()) -    if (llvm::any_of(*BB, [](Instruction &I) { -          return I.mayHaveSideEffects(); -        })) -      return false; - -  return true; -} -  //===----------------------------------------------------------------------===//  //  IV Widening - Extend the width of an IV to cover its widest uses.  //===----------------------------------------------------------------------===// @@ -1060,8 +738,8 @@ protected:    Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter);    bool widenLoopCompare(NarrowIVDefUse DU); -  bool widenWithVariantLoadUse(NarrowIVDefUse DU); -  void widenWithVariantLoadUseCodegen(NarrowIVDefUse DU); +  bool widenWithVariantUse(NarrowIVDefUse DU); +  void widenWithVariantUseCodegen(NarrowIVDefUse DU);    void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef);  }; @@ -1399,20 +1077,27 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) {    return true;  } -/// If the narrow use is an instruction whose two operands are the defining -/// instruction of DU and a load instruction, then we have the following: -/// if the load is hoisted outside the loop, then we do not reach this function -/// as scalar evolution analysis works fine in widenIVUse with variables -/// hoisted outside the loop and efficient code is subsequently generated by -/// not emitting truncate instructions. But when the load is not hoisted -/// (whether due to limitation in alias analysis or due to a true legality), -/// then scalar evolution can not proceed with loop variant values and -/// inefficient code is generated. This function handles the non-hoisted load -/// special case by making the optimization generate the same type of code for -/// hoisted and non-hoisted load (widen use and eliminate sign extend -/// instruction). This special case is important especially when the induction -/// variables are affecting addressing mode in code generation. -bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) { +// The widenIVUse avoids generating trunc by evaluating the use as AddRec, this +// will not work when: +//    1) SCEV traces back to an instruction inside the loop that SCEV can not +// expand, eg. add %indvar, (load %addr) +//    2) SCEV finds a loop variant, eg. add %indvar, %loopvariant +// While SCEV fails to avoid trunc, we can still try to use instruction +// combining approach to prove trunc is not required. This can be further +// extended with other instruction combining checks, but for now we handle the +// following case (sub can be "add" and "mul", "nsw + sext" can be "nus + zext") +// +// Src: +//   %c = sub nsw %b, %indvar +//   %d = sext %c to i64 +// Dst: +//   %indvar.ext1 = sext %indvar to i64 +//   %m = sext %b to i64 +//   %d = sub nsw i64 %m, %indvar.ext1 +// Therefore, as long as the result of add/sub/mul is extended to wide type, no +// trunc is required regardless of how %b is generated. This pattern is common +// when calculating address in 64 bit architecture +bool WidenIV::widenWithVariantUse(NarrowIVDefUse DU) {    Instruction *NarrowUse = DU.NarrowUse;    Instruction *NarrowDef = DU.NarrowDef;    Instruction *WideDef = DU.WideDef; @@ -1443,12 +1128,6 @@ bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) {    else      return false; -  // We are interested in the other operand being a load instruction. -  // But, we should look into relaxing this restriction later on. -  auto *I = dyn_cast<Instruction>(NarrowUse->getOperand(ExtendOperIdx)); -  if (I && I->getOpcode() != Instruction::Load) -    return false; -    // Verifying that Defining operand is an AddRec    const SCEV *Op1 = SE->getSCEV(WideDef);    const SCEVAddRecExpr *AddRecOp1 = dyn_cast<SCEVAddRecExpr>(Op1); @@ -1480,9 +1159,9 @@ bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) {    return true;  } -/// Special Case for widening with variant Loads (see -/// WidenIV::widenWithVariantLoadUse). This is the code generation part. -void WidenIV::widenWithVariantLoadUseCodegen(NarrowIVDefUse DU) { +/// Special Case for widening with loop variant (see +/// WidenIV::widenWithVariant). This is the code generation part. +void WidenIV::widenWithVariantUseCodegen(NarrowIVDefUse DU) {    Instruction *NarrowUse = DU.NarrowUse;    Instruction *NarrowDef = DU.NarrowDef;    Instruction *WideDef = DU.WideDef; @@ -1508,33 +1187,22 @@ void WidenIV::widenWithVariantLoadUseCodegen(NarrowIVDefUse DU) {    Builder.Insert(WideBO);    WideBO->copyIRFlags(NarrowBO); -  if (ExtKind == SignExtended) -    ExtendKindMap[NarrowUse] = SignExtended; -  else -    ExtendKindMap[NarrowUse] = ZeroExtended; +  assert(ExtKind != Unknown && "Unknown ExtKind not handled"); -  // Update the Use. -  if (ExtKind == SignExtended) { -    for (Use &U : NarrowUse->uses()) { -      SExtInst *User = dyn_cast<SExtInst>(U.getUser()); -      if (User && User->getType() == WideType) { -        LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " -                          << *WideBO << "\n"); -        ++NumElimExt; -        User->replaceAllUsesWith(WideBO); -        DeadInsts.emplace_back(User); -      } -    } -  } else { // ExtKind == ZeroExtended -    for (Use &U : NarrowUse->uses()) { -      ZExtInst *User = dyn_cast<ZExtInst>(U.getUser()); -      if (User && User->getType() == WideType) { -        LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " -                          << *WideBO << "\n"); -        ++NumElimExt; -        User->replaceAllUsesWith(WideBO); -        DeadInsts.emplace_back(User); -      } +  ExtendKindMap[NarrowUse] = ExtKind; + +  for (Use &U : NarrowUse->uses()) { +    Instruction *User = nullptr; +    if (ExtKind == SignExtended) +      User = dyn_cast<SExtInst>(U.getUser()); +    else +      User = dyn_cast<ZExtInst>(U.getUser()); +    if (User && User->getType() == WideType) { +      LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " +                        << *WideBO << "\n"); +      ++NumElimExt; +      User->replaceAllUsesWith(WideBO); +      DeadInsts.emplace_back(User);      }    }  } @@ -1641,8 +1309,8 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) {      // in WideAddRec.first does not indicate a polynomial induction expression.      // In that case, look at the operands of the use instruction to determine      // if we can still widen the use instead of truncating its operand. -    if (widenWithVariantLoadUse(DU)) { -      widenWithVariantLoadUseCodegen(DU); +    if (widenWithVariantUse(DU)) { +      widenWithVariantUseCodegen(DU);        return nullptr;      } @@ -1992,8 +1660,8 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L,        // Information about sign/zero extensions of CurrIV.        IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT); -      Changed |= -          simplifyUsersOfIV(CurrIV, SE, DT, LI, DeadInsts, Rewriter, &Visitor); +      Changed |= simplifyUsersOfIV(CurrIV, SE, DT, LI, TTI, DeadInsts, Rewriter, +                                   &Visitor);        if (Visitor.WI.WidestNativeType) {          WideIVs.push_back(Visitor.WI); @@ -2017,7 +1685,7 @@ bool IndVarSimplify::simplifyAndExtend(Loop *L,  /// Given an Value which is hoped to be part of an add recurance in the given  /// loop, return the associated Phi node if so.  Otherwise, return null.  Note -/// that this is less general than SCEVs AddRec checking.   +/// that this is less general than SCEVs AddRec checking.  static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L) {    Instruction *IncI = dyn_cast<Instruction>(IncV);    if (!IncI) @@ -2079,7 +1747,7 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) {    BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());    if (L->isLoopInvariant(BI->getCondition()))      return false; -   +    // Do LFTR to simplify the exit condition to an ICMP.    ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition());    if (!Cond) @@ -2122,9 +1790,9 @@ static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) {  /// actually poison.  This can be used to assess whether a new use of Root can  /// be added at a location which is control equivalent with OnPathTo (such as  /// immediately before it) without introducing UB which didn't previously -/// exist.  Note that a false result conveys no information.   +/// exist.  Note that a false result conveys no information.  static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root, -                                          Instruction *OnPathTo,  +                                          Instruction *OnPathTo,                                            DominatorTree *DT) {    // Basic approach is to assume Root is poison, propagate poison forward    // through all users we can easily track, and then check whether any of those @@ -2142,10 +1810,10 @@ static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,      // If we know this must trigger UB on a path leading our target.      if (mustTriggerUB(I, KnownPoison) && DT->dominates(I, OnPathTo))        return true; -     +      // If we can't analyze propagation through this instruction, just skip it      // and transitive users.  Safe as false is a conservative result. -    if (!propagatesFullPoison(I) && I != Root) +    if (!propagatesPoison(I) && I != Root)        continue;      if (KnownPoison.insert(I).second) @@ -2154,7 +1822,7 @@ static bool mustExecuteUBIfPoisonOnPathTo(Instruction *Root,    }    // Might be non-UB, or might have a path we couldn't prove must execute on -  // way to exiting bb.  +  // way to exiting bb.    return false;  } @@ -2221,7 +1889,7 @@ static bool isLoopCounter(PHINode* Phi, Loop *L,                            ScalarEvolution *SE) {    assert(Phi->getParent() == L->getHeader());    assert(L->getLoopLatch()); -   +    if (!SE->isSCEVable(Phi->getType()))      return false; @@ -2282,7 +1950,7 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,      if (!hasConcreteDef(Phi)) {        // We explicitly allow unknown phis as long as they are already used by        // the loop exit test.  This is legal since performing LFTR could not -      // increase the number of undef users.  +      // increase the number of undef users.        Value *IncPhi = Phi->getIncomingValueForBlock(LatchBlock);        if (!isLoopExitTestBasedOn(Phi, ExitingBB) &&            !isLoopExitTestBasedOn(IncPhi, ExitingBB)) @@ -2300,7 +1968,7 @@ static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB,      if (!Phi->getType()->isIntegerTy() &&          !mustExecuteUBIfPoisonOnPathTo(Phi, ExitingBB->getTerminator(), DT))        continue; -     +      const SCEV *Init = AR->getStart();      if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) { @@ -2506,14 +2174,14 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,      // reasoning as from SimplifyIndvar::eliminateTrunc to see if we can extend      // the other side of the comparison instead.  We still evaluate the limit      // in the narrower bitwidth, we just prefer a zext/sext outside the loop to -    // a truncate within in.   +    // a truncate within in.      bool Extended = false;      const SCEV *IV = SE->getSCEV(CmpIndVar);      const SCEV *TruncatedIV = SE->getTruncateExpr(SE->getSCEV(CmpIndVar),                                                    ExitCnt->getType());      const SCEV *ZExtTrunc =        SE->getZeroExtendExpr(TruncatedIV, CmpIndVar->getType()); -     +      if (ZExtTrunc == IV) {        Extended = true;        ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(), @@ -2531,7 +2199,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,      if (Extended) {        bool Discard;        L->makeLoopInvariant(ExitCnt, Discard); -    } else  +    } else        CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(),                                        "lftr.wideiv");    } @@ -2551,7 +2219,7 @@ linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB,    // update the branch to use the new comparison; in the common case this    // will make old comparison dead.    BI->setCondition(Cond); -  DeadInsts.push_back(OrigCond); +  DeadInsts.emplace_back(OrigCond);    ++NumLFTR;    return true; @@ -2685,11 +2353,10 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {    L->getExitingBlocks(ExitingBlocks);    // Remove all exits which aren't both rewriteable and analyzeable. -  auto NewEnd = llvm::remove_if(ExitingBlocks, -                                [&](BasicBlock *ExitingBB) { +  auto NewEnd = llvm::remove_if(ExitingBlocks, [&](BasicBlock *ExitingBB) {      // If our exitting block exits multiple loops, we can only rewrite the      // innermost one.  Otherwise, we're changing how many times the innermost -    // loop runs before it exits.  +    // loop runs before it exits.      if (LI->getLoopFor(ExitingBB) != L)        return true; @@ -2701,18 +2368,18 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {      // If already constant, nothing to do.      if (isa<Constant>(BI->getCondition()))        return true; -     +      const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);      if (isa<SCEVCouldNotCompute>(ExitCount))        return true;      return false; -   }); +  });    ExitingBlocks.erase(NewEnd, ExitingBlocks.end());    if (ExitingBlocks.empty())      return false; -   -  // Get a symbolic upper bound on the loop backedge taken count.   + +  // Get a symbolic upper bound on the loop backedge taken count.    const SCEV *MaxExitCount = getMaxBackedgeTakenCount(*SE, *DT, L);    if (isa<SCEVCouldNotCompute>(MaxExitCount))      return false; @@ -2720,11 +2387,12 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {    // Visit our exit blocks in order of dominance.  We know from the fact that    // all exits (left) are analyzeable that the must be a total dominance order    // between them as each must dominate the latch.  The visit order only -  // matters for the provably equal case.   +  // matters for the provably equal case.    llvm::sort(ExitingBlocks,               [&](BasicBlock *A, BasicBlock *B) {                 // std::sort sorts in ascending order, so we want the inverse of                 // the normal dominance relation. +               if (A == B) return false;                 if (DT->properlyDominates(A, B)) return true;                 if (DT->properlyDominates(B, A)) return false;                 llvm_unreachable("expected total dominance order!"); @@ -2734,7 +2402,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {      assert(DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i]));    }  #endif -   +    auto FoldExit = [&](BasicBlock *ExitingBB, bool IsTaken) {      BranchInst *BI = cast<BranchInst>(ExitingBB->getTerminator());      bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB)); @@ -2743,7 +2411,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {                                       IsTaken ? ExitIfTrue : !ExitIfTrue);      BI->setCondition(NewCond);      if (OldCond->use_empty()) -      DeadInsts.push_back(OldCond); +      DeadInsts.emplace_back(OldCond);    };    bool Changed = false; @@ -2751,7 +2419,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {    for (BasicBlock *ExitingBB : ExitingBlocks) {      const SCEV *ExitCount = SE->getExitCount(L, ExitingBB);      assert(!isa<SCEVCouldNotCompute>(ExitCount) && "checked above"); -     +      // If we know we'd exit on the first iteration, rewrite the exit to      // reflect this.  This does not imply the loop must exit through this      // exit; there may be an earlier one taken on the first iteration. @@ -2769,13 +2437,13 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {      if (!ExitCount->getType()->isIntegerTy() ||          !MaxExitCount->getType()->isIntegerTy())        continue; -     +      Type *WiderType =        SE->getWiderType(MaxExitCount->getType(), ExitCount->getType());      ExitCount = SE->getNoopOrZeroExtend(ExitCount, WiderType);      MaxExitCount = SE->getNoopOrZeroExtend(MaxExitCount, WiderType);      assert(MaxExitCount->getType() == ExitCount->getType()); -     +      // Can we prove that some other exit must be taken strictly before this      // one?      if (SE->isLoopEntryGuardedByCond(L, CmpInst::ICMP_ULT, @@ -2788,7 +2456,7 @@ bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) {      // As we run, keep track of which exit counts we've encountered.  If we      // find a duplicate, we've found an exit which would have exited on the      // exiting iteration, but (from the visit order) strictly follows another -    // which does the same and is thus dead.  +    // which does the same and is thus dead.      if (!DominatingExitCounts.insert(ExitCount).second) {        FoldExit(ExitingBB, false);        Changed = true; @@ -2809,22 +2477,20 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {    SmallVector<BasicBlock*, 16> ExitingBlocks;    L->getExitingBlocks(ExitingBlocks); -  bool Changed = false; -    // Finally, see if we can rewrite our exit conditions into a loop invariant -  // form.  If we have a read-only loop, and we can tell that we must exit down +  // form. If we have a read-only loop, and we can tell that we must exit down    // a path which does not need any of the values computed within the loop, we    // can rewrite the loop to exit on the first iteration.  Note that this    // doesn't either a) tell us the loop exits on the first iteration (unless    // *all* exits are predicateable) or b) tell us *which* exit might be taken.    // This transformation looks a lot like a restricted form of dead loop    // elimination, but restricted to read-only loops and without neccesssarily -  // needing to kill the loop entirely.  +  // needing to kill the loop entirely.    if (!LoopPredication) -    return Changed; +    return false;    if (!SE->hasLoopInvariantBackedgeTakenCount(L)) -    return Changed; +    return false;    // Note: ExactBTC is the exact backedge taken count *iff* the loop exits    // through *explicit* control flow.  We have to eliminate the possibility of @@ -2833,16 +2499,16 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {    if (isa<SCEVCouldNotCompute>(ExactBTC) ||        !SE->isLoopInvariant(ExactBTC, L) ||        !isSafeToExpand(ExactBTC, *SE)) -    return Changed; +    return false;    // If we end up with a pointer exit count, bail.  It may be unsized.    if (!ExactBTC->getType()->isIntegerTy()) -    return Changed; +    return false;    auto BadExit = [&](BasicBlock *ExitingBB) {      // If our exiting block exits multiple loops, we can only rewrite the      // innermost one.  Otherwise, we're changing how many times the innermost -    // loop runs before it exits.  +    // loop runs before it exits.      if (LI->getLoopFor(ExitingBB) != L)        return true; @@ -2897,18 +2563,18 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {    // is complicated and we choose not to for now.    for (unsigned i = 1; i < ExitingBlocks.size(); i++)      if (!DT->dominates(ExitingBlocks[i-1], ExitingBlocks[i])) -      return Changed; +      return false;    // Given our sorted total order, we know that exit[j] must be evaluated    // after all exit[i] such j > i.    for (unsigned i = 0, e = ExitingBlocks.size(); i < e; i++)      if (BadExit(ExitingBlocks[i])) { -      ExitingBlocks.resize(i);   +      ExitingBlocks.resize(i);        break;      }    if (ExitingBlocks.empty()) -    return Changed; +    return false;    // We rely on not being able to reach an exiting block on a later iteration    // then it's statically compute exit count.  The implementaton of @@ -2930,8 +2596,9 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {      for (auto &I : *BB)        // TODO:isGuaranteedToTransfer        if (I.mayHaveSideEffects() || I.mayThrow()) -        return Changed; +        return false; +  bool Changed = false;    // Finally, do the actual predication for all predicatable blocks.  A couple    // of notes here:    // 1) We don't bother to constant fold dominated exits with identical exit @@ -2970,7 +2637,7 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {      Value *OldCond = BI->getCondition();      BI->setCondition(NewCond);      if (OldCond->use_empty()) -      DeadInsts.push_back(OldCond); +      DeadInsts.emplace_back(OldCond);      Changed = true;    } @@ -2985,7 +2652,6 @@ bool IndVarSimplify::run(Loop *L) {    // We need (and expect!) the incoming loop to be in LCSSA.    assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&           "LCSSA required to run indvars!"); -  bool Changed = false;    // If LoopSimplify form is not available, stay out of trouble. Some notes:    //  - LSR currently only supports LoopSimplify-form loops. Indvars' @@ -3001,9 +2667,15 @@ bool IndVarSimplify::run(Loop *L) {  #ifndef NDEBUG    // Used below for a consistency check only -  const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); +  // Note: Since the result returned by ScalarEvolution may depend on the order +  // in which previous results are added to its cache, the call to +  // getBackedgeTakenCount() may change following SCEV queries. +  const SCEV *BackedgeTakenCount; +  if (VerifyIndvars) +    BackedgeTakenCount = SE->getBackedgeTakenCount(L);  #endif +  bool Changed = false;    // If there are any floating-point recurrences, attempt to    // transform them to use integer recurrences.    Changed |= rewriteNonIntegerIVs(L); @@ -3027,8 +2699,13 @@ bool IndVarSimplify::run(Loop *L) {    // that are recurrent in the loop, and substitute the exit values from the    // loop into any instructions outside of the loop that use the final values    // of the current expressions. -  if (ReplaceExitValue != NeverRepl) -    Changed |= rewriteLoopExitValues(L, Rewriter); +  if (ReplaceExitValue != NeverRepl) { +    if (int Rewrites = rewriteLoopExitValues(L, LI, TLI, SE, TTI, Rewriter, DT, +                                             ReplaceExitValue, DeadInsts)) { +      NumReplaced += Rewrites; +      Changed = true; +    } +  }    // Eliminate redundant IV cycles.    NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); @@ -3039,7 +2716,7 @@ bool IndVarSimplify::run(Loop *L) {      // Given we've changed exit counts, notify SCEV      SE->forgetLoop(L);    } -   +    // Try to form loop invariant tests for loop exits by changing how many    // iterations of the loop run when that is unobservable.    if (predicateLoopExits(L, Rewriter)) { @@ -3049,8 +2726,11 @@ bool IndVarSimplify::run(Loop *L) {    }    // If we have a trip count expression, rewrite the loop's exit condition -  // using it.   +  // using it.    if (!DisableLFTR) { +    BasicBlock *PreHeader = L->getLoopPreheader(); +    BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator()); +      SmallVector<BasicBlock*, 16> ExitingBlocks;      L->getExitingBlocks(ExitingBlocks);      for (BasicBlock *ExitingBB : ExitingBlocks) { @@ -3060,10 +2740,10 @@ bool IndVarSimplify::run(Loop *L) {        // If our exitting block exits multiple loops, we can only rewrite the        // innermost one.  Otherwise, we're changing how many times the innermost -      // loop runs before it exits.  +      // loop runs before it exits.        if (LI->getLoopFor(ExitingBB) != L)          continue; -       +        if (!needsLFTR(L, ExitingBB))          continue; @@ -3077,14 +2757,15 @@ bool IndVarSimplify::run(Loop *L) {        // until stable to handle cases like this better.        if (ExitCount->isZero())          continue; -       +        PHINode *IndVar = FindLoopCounter(L, ExitingBB, ExitCount, SE, DT);        if (!IndVar)          continue; -       +        // Avoid high cost expansions.  Note: This heuristic is questionable in -      // that our definition of "high cost" is not exactly principled.   -      if (Rewriter.isHighCostExpansion(ExitCount, L)) +      // that our definition of "high cost" is not exactly principled. +      if (Rewriter.isHighCostExpansion(ExitCount, L, SCEVCheapExpansionBudget, +                                       TTI, PreHeaderBR))          continue;        // Check preconditions for proper SCEVExpander operation. SCEV does not @@ -3092,7 +2773,7 @@ bool IndVarSimplify::run(Loop *L) {        // any pass that uses the SCEVExpander must do it. This does not work        // well for loop passes because SCEVExpander makes assumptions about        // all loops, while LoopPassManager only forces the current loop to be -      // simplified.  +      // simplified.        //        // FIXME: SCEV expansion has no way to bail out, so the caller must        // explicitly check any assumptions made by SCEV. Brittle. @@ -3113,7 +2794,8 @@ bool IndVarSimplify::run(Loop *L) {    while (!DeadInsts.empty())      if (Instruction *Inst =              dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val())) -      Changed |= RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI); +      Changed |= +          RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI, MSSAU.get());    // The Rewriter may not be used from this point on. @@ -3127,7 +2809,7 @@ bool IndVarSimplify::run(Loop *L) {    Changed |= rewriteFirstIterationLoopExitValues(L);    // Clean up dead instructions. -  Changed |= DeleteDeadPHIs(L->getHeader(), TLI); +  Changed |= DeleteDeadPHIs(L->getHeader(), TLI, MSSAU.get());    // Check a post-condition.    assert(L->isRecursivelyLCSSAForm(*DT, *LI) && @@ -3150,6 +2832,8 @@ bool IndVarSimplify::run(Loop *L) {      assert(!SE->isKnownPredicate(ICmpInst::ICMP_ULT, BackedgeTakenCount,                                   NewBECount) && "indvars must preserve SCEV");    } +  if (VerifyMemorySSA && MSSAU) +    MSSAU->getMemorySSA()->verifyMemorySSA();  #endif    return Changed; @@ -3161,12 +2845,14 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM,    Function *F = L.getHeader()->getParent();    const DataLayout &DL = F->getParent()->getDataLayout(); -  IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI); +  IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI, AR.MSSA);    if (!IVS.run(&L))      return PreservedAnalyses::all();    auto PA = getLoopPassPreservedAnalyses();    PA.preserveSet<CFGAnalyses>(); +  if (AR.MSSA) +    PA.preserve<MemorySSAAnalysis>();    return PA;  } @@ -3191,13 +2877,18 @@ struct IndVarSimplifyLegacyPass : public LoopPass {      auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>();      auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr;      const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); +    auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); +    MemorySSA *MSSA = nullptr; +    if (MSSAAnalysis) +      MSSA = &MSSAAnalysis->getMSSA(); -    IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); +    IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA);      return IVS.run(L);    }    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.setPreservesCFG(); +    AU.addPreserved<MemorySSAWrapperPass>();      getLoopAnalysisUsage(AU);    }  }; diff --git a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 58469749600e..30e4822b6769 100644 --- a/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -47,6 +47,7 @@  #include "llvm/ADT/ArrayRef.h"  #include "llvm/ADT/None.h"  #include "llvm/ADT/Optional.h" +#include "llvm/ADT/PriorityWorklist.h"  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/StringRef.h" @@ -55,8 +56,8 @@  #include "llvm/Analysis/LoopAnalysisManager.h"  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/PostDominators.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/CFG.h" @@ -87,6 +88,7 @@  #include "llvm/Transforms/Utils/Cloning.h"  #include "llvm/Transforms/Utils/LoopSimplify.h"  #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include "llvm/Transforms/Utils/ValueMapper.h"  #include <algorithm>  #include <cassert> @@ -242,20 +244,25 @@ public:    bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop);  }; -class IRCELegacyPass : public LoopPass { +class IRCELegacyPass : public FunctionPass {  public:    static char ID; -  IRCELegacyPass() : LoopPass(ID) { +  IRCELegacyPass() : FunctionPass(ID) {      initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry());    }    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.addRequired<BranchProbabilityInfoWrapperPass>(); -    getLoopAnalysisUsage(AU); +    AU.addRequired<DominatorTreeWrapperPass>(); +    AU.addPreserved<DominatorTreeWrapperPass>(); +    AU.addRequired<LoopInfoWrapperPass>(); +    AU.addPreserved<LoopInfoWrapperPass>(); +    AU.addRequired<ScalarEvolutionWrapperPass>(); +    AU.addPreserved<ScalarEvolutionWrapperPass>();    } -  bool runOnLoop(Loop *L, LPPassManager &LPM) override; +  bool runOnFunction(Function &F) override;  };  } // end anonymous namespace @@ -265,7 +272,9 @@ char IRCELegacyPass::ID = 0;  INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce",                        "Inductive range check elimination", false, false)  INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)  INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination",                      false, false) @@ -866,7 +875,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,    const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);    const SCEV *Step = SE.getSCEV(StepCI); -  ConstantInt *One = ConstantInt::get(IndVarTy, 1); +  const SCEV *FixedRightSCEV = nullptr; + +  // If RightValue resides within loop (but still being loop invariant), +  // regenerate it as preheader. +  if (auto *I = dyn_cast<Instruction>(RightValue)) +    if (L.contains(I->getParent())) +      FixedRightSCEV = RightSCEV; +    if (IsIncreasing) {      bool DecreasedRightValueByOne = false;      if (StepCI->isOne()) { @@ -928,10 +944,9 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,      if (LatchBrExitIdx == 0) {        // We need to increase the right value unless we have already decreased        // it virtually when we replaced EQ with SGT. -      if (!DecreasedRightValueByOne) { -        IRBuilder<> B(Preheader->getTerminator()); -        RightValue = B.CreateAdd(RightValue, One); -      } +      if (!DecreasedRightValueByOne) +        FixedRightSCEV = +            SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));      } else {        assert(!DecreasedRightValueByOne &&               "Right value can be decreased only for LatchBrExitIdx == 0!"); @@ -995,10 +1010,9 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,      if (LatchBrExitIdx == 0) {        // We need to decrease the right value unless we have already increased        // it virtually when we replaced EQ with SLT. -      if (!IncreasedRightValueByOne) { -        IRBuilder<> B(Preheader->getTerminator()); -        RightValue = B.CreateSub(RightValue, One); -      } +      if (!IncreasedRightValueByOne) +        FixedRightSCEV = +            SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));      } else {        assert(!IncreasedRightValueByOne &&               "Right value can be increased only for LatchBrExitIdx == 0!"); @@ -1012,9 +1026,14 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE,    assert(!L.contains(LatchExit) && "expected an exit block!");    const DataLayout &DL = Preheader->getModule()->getDataLayout(); -  Value *IndVarStartV = -      SCEVExpander(SE, DL, "irce") -          .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator()); +  SCEVExpander Expander(SE, DL, "irce"); +  Instruction *Ins = Preheader->getTerminator(); + +  if (FixedRightSCEV) +    RightValue = +        Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); + +  Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);    IndVarStartV->setName("indvar.start");    LoopStructure Result; @@ -1747,27 +1766,41 @@ IntersectUnsignedRange(ScalarEvolution &SE,    return Ret;  } -PreservedAnalyses IRCEPass::run(Loop &L, LoopAnalysisManager &AM, -                                LoopStandardAnalysisResults &AR, -                                LPMUpdater &U) { -  Function *F = L.getHeader()->getParent(); -  const auto &FAM = -      AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); -  auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); -  InductiveRangeCheckElimination IRCE(AR.SE, BPI, AR.DT, AR.LI); -  auto LPMAddNewLoop = [&U](Loop *NL, bool IsSubloop) { +PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) { +  auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &BPI = AM.getResult<BranchProbabilityAnalysis>(F); +  LoopInfo &LI = AM.getResult<LoopAnalysis>(F); + +  InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); + +  bool Changed = false; + +  for (const auto &L : LI) { +    Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, +                            /*PreserveLCSSA=*/false); +    Changed |= formLCSSARecursively(*L, DT, &LI, &SE); +  } + +  SmallPriorityWorklist<Loop *, 4> Worklist; +  appendLoopsToWorklist(LI, Worklist); +  auto LPMAddNewLoop = [&Worklist](Loop *NL, bool IsSubloop) {      if (!IsSubloop) -      U.addSiblingLoops(NL); +      appendLoopsToWorklist(*NL, Worklist);    }; -  bool Changed = IRCE.run(&L, LPMAddNewLoop); + +  while (!Worklist.empty()) { +    Loop *L = Worklist.pop_back_val(); +    Changed |= IRCE.run(L, LPMAddNewLoop); +  } +    if (!Changed)      return PreservedAnalyses::all(); -    return getLoopPassPreservedAnalyses();  } -bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { -  if (skipLoop(L)) +bool IRCELegacyPass::runOnFunction(Function &F) { +  if (skipFunction(F))      return false;    ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); @@ -1776,10 +1809,27 @@ bool IRCELegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();    InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI); -  auto LPMAddNewLoop = [&LPM](Loop *NL, bool /* IsSubLoop */) { -    LPM.addLoop(*NL); + +  bool Changed = false; + +  for (const auto &L : LI) { +    Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, +                            /*PreserveLCSSA=*/false); +    Changed |= formLCSSARecursively(*L, DT, &LI, &SE); +  } + +  SmallPriorityWorklist<Loop *, 4> Worklist; +  appendLoopsToWorklist(LI, Worklist); +  auto LPMAddNewLoop = [&](Loop *NL, bool IsSubloop) { +    if (!IsSubloop) +      appendLoopsToWorklist(*NL, Worklist);    }; -  return IRCE.run(L, LPMAddNewLoop); + +  while (!Worklist.empty()) { +    Loop *L = Worklist.pop_back_val(); +    Changed |= IRCE.run(L, LPMAddNewLoop); +  } +  return Changed;  }  bool InductiveRangeCheckElimination::run( diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp index dfb1b6bfb739..db9cc58bbfc4 100644 --- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -96,7 +96,6 @@  #include "llvm/ADT/SetVector.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Transforms/Utils/Local.h"  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h" @@ -116,11 +115,13 @@  #include "llvm/IR/ValueHandle.h"  #include "llvm/Pass.h"  #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h"  #include "llvm/Support/Compiler.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/ErrorHandling.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/ValueMapper.h"  #include <cassert>  #include <iterator> @@ -132,16 +133,23 @@  using namespace llvm; +static cl::opt<bool> AssumeDefaultIsFlatAddressSpace( +    "assume-default-is-flat-addrspace", cl::init(false), cl::ReallyHidden, +    cl::desc("The default address space is assumed as the flat address space. " +             "This is mainly for test purpose.")); +  static const unsigned UninitializedAddressSpace =      std::numeric_limits<unsigned>::max();  namespace {  using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; +using PostorderStackTy = llvm::SmallVector<PointerIntPair<Value *, 1, bool>, 4>;  /// InferAddressSpaces  class InferAddressSpaces : public FunctionPass {    const TargetTransformInfo *TTI = nullptr; +  const DataLayout *DL = nullptr;    /// Target specific address space which uses of should be replaced if    /// possible. @@ -174,6 +182,11 @@ private:    bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; +  Value *cloneInstructionWithNewAddressSpace( +      Instruction *I, unsigned NewAddrSpace, +      const ValueToValueMapTy &ValueWithNewAddrSpace, +      SmallVectorImpl<const Use *> *UndefUsesToFix) const; +    // Changes the flat address expressions in function F to point to specific    // address spaces if InferredAddrSpace says so. Postorder is the postorder of    // all flat expressions in the use-def graph of function F. @@ -182,15 +195,14 @@ private:        const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const;    void appendsFlatAddressExpressionToPostorderStack( -    Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, -    DenseSet<Value *> &Visited) const; +      Value *V, PostorderStackTy &PostorderStack, +      DenseSet<Value *> &Visited) const;    bool rewriteIntrinsicOperands(IntrinsicInst *II,                                  Value *OldV, Value *NewV) const; -  void collectRewritableIntrinsicOperands( -    IntrinsicInst *II, -    std::vector<std::pair<Value *, bool>> &PostorderStack, -    DenseSet<Value *> &Visited) const; +  void collectRewritableIntrinsicOperands(IntrinsicInst *II, +                                          PostorderStackTy &PostorderStack, +                                          DenseSet<Value *> &Visited) const;    std::vector<WeakTrackingVH> collectFlatAddressExpressions(Function &F) const; @@ -214,24 +226,65 @@ void initializeInferAddressSpacesPass(PassRegistry &);  INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",                  false, false) +// Check whether that's no-op pointer bicast using a pair of +// `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over +// different address spaces. +static bool isNoopPtrIntCastPair(const Operator *I2P, const DataLayout &DL, +                                 const TargetTransformInfo *TTI) { +  assert(I2P->getOpcode() == Instruction::IntToPtr); +  auto *P2I = dyn_cast<Operator>(I2P->getOperand(0)); +  if (!P2I || P2I->getOpcode() != Instruction::PtrToInt) +    return false; +  // Check it's really safe to treat that pair of `ptrtoint`/`inttoptr` as a +  // no-op cast. Besides checking both of them are no-op casts, as the +  // reinterpreted pointer may be used in other pointer arithmetic, we also +  // need to double-check that through the target-specific hook. That ensures +  // the underlying target also agrees that's a no-op address space cast and +  // pointer bits are preserved. +  // The current IR spec doesn't have clear rules on address space casts, +  // especially a clear definition for pointer bits in non-default address +  // spaces. It would be undefined if that pointer is dereferenced after an +  // invalid reinterpret cast. Also, due to the unclearness for the meaning of +  // bits in non-default address spaces in the current spec, the pointer +  // arithmetic may also be undefined after invalid pointer reinterpret cast. +  // However, as we confirm through the target hooks that it's a no-op +  // addrspacecast, it doesn't matter since the bits should be the same. +  return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()), +                              I2P->getOperand(0)->getType(), I2P->getType(), +                              DL) && +         CastInst::isNoopCast(Instruction::CastOps(P2I->getOpcode()), +                              P2I->getOperand(0)->getType(), P2I->getType(), +                              DL) && +         TTI->isNoopAddrSpaceCast( +             P2I->getOperand(0)->getType()->getPointerAddressSpace(), +             I2P->getType()->getPointerAddressSpace()); +} +  // Returns true if V is an address expression.  // TODO: Currently, we consider only phi, bitcast, addrspacecast, and  // getelementptr operators. -static bool isAddressExpression(const Value &V) { -  if (!isa<Operator>(V)) +static bool isAddressExpression(const Value &V, const DataLayout &DL, +                                const TargetTransformInfo *TTI) { +  const Operator *Op = dyn_cast<Operator>(&V); +  if (!Op)      return false; -  const Operator &Op = cast<Operator>(V); -  switch (Op.getOpcode()) { +  switch (Op->getOpcode()) {    case Instruction::PHI: -    assert(Op.getType()->isPointerTy()); +    assert(Op->getType()->isPointerTy());      return true;    case Instruction::BitCast:    case Instruction::AddrSpaceCast:    case Instruction::GetElementPtr:      return true;    case Instruction::Select: -    return Op.getType()->isPointerTy(); +    return Op->getType()->isPointerTy(); +  case Instruction::Call: { +    const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V); +    return II && II->getIntrinsicID() == Intrinsic::ptrmask; +  } +  case Instruction::IntToPtr: +    return isNoopPtrIntCastPair(Op, DL, TTI);    default:      return false;    } @@ -240,7 +293,9 @@ static bool isAddressExpression(const Value &V) {  // Returns the pointer operands of V.  //  // Precondition: V is an address expression. -static SmallVector<Value *, 2> getPointerOperands(const Value &V) { +static SmallVector<Value *, 2> +getPointerOperands(const Value &V, const DataLayout &DL, +                   const TargetTransformInfo *TTI) {    const Operator &Op = cast<Operator>(V);    switch (Op.getOpcode()) {    case Instruction::PHI: { @@ -254,12 +309,22 @@ static SmallVector<Value *, 2> getPointerOperands(const Value &V) {      return {Op.getOperand(0)};    case Instruction::Select:      return {Op.getOperand(1), Op.getOperand(2)}; +  case Instruction::Call: { +    const IntrinsicInst &II = cast<IntrinsicInst>(Op); +    assert(II.getIntrinsicID() == Intrinsic::ptrmask && +           "unexpected intrinsic call"); +    return {II.getArgOperand(0)}; +  } +  case Instruction::IntToPtr: { +    assert(isNoopPtrIntCastPair(&Op, DL, TTI)); +    auto *P2I = cast<Operator>(Op.getOperand(0)); +    return {P2I->getOperand(0)}; +  }    default:      llvm_unreachable("Unexpected instruction type.");    }  } -// TODO: Move logic to TTI?  bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,                                                    Value *OldV,                                                    Value *NewV) const { @@ -275,16 +340,26 @@ bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II,      II->setCalledFunction(NewDecl);      return true;    } -  default: -    return TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); +  case Intrinsic::ptrmask: +    // This is handled as an address expression, not as a use memory operation. +    return false; +  default: { +    Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV); +    if (!Rewrite) +      return false; +    if (Rewrite != II) +      II->replaceAllUsesWith(Rewrite); +    return true; +  }    }  }  void InferAddressSpaces::collectRewritableIntrinsicOperands( -    IntrinsicInst *II, std::vector<std::pair<Value *, bool>> &PostorderStack, +    IntrinsicInst *II, PostorderStackTy &PostorderStack,      DenseSet<Value *> &Visited) const {    auto IID = II->getIntrinsicID();    switch (IID) { +  case Intrinsic::ptrmask:    case Intrinsic::objectsize:      appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),                                                   PostorderStack, Visited); @@ -305,7 +380,7 @@ void InferAddressSpaces::collectRewritableIntrinsicOperands(  // If V is an unvisited flat address expression, appends V to PostorderStack  // and marks it as visited.  void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack( -    Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, +    Value *V, PostorderStackTy &PostorderStack,      DenseSet<Value *> &Visited) const {    assert(V->getType()->isPointerTy()); @@ -313,21 +388,21 @@ void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack(    // expressions.    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {      // TODO: Look in non-address parts, like icmp operands. -    if (isAddressExpression(*CE) && Visited.insert(CE).second) -      PostorderStack.push_back(std::make_pair(CE, false)); +    if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second) +      PostorderStack.emplace_back(CE, false);      return;    } -  if (isAddressExpression(*V) && +  if (isAddressExpression(*V, *DL, TTI) &&        V->getType()->getPointerAddressSpace() == FlatAddrSpace) {      if (Visited.insert(V).second) { -      PostorderStack.push_back(std::make_pair(V, false)); +      PostorderStack.emplace_back(V, false);        Operator *Op = cast<Operator>(V);        for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) {          if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op->getOperand(I))) { -          if (isAddressExpression(*CE) && Visited.insert(CE).second) +          if (isAddressExpression(*CE, *DL, TTI) && Visited.insert(CE).second)              PostorderStack.emplace_back(CE, false);          }        } @@ -341,7 +416,7 @@ std::vector<WeakTrackingVH>  InferAddressSpaces::collectFlatAddressExpressions(Function &F) const {    // This function implements a non-recursive postorder traversal of a partial    // use-def graph of function F. -  std::vector<std::pair<Value *, bool>> PostorderStack; +  PostorderStackTy PostorderStack;    // The set of visited expressions.    DenseSet<Value *> Visited; @@ -383,23 +458,27 @@ InferAddressSpaces::collectFlatAddressExpressions(Function &F) const {      } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) {        if (!ASC->getType()->isVectorTy())          PushPtrOperand(ASC->getPointerOperand()); +    } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) { +      if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI)) +        PushPtrOperand( +            cast<PtrToIntInst>(I2P->getOperand(0))->getPointerOperand());      }    }    std::vector<WeakTrackingVH> Postorder; // The resultant postorder.    while (!PostorderStack.empty()) { -    Value *TopVal = PostorderStack.back().first; +    Value *TopVal = PostorderStack.back().getPointer();      // If the operands of the expression on the top are already explored,      // adds that expression to the resultant postorder. -    if (PostorderStack.back().second) { +    if (PostorderStack.back().getInt()) {        if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace)          Postorder.push_back(TopVal);        PostorderStack.pop_back();        continue;      }      // Otherwise, adds its operands to the stack and explores them. -    PostorderStack.back().second = true; -    for (Value *PtrOperand : getPointerOperands(*TopVal)) { +    PostorderStack.back().setInt(true); +    for (Value *PtrOperand : getPointerOperands(*TopVal, *DL, TTI)) {        appendsFlatAddressExpressionToPostorderStack(PtrOperand, PostorderStack,                                                     Visited);      } @@ -438,10 +517,13 @@ static Value *operandWithNewAddressSpaceOrCreateUndef(  // Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast  // from a pointer whose type already matches. Therefore, this function returns a  // Value* instead of an Instruction*. -static Value *cloneInstructionWithNewAddressSpace( +// +// This may also return nullptr in the case the instruction could not be +// rewritten. +Value *InferAddressSpaces::cloneInstructionWithNewAddressSpace(      Instruction *I, unsigned NewAddrSpace,      const ValueToValueMapTy &ValueWithNewAddrSpace, -    SmallVectorImpl<const Use *> *UndefUsesToFix) { +    SmallVectorImpl<const Use *> *UndefUsesToFix) const {    Type *NewPtrType =        I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); @@ -456,6 +538,23 @@ static Value *cloneInstructionWithNewAddressSpace(      return Src;    } +  if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { +    // Technically the intrinsic ID is a pointer typed argument, so specially +    // handle calls early. +    assert(II->getIntrinsicID() == Intrinsic::ptrmask); +    Value *NewPtr = operandWithNewAddressSpaceOrCreateUndef( +        II->getArgOperandUse(0), NewAddrSpace, ValueWithNewAddrSpace, +        UndefUsesToFix); +    Value *Rewrite = +        TTI->rewriteIntrinsicWithAddressSpace(II, II->getArgOperand(0), NewPtr); +    if (Rewrite) { +      assert(Rewrite != II && "cannot modify this pointer operation in place"); +      return Rewrite; +    } + +    return nullptr; +  } +    // Computes the converted pointer operands.    SmallVector<Value *, 4> NewPointerOperands;    for (const Use &OperandUse : I->operands()) { @@ -492,6 +591,14 @@ static Value *cloneInstructionWithNewAddressSpace(      assert(I->getType()->isPointerTy());      return SelectInst::Create(I->getOperand(0), NewPointerOperands[1],                                NewPointerOperands[2], "", nullptr, I); +  case Instruction::IntToPtr: { +    assert(isNoopPtrIntCastPair(cast<Operator>(I), *DL, TTI)); +    Value *Src = cast<Operator>(I->getOperand(0))->getOperand(0); +    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); +    if (Src->getType() != NewPtrType) +      return new BitCastInst(Src, NewPtrType); +    return Src; +  }    default:      llvm_unreachable("Unexpected opcode");    } @@ -501,8 +608,9 @@ static Value *cloneInstructionWithNewAddressSpace(  // constant expression `CE` with its operands replaced as specified in  // ValueWithNewAddrSpace.  static Value *cloneConstantExprWithNewAddressSpace( -  ConstantExpr *CE, unsigned NewAddrSpace, -  const ValueToValueMapTy &ValueWithNewAddrSpace) { +    ConstantExpr *CE, unsigned NewAddrSpace, +    const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL, +    const TargetTransformInfo *TTI) {    Type *TargetType =      CE->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); @@ -533,6 +641,13 @@ static Value *cloneConstantExprWithNewAddressSpace(      }    } +  if (CE->getOpcode() == Instruction::IntToPtr) { +    assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI)); +    Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0); +    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); +    return ConstantExpr::getBitCast(Src, TargetType); +  } +    // Computes the operands of the new constant expression.    bool IsNew = false;    SmallVector<Constant *, 4> NewOperands; @@ -550,7 +665,7 @@ static Value *cloneConstantExprWithNewAddressSpace(      }      if (auto CExpr = dyn_cast<ConstantExpr>(Operand))        if (Value *NewOperand = cloneConstantExprWithNewAddressSpace( -              CExpr, NewAddrSpace, ValueWithNewAddrSpace)) { +              CExpr, NewAddrSpace, ValueWithNewAddrSpace, DL, TTI)) {          IsNew = true;          NewOperands.push_back(cast<Constant>(NewOperand));          continue; @@ -585,13 +700,13 @@ Value *InferAddressSpaces::cloneValueWithNewAddressSpace(    const ValueToValueMapTy &ValueWithNewAddrSpace,    SmallVectorImpl<const Use *> *UndefUsesToFix) const {    // All values in Postorder are flat address expressions. -  assert(isAddressExpression(*V) && +  assert(isAddressExpression(*V, *DL, TTI) &&           V->getType()->getPointerAddressSpace() == FlatAddrSpace);    if (Instruction *I = dyn_cast<Instruction>(V)) {      Value *NewV = cloneInstructionWithNewAddressSpace(        I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix); -    if (Instruction *NewI = dyn_cast<Instruction>(NewV)) { +    if (Instruction *NewI = dyn_cast_or_null<Instruction>(NewV)) {        if (NewI->getParent() == nullptr) {          NewI->insertBefore(I);          NewI->takeName(I); @@ -601,7 +716,7 @@ Value *InferAddressSpaces::cloneValueWithNewAddressSpace(    }    return cloneConstantExprWithNewAddressSpace( -    cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace); +      cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace, DL, TTI);  }  // Defines the join operation on the address space lattice (see the file header @@ -625,6 +740,10 @@ bool InferAddressSpaces::runOnFunction(Function &F) {      return false;    TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); +  DL = &F.getParent()->getDataLayout(); + +  if (AssumeDefaultIsFlatAddressSpace) +    FlatAddrSpace = 0;    if (FlatAddrSpace == UninitializedAddressSpace) {      FlatAddrSpace = TTI->getFlatAddressSpace(); @@ -729,7 +848,7 @@ Optional<unsigned> InferAddressSpaces::updateAddressSpace(      else        NewAS = joinAddressSpaces(Src0AS, Src1AS);    } else { -    for (Value *PtrOperand : getPointerOperands(V)) { +    for (Value *PtrOperand : getPointerOperands(V, *DL, TTI)) {        auto I = InferredAddrSpace.find(PtrOperand);        unsigned OperandAS = I != InferredAddrSpace.end() ?          I->second : PtrOperand->getType()->getPointerAddressSpace(); @@ -879,8 +998,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces(    for (Value* V : Postorder) {      unsigned NewAddrSpace = InferredAddrSpace.lookup(V);      if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { -      ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace( -        V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); +      Value *New = cloneValueWithNewAddressSpace( +          V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); +      if (New) +        ValueWithNewAddrSpace[V] = New;      }    } @@ -890,7 +1011,10 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces(    // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace.    for (const Use *UndefUse : UndefUsesToFix) {      User *V = UndefUse->getUser(); -    User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V)); +    User *NewV = cast_or_null<User>(ValueWithNewAddrSpace.lookup(V)); +    if (!NewV) +      continue; +      unsigned OperandNo = UndefUse->getOperandNo();      assert(isa<UndefValue>(NewV->getOperand(OperandNo)));      NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get())); diff --git a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp index e8bbf2936da6..e87b622ab19f 100644 --- a/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp +++ b/llvm/lib/Transforms/Scalar/InstSimplifyPass.cpp @@ -40,7 +40,7 @@ static bool runImpl(Function &F, const SimplifyQuery &SQ,        if (!SQ.DT->isReachableFromEntry(&BB))          continue; -      SmallVector<Instruction *, 8> DeadInstsInBB; +      SmallVector<WeakTrackingVH, 8> DeadInstsInBB;        for (Instruction &I : BB) {          // The first time through the loop, ToSimplify is empty and we try to          // simplify all instructions. On later iterations, ToSimplify is not diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp index 98c2fcb3dae0..9d0500419a7f 100644 --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -13,6 +13,7 @@  #include "llvm/Transforms/Scalar/JumpThreading.h"  #include "llvm/ADT/DenseMap.h"  #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h"  #include "llvm/ADT/Optional.h"  #include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/SmallPtrSet.h" @@ -170,7 +171,7 @@ FunctionPass *llvm::createJumpThreadingPass(int Threshold) {  }  JumpThreadingPass::JumpThreadingPass(int T) { -  BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); +  DefaultBBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T);  }  // Update branch probability information according to conditional @@ -213,11 +214,16 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {    if (!CondBr)      return; -  BranchProbability BP;    uint64_t TrueWeight, FalseWeight;    if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight))      return; +  if (TrueWeight + FalseWeight == 0) +    // Zero branch_weights do not give a hint for getting branch probabilities. +    // Technically it would result in division by zero denominator, which is +    // TrueWeight + FalseWeight. +    return; +    // Returns the outgoing edge of the dominating predecessor block    // that leads to the PhiNode's incoming block:    auto GetPredOutEdge = @@ -252,10 +258,11 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {      if (!CI || !CI->getType()->isIntegerTy(1))        continue; -    BP = (CI->isOne() ? BranchProbability::getBranchProbability( -                            TrueWeight, TrueWeight + FalseWeight) -                      : BranchProbability::getBranchProbability( -                            FalseWeight, TrueWeight + FalseWeight)); +    BranchProbability BP = +        (CI->isOne() ? BranchProbability::getBranchProbability( +                           TrueWeight, TrueWeight + FalseWeight) +                     : BranchProbability::getBranchProbability( +                           FalseWeight, TrueWeight + FalseWeight));      auto PredOutEdge = GetPredOutEdge(PN->getIncomingBlock(i), BB);      if (!PredOutEdge.first) @@ -298,8 +305,6 @@ bool JumpThreading::runOnFunction(Function &F) {    if (skipFunction(F))      return false;    auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); -  // Get DT analysis before LVI. When LVI is initialized it conditionally adds -  // DT if it's available.    auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();    auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();    auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); @@ -316,7 +321,7 @@ bool JumpThreading::runOnFunction(Function &F) {                                std::move(BFI), std::move(BPI));    if (PrintLVIAfterJumpThreading) {      dbgs() << "LVI for function '" << F.getName() << "':\n"; -    LVI->printLVI(F, *DT, dbgs()); +    LVI->printLVI(F, DTU.getDomTree(), dbgs());    }    return Changed;  } @@ -324,8 +329,6 @@ bool JumpThreading::runOnFunction(Function &F) {  PreservedAnalyses JumpThreadingPass::run(Function &F,                                           FunctionAnalysisManager &AM) {    auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); -  // Get DT analysis before LVI. When LVI is initialized it conditionally adds -  // DT if it's available.    auto &DT = AM.getResult<DominatorTreeAnalysis>(F);    auto &LVI = AM.getResult<LazyValueAnalysis>(F);    auto &AA = AM.getResult<AAManager>(F); @@ -374,6 +377,15 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,      BFI = std::move(BFI_);    } +  // Reduce the number of instructions duplicated when optimizing strictly for +  // size. +  if (BBDuplicateThreshold.getNumOccurrences()) +    BBDupThreshold = BBDuplicateThreshold; +  else if (F.hasFnAttribute(Attribute::MinSize)) +    BBDupThreshold = 3; +  else +    BBDupThreshold = DefaultBBDupThreshold; +    // JumpThreading must not processes blocks unreachable from entry. It's a    // waste of compute time and can potentially lead to hangs.    SmallPtrSet<BasicBlock *, 16> Unreachable; @@ -396,6 +408,12 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,          continue;        while (ProcessBlock(&BB)) // Thread all of the branches we can over BB.          Changed = true; + +      // Jump threading may have introduced redundant debug values into BB +      // which should be removed. +      if (Changed) +        RemoveRedundantDbgInstrs(&BB); +        // Stop processing BB if it's the entry or is now deleted. The following        // routines attempt to eliminate BB and locating a suitable replacement        // for the entry is non-trivial. @@ -418,26 +436,27 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_,        // ProcessBlock doesn't thread BBs with unconditional TIs. However, if BB        // is "almost empty", we attempt to merge BB with its sole successor.        auto *BI = dyn_cast<BranchInst>(BB.getTerminator()); -      if (BI && BI->isUnconditional() && -          // The terminator must be the only non-phi instruction in BB. -          BB.getFirstNonPHIOrDbg()->isTerminator() && -          // Don't alter Loop headers and latches to ensure another pass can -          // detect and transform nested loops later. -          !LoopHeaders.count(&BB) && !LoopHeaders.count(BI->getSuccessor(0)) && -          TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) { -        // BB is valid for cleanup here because we passed in DTU. F remains -        // BB's parent until a DTU->getDomTree() event. -        LVI->eraseBlock(&BB); -        Changed = true; +      if (BI && BI->isUnconditional()) { +        BasicBlock *Succ = BI->getSuccessor(0); +        if ( +            // The terminator must be the only non-phi instruction in BB. +            BB.getFirstNonPHIOrDbg()->isTerminator() && +            // Don't alter Loop headers and latches to ensure another pass can +            // detect and transform nested loops later. +            !LoopHeaders.count(&BB) && !LoopHeaders.count(Succ) && +            TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) { +          RemoveRedundantDbgInstrs(Succ); +          // BB is valid for cleanup here because we passed in DTU. F remains +          // BB's parent until a DTU->getDomTree() event. +          LVI->eraseBlock(&BB); +          Changed = true; +        }        }      }      EverChanged |= Changed;    } while (Changed);    LoopHeaders.clear(); -  // Flush only the Dominator Tree. -  DTU->getDomTree(); -  LVI->enableDT();    return EverChanged;  } @@ -592,20 +611,19 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) {  /// This returns true if there were any known values.  bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(      Value *V, BasicBlock *BB, PredValueInfo &Result, -    ConstantPreference Preference, -    DenseSet<std::pair<Value *, BasicBlock *>> &RecursionSet, +    ConstantPreference Preference, DenseSet<Value *> &RecursionSet,      Instruction *CxtI) {    // This method walks up use-def chains recursively.  Because of this, we could    // get into an infinite loop going around loops in the use-def chain.  To    // prevent this, keep track of what (value, block) pairs we've already visited    // and terminate the search if we loop back to them -  if (!RecursionSet.insert(std::make_pair(V, BB)).second) +  if (!RecursionSet.insert(V).second)      return false;    // If V is a constant, then it is known in all predecessors.    if (Constant *KC = getKnownConstant(V, Preference)) {      for (BasicBlock *Pred : predecessors(BB)) -      Result.push_back(std::make_pair(KC, Pred)); +      Result.emplace_back(KC, Pred);      return !Result.empty();    } @@ -627,17 +645,12 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(      // able to handle value inequalities better, for example if the compare is      // "X < 4" and "X < 3" is known true but "X < 4" itself is not available.      // Perhaps getConstantOnEdge should be smart enough to do this? - -    if (DTU->hasPendingDomTreeUpdates()) -      LVI->disableDT(); -    else -      LVI->enableDT();      for (BasicBlock *P : predecessors(BB)) {        // If the value is known by LazyValueInfo to be a constant in a        // predecessor, use that information to try to thread this block.        Constant *PredCst = LVI->getConstantOnEdge(V, P, BB, CxtI);        if (Constant *KC = getKnownConstant(PredCst, Preference)) -        Result.push_back(std::make_pair(KC, P)); +        Result.emplace_back(KC, P);      }      return !Result.empty(); @@ -645,20 +658,16 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(    /// If I is a PHI node, then we know the incoming values for any constants.    if (PHINode *PN = dyn_cast<PHINode>(I)) { -    if (DTU->hasPendingDomTreeUpdates()) -      LVI->disableDT(); -    else -      LVI->enableDT();      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {        Value *InVal = PN->getIncomingValue(i);        if (Constant *KC = getKnownConstant(InVal, Preference)) { -        Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); +        Result.emplace_back(KC, PN->getIncomingBlock(i));        } else {          Constant *CI = LVI->getConstantOnEdge(InVal,                                                PN->getIncomingBlock(i),                                                BB, CxtI);          if (Constant *KC = getKnownConstant(CI, Preference)) -          Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); +          Result.emplace_back(KC, PN->getIncomingBlock(i));        }      } @@ -757,7 +766,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(          Constant *Folded = ConstantExpr::get(BO->getOpcode(), V, CI);          if (Constant *KC = getKnownConstant(Folded, WantInteger)) -          Result.push_back(std::make_pair(KC, LHSVal.second)); +          Result.emplace_back(KC, LHSVal.second);        }      } @@ -779,10 +788,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(        const DataLayout &DL = PN->getModule()->getDataLayout();        // We can do this simplification if any comparisons fold to true or false.        // See if any do. -      if (DTU->hasPendingDomTreeUpdates()) -        LVI->disableDT(); -      else -        LVI->enableDT();        for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {          BasicBlock *PredBB = PN->getIncomingBlock(i);          Value *LHS, *RHS; @@ -813,7 +818,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(          }          if (Constant *KC = getKnownConstant(Res, WantInteger)) -          Result.push_back(std::make_pair(KC, PredBB)); +          Result.emplace_back(KC, PredBB);        }        return !Result.empty(); @@ -826,10 +831,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(        if (!isa<Instruction>(CmpLHS) ||            cast<Instruction>(CmpLHS)->getParent() != BB) { -        if (DTU->hasPendingDomTreeUpdates()) -          LVI->disableDT(); -        else -          LVI->enableDT();          for (BasicBlock *P : predecessors(BB)) {            // If the value is known by LazyValueInfo to be a constant in a            // predecessor, use that information to try to thread this block. @@ -840,7 +841,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(              continue;            Constant *ResC = ConstantInt::get(CmpType, Res); -          Result.push_back(std::make_pair(ResC, P)); +          Result.emplace_back(ResC, P);          }          return !Result.empty(); @@ -858,10 +859,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(              match(CmpLHS, m_Add(m_Value(AddLHS), m_ConstantInt(AddConst)))) {            if (!isa<Instruction>(AddLHS) ||                cast<Instruction>(AddLHS)->getParent() != BB) { -            if (DTU->hasPendingDomTreeUpdates()) -              LVI->disableDT(); -            else -              LVI->enableDT();              for (BasicBlock *P : predecessors(BB)) {                // If the value is known by LazyValueInfo to be a ConstantRange in                // a predecessor, use that information to try to thread this @@ -883,7 +880,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(                else                  continue; -              Result.push_back(std::make_pair(ResC, P)); +              Result.emplace_back(ResC, P);              }              return !Result.empty(); @@ -901,7 +898,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(          Constant *V = LHSVal.first;          Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst);          if (Constant *KC = getKnownConstant(Folded, WantInteger)) -          Result.push_back(std::make_pair(KC, LHSVal.second)); +          Result.emplace_back(KC, LHSVal.second);        }        return !Result.empty(); @@ -935,7 +932,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(          // See if the select has a known constant value for this predecessor.          if (Constant *Val = KnownCond ? TrueVal : FalseVal) -          Result.push_back(std::make_pair(Val, C.second)); +          Result.emplace_back(Val, C.second);        }        return !Result.empty(); @@ -943,14 +940,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(    }    // If all else fails, see if LVI can figure out a constant value for us. -  if (DTU->hasPendingDomTreeUpdates()) -    LVI->disableDT(); -  else -    LVI->enableDT();    Constant *CI = LVI->getConstant(V, BB, CxtI);    if (Constant *KC = getKnownConstant(CI, Preference)) {      for (BasicBlock *Pred : predecessors(BB)) -      Result.push_back(std::make_pair(KC, Pred)); +      Result.emplace_back(KC, Pred);    }    return !Result.empty(); @@ -1106,10 +1099,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) {        // threading is concerned.        assert(CondBr->isConditional() && "Threading on unconditional terminator"); -      if (DTU->hasPendingDomTreeUpdates()) -        LVI->disableDT(); -      else -        LVI->enableDT();        LazyValueInfo::Tristate Ret =          LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0),                              CondConst, CondBr); @@ -1363,7 +1352,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) {      // If so, this load is partially redundant.  Remember this info so that we      // can create a PHI node. -    AvailablePreds.push_back(std::make_pair(PredBB, PredAvailable)); +    AvailablePreds.emplace_back(PredBB, PredAvailable);    }    // If the loaded value isn't available in any predecessor, it isn't partially @@ -1430,14 +1419,14 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) {             "Can't handle critical edge here!");      LoadInst *NewVal = new LoadInst(          LoadI->getType(), LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), -        LoadI->getName() + ".pr", false, MaybeAlign(LoadI->getAlignment()), +        LoadI->getName() + ".pr", false, LoadI->getAlign(),          LoadI->getOrdering(), LoadI->getSyncScopeID(),          UnavailablePred->getTerminator());      NewVal->setDebugLoc(LoadI->getDebugLoc());      if (AATags)        NewVal->setAAMetadata(AATags); -    AvailablePreds.push_back(std::make_pair(UnavailablePred, NewVal)); +    AvailablePreds.emplace_back(UnavailablePred, NewVal);    }    // Now we know that each predecessor of this block has a value in @@ -1496,56 +1485,70 @@ FindMostPopularDest(BasicBlock *BB,    // explicitly choose to ignore 'undef' destinations.  We prefer to thread    // blocks with known and real destinations to threading undef.  We'll handle    // them later if interesting. -  DenseMap<BasicBlock*, unsigned> DestPopularity; +  MapVector<BasicBlock *, unsigned> DestPopularity; + +  // Populate DestPopularity with the successors in the order they appear in the +  // successor list.  This way, we ensure determinism by iterating it in the +  // same order in std::max_element below.  We map nullptr to 0 so that we can +  // return nullptr when PredToDestList contains nullptr only. +  DestPopularity[nullptr] = 0; +  for (auto *SuccBB : successors(BB)) +    DestPopularity[SuccBB] = 0; +    for (const auto &PredToDest : PredToDestList)      if (PredToDest.second)        DestPopularity[PredToDest.second]++; -  if (DestPopularity.empty()) -    return nullptr; -    // Find the most popular dest. -  DenseMap<BasicBlock*, unsigned>::iterator DPI = DestPopularity.begin(); -  BasicBlock *MostPopularDest = DPI->first; -  unsigned Popularity = DPI->second; -  SmallVector<BasicBlock*, 4> SamePopularity; - -  for (++DPI; DPI != DestPopularity.end(); ++DPI) { -    // If the popularity of this entry isn't higher than the popularity we've -    // seen so far, ignore it. -    if (DPI->second < Popularity) -      ; // ignore. -    else if (DPI->second == Popularity) { -      // If it is the same as what we've seen so far, keep track of it. -      SamePopularity.push_back(DPI->first); -    } else { -      // If it is more popular, remember it. -      SamePopularity.clear(); -      MostPopularDest = DPI->first; -      Popularity = DPI->second; -    } +  using VT = decltype(DestPopularity)::value_type; +  auto MostPopular = std::max_element( +      DestPopularity.begin(), DestPopularity.end(), +      [](const VT &L, const VT &R) { return L.second < R.second; }); + +  // Okay, we have finally picked the most popular destination. +  return MostPopular->first; +} + +// Try to evaluate the value of V when the control flows from PredPredBB to +// BB->getSinglePredecessor() and then on to BB. +Constant *JumpThreadingPass::EvaluateOnPredecessorEdge(BasicBlock *BB, +                                                       BasicBlock *PredPredBB, +                                                       Value *V) { +  BasicBlock *PredBB = BB->getSinglePredecessor(); +  assert(PredBB && "Expected a single predecessor"); + +  if (Constant *Cst = dyn_cast<Constant>(V)) { +    return Cst;    } -  // Okay, now we know the most popular destination.  If there is more than one -  // destination, we need to determine one.  This is arbitrary, but we need -  // to make a deterministic decision.  Pick the first one that appears in the -  // successor list. -  if (!SamePopularity.empty()) { -    SamePopularity.push_back(MostPopularDest); -    Instruction *TI = BB->getTerminator(); -    for (unsigned i = 0; ; ++i) { -      assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); +  // Consult LVI if V is not an instruction in BB or PredBB. +  Instruction *I = dyn_cast<Instruction>(V); +  if (!I || (I->getParent() != BB && I->getParent() != PredBB)) { +    return LVI->getConstantOnEdge(V, PredPredBB, PredBB, nullptr); +  } -      if (!is_contained(SamePopularity, TI->getSuccessor(i))) -        continue; +  // Look into a PHI argument. +  if (PHINode *PHI = dyn_cast<PHINode>(V)) { +    if (PHI->getParent() == PredBB) +      return dyn_cast<Constant>(PHI->getIncomingValueForBlock(PredPredBB)); +    return nullptr; +  } -      MostPopularDest = TI->getSuccessor(i); -      break; +  // If we have a CmpInst, try to fold it for each incoming edge into PredBB. +  if (CmpInst *CondCmp = dyn_cast<CmpInst>(V)) { +    if (CondCmp->getParent() == BB) { +      Constant *Op0 = +          EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(0)); +      Constant *Op1 = +          EvaluateOnPredecessorEdge(BB, PredPredBB, CondCmp->getOperand(1)); +      if (Op0 && Op1) { +        return ConstantExpr::getCompare(CondCmp->getPredicate(), Op0, Op1); +      }      } +    return nullptr;    } -  // Okay, we have finally picked the most popular destination. -  return MostPopularDest; +  return nullptr;  }  bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, @@ -1557,8 +1560,12 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB,      return false;    PredValueInfoTy PredValues; -  if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, CxtI)) -    return false; +  if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, +                                       CxtI)) { +    // We don't have known values in predecessors.  See if we can thread through +    // BB and its sole predecessor. +    return MaybeThreadThroughTwoBasicBlocks(BB, Cond); +  }    assert(!PredValues.empty() &&           "ComputeValueKnownInPredecessors returned true with no values"); @@ -1624,7 +1631,7 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB,          isa<CallBrInst>(Pred->getTerminator()))        continue; -    PredToDestList.push_back(std::make_pair(Pred, DestBB)); +    PredToDestList.emplace_back(Pred, DestBB);    }    // If all edges were unthreadable, we fail. @@ -2015,6 +2022,205 @@ JumpThreadingPass::CloneInstructions(BasicBlock::iterator BI,    return ValueMapping;  } +/// Attempt to thread through two successive basic blocks. +bool JumpThreadingPass::MaybeThreadThroughTwoBasicBlocks(BasicBlock *BB, +                                                         Value *Cond) { +  // Consider: +  // +  // PredBB: +  //   %var = phi i32* [ null, %bb1 ], [ @a, %bb2 ] +  //   %tobool = icmp eq i32 %cond, 0 +  //   br i1 %tobool, label %BB, label ... +  // +  // BB: +  //   %cmp = icmp eq i32* %var, null +  //   br i1 %cmp, label ..., label ... +  // +  // We don't know the value of %var at BB even if we know which incoming edge +  // we take to BB.  However, once we duplicate PredBB for each of its incoming +  // edges (say, PredBB1 and PredBB2), we know the value of %var in each copy of +  // PredBB.  Then we can thread edges PredBB1->BB and PredBB2->BB through BB. + +  // Require that BB end with a Branch for simplicity. +  BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); +  if (!CondBr) +    return false; + +  // BB must have exactly one predecessor. +  BasicBlock *PredBB = BB->getSinglePredecessor(); +  if (!PredBB) +    return false; + +  // Require that PredBB end with a conditional Branch. If PredBB ends with an +  // unconditional branch, we should be merging PredBB and BB instead. For +  // simplicity, we don't deal with a switch. +  BranchInst *PredBBBranch = dyn_cast<BranchInst>(PredBB->getTerminator()); +  if (!PredBBBranch || PredBBBranch->isUnconditional()) +    return false; + +  // If PredBB has exactly one incoming edge, we don't gain anything by copying +  // PredBB. +  if (PredBB->getSinglePredecessor()) +    return false; + +  // Don't thread through PredBB if it contains a successor edge to itself, in +  // which case we would infinite loop.  Suppose we are threading an edge from +  // PredPredBB through PredBB and BB to SuccBB with PredBB containing a +  // successor edge to itself.  If we allowed jump threading in this case, we +  // could duplicate PredBB and BB as, say, PredBB.thread and BB.thread.  Since +  // PredBB.thread has a successor edge to PredBB, we would immediately come up +  // with another jump threading opportunity from PredBB.thread through PredBB +  // and BB to SuccBB.  This jump threading would repeatedly occur.  That is, we +  // would keep peeling one iteration from PredBB. +  if (llvm::is_contained(successors(PredBB), PredBB)) +    return false; + +  // Don't thread across a loop header. +  if (LoopHeaders.count(PredBB)) +    return false; + +  // Avoid complication with duplicating EH pads. +  if (PredBB->isEHPad()) +    return false; + +  // Find a predecessor that we can thread.  For simplicity, we only consider a +  // successor edge out of BB to which we thread exactly one incoming edge into +  // PredBB. +  unsigned ZeroCount = 0; +  unsigned OneCount = 0; +  BasicBlock *ZeroPred = nullptr; +  BasicBlock *OnePred = nullptr; +  for (BasicBlock *P : predecessors(PredBB)) { +    if (ConstantInt *CI = dyn_cast_or_null<ConstantInt>( +            EvaluateOnPredecessorEdge(BB, P, Cond))) { +      if (CI->isZero()) { +        ZeroCount++; +        ZeroPred = P; +      } else if (CI->isOne()) { +        OneCount++; +        OnePred = P; +      } +    } +  } + +  // Disregard complicated cases where we have to thread multiple edges. +  BasicBlock *PredPredBB; +  if (ZeroCount == 1) { +    PredPredBB = ZeroPred; +  } else if (OneCount == 1) { +    PredPredBB = OnePred; +  } else { +    return false; +  } + +  BasicBlock *SuccBB = CondBr->getSuccessor(PredPredBB == ZeroPred); + +  // If threading to the same block as we come from, we would infinite loop. +  if (SuccBB == BB) { +    LLVM_DEBUG(dbgs() << "  Not threading across BB '" << BB->getName() +                      << "' - would thread to self!\n"); +    return false; +  } + +  // If threading this would thread across a loop header, don't thread the edge. +  // See the comments above FindLoopHeaders for justifications and caveats. +  if (LoopHeaders.count(BB) || LoopHeaders.count(SuccBB)) { +    LLVM_DEBUG({ +      bool BBIsHeader = LoopHeaders.count(BB); +      bool SuccIsHeader = LoopHeaders.count(SuccBB); +      dbgs() << "  Not threading across " +             << (BBIsHeader ? "loop header BB '" : "block BB '") +             << BB->getName() << "' to dest " +             << (SuccIsHeader ? "loop header BB '" : "block BB '") +             << SuccBB->getName() +             << "' - it might create an irreducible loop!\n"; +    }); +    return false; +  } + +  // Compute the cost of duplicating BB and PredBB. +  unsigned BBCost = +      getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); +  unsigned PredBBCost = getJumpThreadDuplicationCost( +      PredBB, PredBB->getTerminator(), BBDupThreshold); + +  // Give up if costs are too high.  We need to check BBCost and PredBBCost +  // individually before checking their sum because getJumpThreadDuplicationCost +  // return (unsigned)~0 for those basic blocks that cannot be duplicated. +  if (BBCost > BBDupThreshold || PredBBCost > BBDupThreshold || +      BBCost + PredBBCost > BBDupThreshold) { +    LLVM_DEBUG(dbgs() << "  Not threading BB '" << BB->getName() +                      << "' - Cost is too high: " << PredBBCost +                      << " for PredBB, " << BBCost << "for BB\n"); +    return false; +  } + +  // Now we are ready to duplicate PredBB. +  ThreadThroughTwoBasicBlocks(PredPredBB, PredBB, BB, SuccBB); +  return true; +} + +void JumpThreadingPass::ThreadThroughTwoBasicBlocks(BasicBlock *PredPredBB, +                                                    BasicBlock *PredBB, +                                                    BasicBlock *BB, +                                                    BasicBlock *SuccBB) { +  LLVM_DEBUG(dbgs() << "  Threading through '" << PredBB->getName() << "' and '" +                    << BB->getName() << "'\n"); + +  BranchInst *CondBr = cast<BranchInst>(BB->getTerminator()); +  BranchInst *PredBBBranch = cast<BranchInst>(PredBB->getTerminator()); + +  BasicBlock *NewBB = +      BasicBlock::Create(PredBB->getContext(), PredBB->getName() + ".thread", +                         PredBB->getParent(), PredBB); +  NewBB->moveAfter(PredBB); + +  // Set the block frequency of NewBB. +  if (HasProfileData) { +    auto NewBBFreq = BFI->getBlockFreq(PredPredBB) * +                     BPI->getEdgeProbability(PredPredBB, PredBB); +    BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); +  } + +  // We are going to have to map operands from the original BB block to the new +  // copy of the block 'NewBB'.  If there are PHI nodes in PredBB, evaluate them +  // to account for entry from PredPredBB. +  DenseMap<Instruction *, Value *> ValueMapping = +      CloneInstructions(PredBB->begin(), PredBB->end(), NewBB, PredPredBB); + +  // Update the terminator of PredPredBB to jump to NewBB instead of PredBB. +  // This eliminates predecessors from PredPredBB, which requires us to simplify +  // any PHI nodes in PredBB. +  Instruction *PredPredTerm = PredPredBB->getTerminator(); +  for (unsigned i = 0, e = PredPredTerm->getNumSuccessors(); i != e; ++i) +    if (PredPredTerm->getSuccessor(i) == PredBB) { +      PredBB->removePredecessor(PredPredBB, true); +      PredPredTerm->setSuccessor(i, NewBB); +    } + +  AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(0), PredBB, NewBB, +                                  ValueMapping); +  AddPHINodeEntriesForMappedBlock(PredBBBranch->getSuccessor(1), PredBB, NewBB, +                                  ValueMapping); + +  DTU->applyUpdatesPermissive( +      {{DominatorTree::Insert, NewBB, CondBr->getSuccessor(0)}, +       {DominatorTree::Insert, NewBB, CondBr->getSuccessor(1)}, +       {DominatorTree::Insert, PredPredBB, NewBB}, +       {DominatorTree::Delete, PredPredBB, PredBB}}); + +  UpdateSSA(PredBB, NewBB, ValueMapping); + +  // Clean up things like PHI nodes with single operands, dead instructions, +  // etc. +  SimplifyInstructionsInBlock(NewBB, TLI); +  SimplifyInstructionsInBlock(PredBB, TLI); + +  SmallVector<BasicBlock *, 1> PredsToFactor; +  PredsToFactor.push_back(NewBB); +  ThreadEdge(BB, PredsToFactor, SuccBB); +} +  /// TryThreadEdge - Thread an edge if it's safe and profitable to do so.  bool JumpThreadingPass::TryThreadEdge(      BasicBlock *BB, const SmallVectorImpl<BasicBlock *> &PredBBs, @@ -2078,10 +2284,6 @@ void JumpThreadingPass::ThreadEdge(BasicBlock *BB,                      << "' to '" << SuccBB->getName()                      << ", across block:\n    " << *BB << "\n"); -  if (DTU->hasPendingDomTreeUpdates()) -    LVI->disableDT(); -  else -    LVI->enableDT();    LVI->threadEdge(PredBB, BB, SuccBB);    BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), @@ -2246,8 +2448,7 @@ void JumpThreadingPass::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB,    }    // Update edge probabilities in BPI. -  for (int I = 0, E = BBSuccProbs.size(); I < E; I++) -    BPI->setEdgeProbability(BB, I, BBSuccProbs[I]); +  BPI->setEdgeProbability(BB, BBSuccProbs);    // Update the profile metadata as well.    // @@ -2524,10 +2725,6 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) {      // Now check if one of the select values would allow us to constant fold the      // terminator in BB. We don't do the transform if both sides fold, those      // cases will be threaded in any case. -    if (DTU->hasPendingDomTreeUpdates()) -      LVI->disableDT(); -    else -      LVI->enableDT();      LazyValueInfo::Tristate LHSFolds =          LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(1),                                  CondRHS, Pred, BB, CondCmp); @@ -2565,6 +2762,16 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) {  /// select is not jump-threaded, it will be folded again in the later  /// optimizations.  bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { +  // This transform can introduce a UB (a conditional branch that depends on a +  // poison value) that was not present in the original program. See +  // @TryToUnfoldSelectInCurrBB test in test/Transforms/JumpThreading/select.ll. +  // Disable this transform under MemorySanitizer. +  // FIXME: either delete it or replace with a valid transform. This issue is +  // not limited to MemorySanitizer (but has only been observed as an MSan false +  // positive in practice so far). +  if (BB->getParent()->hasFnAttribute(Attribute::SanitizeMemory)) +    return false; +    // If threading this would thread across a loop header, don't thread the edge.    // See the comments above FindLoopHeaders for justifications and caveats.    if (LoopHeaders.count(BB)) diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index 8c33045c2380..1a22edaf8726 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -46,6 +46,7 @@  #include "llvm/Analysis/MemoryBuiltins.h"  #include "llvm/Analysis/MemorySSA.h"  #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/MustExecute.h"  #include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/ScalarEvolution.h"  #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" @@ -69,6 +70,7 @@  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/LoopUtils.h" @@ -151,11 +153,11 @@ static bool isSafeToExecuteUnconditionally(Instruction &Inst,                                             const Instruction *CtxI = nullptr);  static bool pointerInvalidatedByLoop(MemoryLocation MemLoc,                                       AliasSetTracker *CurAST, Loop *CurLoop, -                                     AliasAnalysis *AA); +                                     AAResults *AA);  static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU,                                               Loop *CurLoop,                                               SinkAndHoistLICMFlags &Flags); -static Instruction *CloneInstructionInExitBlock( +static Instruction *cloneInstructionInExitBlock(      Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI,      const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU); @@ -168,27 +170,24 @@ static void moveInstructionBefore(Instruction &I, Instruction &Dest,  namespace {  struct LoopInvariantCodeMotion { -  using ASTrackerMapTy = DenseMap<Loop *, std::unique_ptr<AliasSetTracker>>; -  bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, +  bool runOnLoop(Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT,                   TargetLibraryInfo *TLI, TargetTransformInfo *TTI,                   ScalarEvolution *SE, MemorySSA *MSSA, -                 OptimizationRemarkEmitter *ORE, bool DeleteAST); +                 OptimizationRemarkEmitter *ORE); -  ASTrackerMapTy &getLoopToAliasSetMap() { return LoopToAliasSetMap; }    LoopInvariantCodeMotion(unsigned LicmMssaOptCap,                            unsigned LicmMssaNoAccForPromotionCap)        : LicmMssaOptCap(LicmMssaOptCap),          LicmMssaNoAccForPromotionCap(LicmMssaNoAccForPromotionCap) {}  private: -  ASTrackerMapTy LoopToAliasSetMap;    unsigned LicmMssaOptCap;    unsigned LicmMssaNoAccForPromotionCap;    std::unique_ptr<AliasSetTracker> -  collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AliasAnalysis *AA); +  collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AAResults *AA);    std::unique_ptr<AliasSetTracker> -  collectAliasInfoForLoopWithMSSA(Loop *L, AliasAnalysis *AA, +  collectAliasInfoForLoopWithMSSA(Loop *L, AAResults *AA,                                    MemorySSAUpdater *MSSAU);  }; @@ -202,13 +201,8 @@ struct LegacyLICMPass : public LoopPass {    }    bool runOnLoop(Loop *L, LPPassManager &LPM) override { -    if (skipLoop(L)) { -      // If we have run LICM on a previous loop but now we are skipping -      // (because we've hit the opt-bisect limit), we need to clear the -      // loop alias information. -      LICM.getLoopToAliasSetMap().clear(); +    if (skipLoop(L))        return false; -    }      auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();      MemorySSA *MSSA = EnableMSSALoopDependency @@ -226,7 +220,7 @@ struct LegacyLICMPass : public LoopPass {                                *L->getHeader()->getParent()),                            &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(                                *L->getHeader()->getParent()), -                          SE ? &SE->getSE() : nullptr, MSSA, &ORE, false); +                          SE ? &SE->getSE() : nullptr, MSSA, &ORE);    }    /// This transformation requires natural loop information & requires that @@ -244,53 +238,21 @@ struct LegacyLICMPass : public LoopPass {      getLoopAnalysisUsage(AU);    } -  using llvm::Pass::doFinalization; - -  bool doFinalization() override { -    auto &AliasSetMap = LICM.getLoopToAliasSetMap(); -    // All loops in the AliasSetMap should be cleaned up already. The only case -    // where we fail to do so is if an outer loop gets deleted before LICM -    // visits it. -    assert(all_of(AliasSetMap, -                  [](LoopInvariantCodeMotion::ASTrackerMapTy::value_type &KV) { -                    return !KV.first->getParentLoop(); -                  }) && -           "Didn't free loop alias sets"); -    AliasSetMap.clear(); -    return false; -  } -  private:    LoopInvariantCodeMotion LICM; - -  /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. -  void cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, -                               Loop *L) override; - -  /// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias -  /// set. -  void deleteAnalysisValue(Value *V, Loop *L) override; - -  /// Simple Analysis hook. Delete loop L from alias set map. -  void deleteAnalysisLoop(Loop *L) override;  };  } // namespace  PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM,                                  LoopStandardAnalysisResults &AR, LPMUpdater &) { -  const auto &FAM = -      AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); -  Function *F = L.getHeader()->getParent(); - -  auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); -  // FIXME: This should probably be optional rather than required. -  if (!ORE) -    report_fatal_error("LICM: OptimizationRemarkEmitterAnalysis not " -                       "cached at a higher level"); +  // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis +  // pass.  Function analyses need to be preserved across loop transformations +  // but ORE cannot be preserved (see comment before the pass definition). +  OptimizationRemarkEmitter ORE(L.getHeader()->getParent());    LoopInvariantCodeMotion LICM(LicmMssaOptCap, LicmMssaNoAccForPromotionCap);    if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.TTI, &AR.SE, -                      AR.MSSA, ORE, true)) +                      AR.MSSA, &ORE))      return PreservedAnalyses::all();    auto PA = getLoopPassPreservedAnalyses(); @@ -322,13 +284,10 @@ Pass *llvm::createLICMPass(unsigned LicmMssaOptCap,  /// Hoist expressions out of the specified loop. Note, alias info for inner  /// loop is not preserved so it is not a good idea to run LICM multiple  /// times on one loop. -/// We should delete AST for inner loops in the new pass manager to avoid -/// memory leak. -///  bool LoopInvariantCodeMotion::runOnLoop( -    Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, +    Loop *L, AAResults *AA, LoopInfo *LI, DominatorTree *DT,      TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE, -    MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool DeleteAST) { +    MemorySSA *MSSA, OptimizationRemarkEmitter *ORE) {    bool Changed = false;    assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); @@ -372,7 +331,7 @@ bool LoopInvariantCodeMotion::runOnLoop(    BasicBlock *Preheader = L->getLoopPreheader();    // Compute loop safety information. -  ICFLoopSafetyInfo SafetyInfo(DT); +  ICFLoopSafetyInfo SafetyInfo;    SafetyInfo.computeLoopSafetyInfo(L);    // We want to visit all of the instructions in this loop... that are not parts @@ -476,11 +435,6 @@ bool LoopInvariantCodeMotion::runOnLoop(    assert((!L->getParentLoop() || L->getParentLoop()->isLCSSAForm(*DT)) &&           "Parent loop not left in LCSSA form after LICM!"); -  // If this loop is nested inside of another one, save the alias information -  // for when we process the outer loop. -  if (!MSSAU.get() && CurAST.get() && L->getParentLoop() && !DeleteAST) -    LoopToAliasSetMap[L] = std::move(CurAST); -    if (MSSAU.get() && VerifyMemorySSA)      MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -494,7 +448,7 @@ bool LoopInvariantCodeMotion::runOnLoop(  /// first order w.r.t the DominatorTree.  This allows us to visit uses before  /// definitions, allowing us to sink a loop body in one pass without iteration.  /// -bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, +bool llvm::sinkRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,                        DominatorTree *DT, TargetLibraryInfo *TLI,                        TargetTransformInfo *TTI, Loop *CurLoop,                        AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, @@ -529,6 +483,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,        // used in the loop, instead, just delete it.        if (isInstructionTriviallyDead(&I, TLI)) {          LLVM_DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); +        salvageKnowledge(&I);          salvageDebugInfo(I);          ++II;          eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); @@ -542,13 +497,14 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,        // operands of the instruction are loop invariant.        //        bool FreeInLoop = false; -      if (isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && +      if (!I.mayHaveSideEffects() && +          isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) &&            canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, &Flags, -                             ORE) && -          !I.mayHaveSideEffects()) { +                             ORE)) {          if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE)) {            if (!FreeInLoop) {              ++II; +            salvageDebugInfo(I);              eraseInstruction(I, *SafetyInfo, CurAST, MSSAU);            }            Changed = true; @@ -790,47 +746,12 @@ public:  };  } // namespace - -/// Return true if we know how to rewrite all uses of the given alloca after -/// hoisting it out of the loop.  The main concerns are a) potential captures -/// and b) invariant.start markers which don't capture, but are no longer -/// valid w/o a corresponding invariant.end. -static bool canRewriteUsesOfAlloca(AllocaInst &AI) { -  // TODO: This looks a lot like capture tracking, but we need to remove any -  // invariant starts if we extend the lifetime of the alloca by hoisting it. -  // We should probably refactor capture tracking into a form which allows us -  // to reuse the relevant bits and remove the duplicated logic here. - -  SmallVector<Use *, 16> Worklist; -  for (Use &U : AI.uses()) -    Worklist.push_back(&U); -   -  unsigned NumUsesExplored = 0; -  while (!Worklist.empty()) { -    Use *U = Worklist.pop_back_val(); -    Instruction *I = cast<Instruction>(U->getUser()); -    NumUsesExplored++; -    if (NumUsesExplored > DefaultMaxUsesToExplore) -      return false; -    // Non capturing, terminating uses -    if (isa<LoadInst>(I) || -        (isa<StoreInst>(I) && U->getOperandNo() == 1)) -      continue; -    // Non capturing, non-terminating -    if (!isa<BitCastInst>(I) && !isa<GetElementPtrInst>(I)) -      return false; -    for (Use &U : I->uses()) -      Worklist.push_back(&U); -  } -  return true; -} -  /// Walk the specified region of the CFG (defined by all blocks dominated by  /// the specified block, and that are in the current loop) in depth first  /// order w.r.t the DominatorTree.  This allows us to visit definitions before  /// uses, allowing us to hoist a loop body in one pass without iteration.  /// -bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, +bool llvm::hoistRegion(DomTreeNode *N, AAResults *AA, LoopInfo *LI,                         DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop,                         AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU,                         ScalarEvolution *SE, ICFLoopSafetyInfo *SafetyInfo, @@ -901,9 +822,8 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,        // Attempt to remove floating point division out of the loop by        // converting it to a reciprocal multiplication. -      if (I.getOpcode() == Instruction::FDiv && -          CurLoop->isLoopInvariant(I.getOperand(1)) && -          I.hasAllowReciprocal()) { +      if (I.getOpcode() == Instruction::FDiv && I.hasAllowReciprocal() && +          CurLoop->isLoopInvariant(I.getOperand(1))) {          auto Divisor = I.getOperand(1);          auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0);          auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); @@ -945,16 +865,6 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI,          continue;        } -      if (isa<AllocaInst>(&I) && -          SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop) && -          canRewriteUsesOfAlloca(cast<AllocaInst>(I))) { -        hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, -              MSSAU, SE, ORE); -        HoistedInstructions.push_back(&I); -        Changed = true; -        continue; -      } -        if (PHINode *PN = dyn_cast<PHINode>(&I)) {          if (CFH.canHoistPHI(PN)) {            // Redirect incoming blocks first to ensure that we create hoisted @@ -1081,12 +991,12 @@ namespace {  bool isHoistableAndSinkableInst(Instruction &I) {    // Only these instructions are hoistable/sinkable.    return (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<CallInst>(I) || -          isa<FenceInst>(I) || isa<CastInst>(I) || -          isa<UnaryOperator>(I) || isa<BinaryOperator>(I) || -          isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || isa<CmpInst>(I) || +          isa<FenceInst>(I) || isa<CastInst>(I) || isa<UnaryOperator>(I) || +          isa<BinaryOperator>(I) || isa<SelectInst>(I) || +          isa<GetElementPtrInst>(I) || isa<CmpInst>(I) ||            isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) ||            isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) || -          isa<InsertValueInst>(I)); +          isa<InsertValueInst>(I) || isa<FreezeInst>(I));  }  /// Return true if all of the alias sets within this AST are known not to  /// contain a Mod, or if MSSA knows thare are no MemoryDefs in the loop. @@ -1198,11 +1108,11 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT,      FunctionModRefBehavior Behavior = AA->getModRefBehavior(CI);      if (Behavior == FMRB_DoesNotAccessMemory)        return true; -    if (AliasAnalysis::onlyReadsMemory(Behavior)) { +    if (AAResults::onlyReadsMemory(Behavior)) {        // A readonly argmemonly function only reads from memory pointed to by        // it's arguments with arbitrary offsets.  If we can prove there are no        // writes to this memory in the loop, we can hoist or sink. -      if (AliasAnalysis::onlyAccessesArgPointees(Behavior)) { +      if (AAResults::onlyAccessesArgPointees(Behavior)) {          // TODO: expand to writeable arguments          for (Value *Op : CI->arg_operands())            if (Op->getType()->isPointerTy()) { @@ -1351,7 +1261,8 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop,                           const TargetTransformInfo *TTI) {    if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&I)) { -    if (TTI->getUserCost(GEP) != TargetTransformInfo::TCC_Free) +    if (TTI->getUserCost(GEP, TargetTransformInfo::TCK_SizeAndLatency) != +        TargetTransformInfo::TCC_Free)        return false;      // For a GEP, we cannot simply use getUserCost because currently it      // optimistically assume that a GEP will fold into addressing mode @@ -1366,7 +1277,8 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop,      }      return true;    } else -    return TTI->getUserCost(&I) == TargetTransformInfo::TCC_Free; +    return TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency) == +           TargetTransformInfo::TCC_Free;  }  /// Return true if the only users of this instruction are outside of @@ -1407,7 +1319,7 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop,    return true;  } -static Instruction *CloneInstructionInExitBlock( +static Instruction *cloneInstructionInExitBlock(      Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI,      const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU) {    Instruction *New; @@ -1520,7 +1432,7 @@ static Instruction *sinkThroughTriviallyReplaceablePHI(    if (It != SunkCopies.end())      New = It->second;    else -    New = SunkCopies[ExitBlock] = CloneInstructionInExitBlock( +    New = SunkCopies[ExitBlock] = cloneInstructionInExitBlock(          *I, *ExitBlock, *TPN, LI, SafetyInfo, MSSAU);    return New;  } @@ -1537,7 +1449,8 @@ static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) {      return false;    for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) {      BasicBlock *BBPred = *PI; -    if (isa<IndirectBrInst>(BBPred->getTerminator())) +    if (isa<IndirectBrInst>(BBPred->getTerminator()) || +        isa<CallBrInst>(BBPred->getTerminator()))        return false;    }    return true; @@ -1857,7 +1770,7 @@ public:        StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos);        if (UnorderedAtomic)          NewSI->setOrdering(AtomicOrdering::Unordered); -      NewSI->setAlignment(MaybeAlign(Alignment)); +      NewSI->setAlignment(Align(Alignment));        NewSI->setDebugLoc(DL);        if (AATags)          NewSI->setAAMetadata(AATags); @@ -1981,7 +1894,7 @@ bool llvm::promoteLoopAccessesToScalars(    // We start with an alignment of one and try to find instructions that allow    // us to prove better alignment. -  unsigned Alignment = 1; +  Align Alignment;    // Keep track of which types of access we see    bool SawUnorderedAtomic = false;    bool SawNotAtomic = false; @@ -2029,10 +1942,7 @@ bool llvm::promoteLoopAccessesToScalars(          SawUnorderedAtomic |= Load->isAtomic();          SawNotAtomic |= !Load->isAtomic(); -        unsigned InstAlignment = Load->getAlignment(); -        if (!InstAlignment) -          InstAlignment = -              MDL.getABITypeAlignment(Load->getType()); +        Align InstAlignment = Load->getAlign();          // Note that proving a load safe to speculate requires proving          // sufficient alignment at the target location.  Proving it guaranteed @@ -2060,10 +1970,7 @@ bool llvm::promoteLoopAccessesToScalars(          // already know that promotion is safe, since it may have higher          // alignment than any other guaranteed stores, in which case we can          // raise the alignment on the promoted store. -        unsigned InstAlignment = Store->getAlignment(); -        if (!InstAlignment) -          InstAlignment = -              MDL.getABITypeAlignment(Store->getValueOperand()->getType()); +        Align InstAlignment = Store->getAlign();          if (!DereferenceableInPH || !SafeToInsertStore ||              (InstAlignment > Alignment)) { @@ -2090,8 +1997,7 @@ bool llvm::promoteLoopAccessesToScalars(          if (!DereferenceableInPH) {            DereferenceableInPH = isDereferenceableAndAlignedPointer(                Store->getPointerOperand(), Store->getValueOperand()->getType(), -              MaybeAlign(Store->getAlignment()), MDL, -              Preheader->getTerminator(), DT); +              Store->getAlign(), MDL, Preheader->getTerminator(), DT);          }        } else          return false; // Not a load or store. @@ -2156,18 +2062,19 @@ bool llvm::promoteLoopAccessesToScalars(    });    ++NumPromoted; -  // Grab a debug location for the inserted loads/stores; given that the -  // inserted loads/stores have little relation to the original loads/stores, -  // this code just arbitrarily picks a location from one, since any debug -  // location is better than none. -  DebugLoc DL = LoopUses[0]->getDebugLoc(); +  // Look at all the loop uses, and try to merge their locations. +  std::vector<const DILocation *> LoopUsesLocs; +  for (auto U : LoopUses) +    LoopUsesLocs.push_back(U->getDebugLoc().get()); +  auto DL = DebugLoc(DILocation::getMergedLocations(LoopUsesLocs));    // We use the SSAUpdater interface to insert phi nodes as required.    SmallVector<PHINode *, 16> NewPHIs;    SSAUpdater SSA(&NewPHIs);    LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks,                          InsertPts, MSSAInsertPts, PIC, *CurAST, MSSAU, *LI, DL, -                        Alignment, SawUnorderedAtomic, AATags, *SafetyInfo); +                        Alignment.value(), SawUnorderedAtomic, AATags, +                        *SafetyInfo);    // Set up the preheader to have a definition of the value.  It is the live-out    // value from the preheader that uses in the loop will use. @@ -2176,8 +2083,8 @@ bool llvm::promoteLoopAccessesToScalars(        SomePtr->getName() + ".promoted", Preheader->getTerminator());    if (SawUnorderedAtomic)      PreheaderLoad->setOrdering(AtomicOrdering::Unordered); -  PreheaderLoad->setAlignment(MaybeAlign(Alignment)); -  PreheaderLoad->setDebugLoc(DL); +  PreheaderLoad->setAlignment(Alignment); +  PreheaderLoad->setDebugLoc(DebugLoc());    if (AATags)      PreheaderLoad->setAAMetadata(AATags);    SSA.AddAvailableValue(Preheader, PreheaderLoad); @@ -2206,41 +2113,13 @@ bool llvm::promoteLoopAccessesToScalars(  /// Returns an owning pointer to an alias set which incorporates aliasing info  /// from L and all subloops of L. -/// FIXME: In new pass manager, there is no helper function to handle loop -/// analysis such as cloneBasicBlockAnalysis, so the AST needs to be recomputed -/// from scratch for every loop. Hook up with the helper functions when -/// available in the new pass manager to avoid redundant computation.  std::unique_ptr<AliasSetTracker>  LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, -                                                 AliasAnalysis *AA) { -  std::unique_ptr<AliasSetTracker> CurAST; -  SmallVector<Loop *, 4> RecomputeLoops; -  for (Loop *InnerL : L->getSubLoops()) { -    auto MapI = LoopToAliasSetMap.find(InnerL); -    // If the AST for this inner loop is missing it may have been merged into -    // some other loop's AST and then that loop unrolled, and so we need to -    // recompute it. -    if (MapI == LoopToAliasSetMap.end()) { -      RecomputeLoops.push_back(InnerL); -      continue; -    } -    std::unique_ptr<AliasSetTracker> InnerAST = std::move(MapI->second); +                                                 AAResults *AA) { +  auto CurAST = std::make_unique<AliasSetTracker>(*AA); -    if (CurAST) { -      // What if InnerLoop was modified by other passes ? -      // Once we've incorporated the inner loop's AST into ours, we don't need -      // the subloop's anymore. -      CurAST->add(*InnerAST); -    } else { -      CurAST = std::move(InnerAST); -    } -    LoopToAliasSetMap.erase(MapI); -  } -  if (!CurAST) -    CurAST = std::make_unique<AliasSetTracker>(*AA); - -  // Add everything from the sub loops that are no longer directly available. -  for (Loop *InnerL : RecomputeLoops) +  // Add everything from all the sub loops. +  for (Loop *InnerL : L->getSubLoops())      for (BasicBlock *BB : InnerL->blocks())        CurAST->add(*BB); @@ -2254,46 +2133,16 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI,  std::unique_ptr<AliasSetTracker>  LoopInvariantCodeMotion::collectAliasInfoForLoopWithMSSA( -    Loop *L, AliasAnalysis *AA, MemorySSAUpdater *MSSAU) { +    Loop *L, AAResults *AA, MemorySSAUpdater *MSSAU) {    auto *MSSA = MSSAU->getMemorySSA();    auto CurAST = std::make_unique<AliasSetTracker>(*AA, MSSA, L);    CurAST->addAllInstructionsInLoopUsingMSSA();    return CurAST;  } -/// Simple analysis hook. Clone alias set info. -/// -void LegacyLICMPass::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, -                                             Loop *L) { -  auto ASTIt = LICM.getLoopToAliasSetMap().find(L); -  if (ASTIt == LICM.getLoopToAliasSetMap().end()) -    return; - -  ASTIt->second->copyValue(From, To); -} - -/// Simple Analysis hook. Delete value V from alias set -/// -void LegacyLICMPass::deleteAnalysisValue(Value *V, Loop *L) { -  auto ASTIt = LICM.getLoopToAliasSetMap().find(L); -  if (ASTIt == LICM.getLoopToAliasSetMap().end()) -    return; - -  ASTIt->second->deleteValue(V); -} - -/// Simple Analysis hook. Delete value L from alias set map. -/// -void LegacyLICMPass::deleteAnalysisLoop(Loop *L) { -  if (!LICM.getLoopToAliasSetMap().count(L)) -    return; - -  LICM.getLoopToAliasSetMap().erase(L); -} -  static bool pointerInvalidatedByLoop(MemoryLocation MemLoc,                                       AliasSetTracker *CurAST, Loop *CurLoop, -                                     AliasAnalysis *AA) { +                                     AAResults *AA) {    // First check to see if any of the basic blocks in CurLoop invalidate *V.    bool isInvalidatedAccordingToAST = CurAST->getAliasSetFor(MemLoc).isMod(); diff --git a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp index ab65f56d088f..687e14d6d7d2 100644 --- a/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -21,7 +21,6 @@  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/Analysis/TargetTransformInfo.h"  #include "llvm/IR/CFG.h" @@ -32,6 +31,7 @@  #include "llvm/Support/Debug.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include "llvm/Transforms/Utils/ValueMapper.h"  using namespace llvm; @@ -61,10 +61,10 @@ namespace {  /// Loop prefetch implementation class.  class LoopDataPrefetch {  public: -  LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, -                   const TargetTransformInfo *TTI, +  LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI, +                   ScalarEvolution *SE, const TargetTransformInfo *TTI,                     OptimizationRemarkEmitter *ORE) -      : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} +      : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {}    bool run(); @@ -73,12 +73,16 @@ private:    /// Check if the stride of the accesses is large enough to    /// warrant a prefetch. -  bool isStrideLargeEnough(const SCEVAddRecExpr *AR); +  bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride); -  unsigned getMinPrefetchStride() { +  unsigned getMinPrefetchStride(unsigned NumMemAccesses, +                                unsigned NumStridedMemAccesses, +                                unsigned NumPrefetches, +                                bool HasCall) {      if (MinPrefetchStride.getNumOccurrences() > 0)        return MinPrefetchStride; -    return TTI->getMinPrefetchStride(); +    return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses, +                                     NumPrefetches, HasCall);    }    unsigned getPrefetchDistance() { @@ -93,7 +97,14 @@ private:      return TTI->getMaxPrefetchIterationsAhead();    } +  bool doPrefetchWrites() { +    if (PrefetchWrites.getNumOccurrences() > 0) +      return PrefetchWrites; +    return TTI->enableWritePrefetching(); +  } +    AssumptionCache *AC; +  DominatorTree *DT;    LoopInfo *LI;    ScalarEvolution *SE;    const TargetTransformInfo *TTI; @@ -110,6 +121,7 @@ public:    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.addRequired<AssumptionCacheTracker>(); +    AU.addRequired<DominatorTreeWrapperPass>();      AU.addPreserved<DominatorTreeWrapperPass>();      AU.addRequired<LoopInfoWrapperPass>();      AU.addPreserved<LoopInfoWrapperPass>(); @@ -138,8 +150,8 @@ FunctionPass *llvm::createLoopDataPrefetchPass() {    return new LoopDataPrefetchLegacyPass();  } -bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { -  unsigned TargetMinStride = getMinPrefetchStride(); +bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR, +                                           unsigned TargetMinStride) {    // No need to check if any stride goes.    if (TargetMinStride <= 1)      return true; @@ -156,6 +168,7 @@ bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) {  PreservedAnalyses LoopDataPrefetchPass::run(Function &F,                                              FunctionAnalysisManager &AM) { +  DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);    LoopInfo *LI = &AM.getResult<LoopAnalysis>(F);    ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);    AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); @@ -163,7 +176,7 @@ PreservedAnalyses LoopDataPrefetchPass::run(Function &F,        &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);    const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F); -  LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); +  LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);    bool Changed = LDP.run();    if (Changed) { @@ -180,6 +193,7 @@ bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {    if (skipFunction(F))      return false; +  DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();    LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();    ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();    AssumptionCache *AC = @@ -189,7 +203,7 @@ bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) {    const TargetTransformInfo *TTI =        &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); -  LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); +  LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE);    return LDP.run();  } @@ -210,6 +224,49 @@ bool LoopDataPrefetch::run() {    return MadeChange;  } +/// A record for a potential prefetch made during the initial scan of the +/// loop. This is used to let a single prefetch target multiple memory accesses. +struct Prefetch { +  /// The address formula for this prefetch as returned by ScalarEvolution. +  const SCEVAddRecExpr *LSCEVAddRec; +  /// The point of insertion for the prefetch instruction. +  Instruction *InsertPt; +  /// True if targeting a write memory access. +  bool Writes; +  /// The (first seen) prefetched instruction. +  Instruction *MemI; + +  /// Constructor to create a new Prefetch for \p I. +  Prefetch(const SCEVAddRecExpr *L, Instruction *I) +      : LSCEVAddRec(L), InsertPt(nullptr), Writes(false), MemI(nullptr) { +    addInstruction(I); +  }; + +  /// Add the instruction \param I to this prefetch. If it's not the first +  /// one, 'InsertPt' and 'Writes' will be updated as required. +  /// \param PtrDiff the known constant address difference to the first added +  /// instruction. +  void addInstruction(Instruction *I, DominatorTree *DT = nullptr, +                      int64_t PtrDiff = 0) { +    if (!InsertPt) { +      MemI = I; +      InsertPt = I; +      Writes = isa<StoreInst>(I); +    } else { +      BasicBlock *PrefBB = InsertPt->getParent(); +      BasicBlock *InsBB = I->getParent(); +      if (PrefBB != InsBB) { +        BasicBlock *DomBB = DT->findNearestCommonDominator(PrefBB, InsBB); +        if (DomBB != PrefBB) +          InsertPt = DomBB->getTerminator(); +      } + +      if (isa<StoreInst>(I) && PtrDiff == 0) +        Writes = true; +    } +  } +}; +  bool LoopDataPrefetch::runOnLoop(Loop *L) {    bool MadeChange = false; @@ -222,15 +279,22 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {    // Calculate the number of iterations ahead to prefetch    CodeMetrics Metrics; +  bool HasCall = false;    for (const auto BB : L->blocks()) {      // If the loop already has prefetches, then assume that the user knows      // what they are doing and don't add any more. -    for (auto &I : *BB) -      if (CallInst *CI = dyn_cast<CallInst>(&I)) -        if (Function *F = CI->getCalledFunction()) +    for (auto &I : *BB) { +      if (isa<CallInst>(&I) || isa<InvokeInst>(&I)) { +        if (const Function *F = cast<CallBase>(I).getCalledFunction()) {            if (F->getIntrinsicID() == Intrinsic::prefetch)              return MadeChange; - +          if (TTI->isLoweredToCall(F)) +            HasCall = true; +        } else { // indirect call. +          HasCall = true; +        } +      } +    }      Metrics.analyzeBasicBlock(BB, *TTI, EphValues);    }    unsigned LoopSize = Metrics.NumInsts; @@ -244,12 +308,14 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {    if (ItersAhead > getMaxPrefetchIterationsAhead())      return MadeChange; -  LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead -                    << " iterations ahead (loop size: " << LoopSize << ") in " -                    << L->getHeader()->getParent()->getName() << ": " << *L); +  unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L); +  if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1) +    return MadeChange; -  SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; -  for (const auto BB : L->blocks()) { +  unsigned NumMemAccesses = 0; +  unsigned NumStridedMemAccesses = 0; +  SmallVector<Prefetch, 16> Prefetches; +  for (const auto BB : L->blocks())      for (auto &I : *BB) {        Value *PtrValue;        Instruction *MemI; @@ -258,7 +324,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {          MemI = LMemI;          PtrValue = LMemI->getPointerOperand();        } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) { -        if (!PrefetchWrites) continue; +        if (!doPrefetchWrites()) continue;          MemI = SMemI;          PtrValue = SMemI->getPointerOperand();        } else continue; @@ -266,7 +332,7 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {        unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace();        if (PtrAddrSpace)          continue; - +      NumMemAccesses++;        if (L->isLoopInvariant(PtrValue))          continue; @@ -274,62 +340,79 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) {        const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);        if (!LSCEVAddRec)          continue; +      NumStridedMemAccesses++; -      // Check if the stride of the accesses is large enough to warrant a -      // prefetch. -      if (!isStrideLargeEnough(LSCEVAddRec)) -        continue; - -      // We don't want to double prefetch individual cache lines. If this load -      // is known to be within one cache line of some other load that has -      // already been prefetched, then don't prefetch this one as well. +      // We don't want to double prefetch individual cache lines. If this +      // access is known to be within one cache line of some other one that +      // has already been prefetched, then don't prefetch this one as well.        bool DupPref = false; -      for (const auto &PrefLoad : PrefLoads) { -        const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second); +      for (auto &Pref : Prefetches) { +        const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, Pref.LSCEVAddRec);          if (const SCEVConstant *ConstPtrDiff =              dyn_cast<SCEVConstant>(PtrDiff)) {            int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue());            if (PD < (int64_t) TTI->getCacheLineSize()) { +            Pref.addInstruction(MemI, DT, PD);              DupPref = true;              break;            }          }        } -      if (DupPref) -        continue; +      if (!DupPref) +        Prefetches.push_back(Prefetch(LSCEVAddRec, MemI)); +    } -      const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr( -        SE->getConstant(LSCEVAddRec->getType(), ItersAhead), -        LSCEVAddRec->getStepRecurrence(*SE))); -      if (!isSafeToExpand(NextLSCEV, *SE)) -        continue; +  unsigned TargetMinStride = +    getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses, +                         Prefetches.size(), HasCall); -      PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec)); - -      Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); -      SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr"); -      Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI); - -      IRBuilder<> Builder(MemI); -      Module *M = BB->getParent()->getParent(); -      Type *I32 = Type::getInt32Ty(BB->getContext()); -      Function *PrefetchFunc = Intrinsic::getDeclaration( -          M, Intrinsic::prefetch, PrefPtrValue->getType()); -      Builder.CreateCall( -          PrefetchFunc, -          {PrefPtrValue, -           ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1), -           ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); -      ++NumPrefetches; -      LLVM_DEBUG(dbgs() << "  Access: " << *PtrValue << ", SCEV: " << *LSCEV -                        << "\n"); -      ORE->emit([&]() { -        return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) -               << "prefetched memory access"; +  LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead +             << " iterations ahead (loop size: " << LoopSize << ") in " +             << L->getHeader()->getParent()->getName() << ": " << *L); +  LLVM_DEBUG(dbgs() << "Loop has: " +             << NumMemAccesses << " memory accesses, " +             << NumStridedMemAccesses << " strided memory accesses, " +             << Prefetches.size() << " potential prefetch(es), " +             << "a minimum stride of " << TargetMinStride << ", " +             << (HasCall ? "calls" : "no calls") << ".\n"); + +  for (auto &P : Prefetches) { +    // Check if the stride of the accesses is large enough to warrant a +    // prefetch. +    if (!isStrideLargeEnough(P.LSCEVAddRec, TargetMinStride)) +      continue; + +    const SCEV *NextLSCEV = SE->getAddExpr(P.LSCEVAddRec, SE->getMulExpr( +      SE->getConstant(P.LSCEVAddRec->getType(), ItersAhead), +      P.LSCEVAddRec->getStepRecurrence(*SE))); +    if (!isSafeToExpand(NextLSCEV, *SE)) +      continue; + +    BasicBlock *BB = P.InsertPt->getParent(); +    Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), 0/*PtrAddrSpace*/); +    SCEVExpander SCEVE(*SE, BB->getModule()->getDataLayout(), "prefaddr"); +    Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, P.InsertPt); + +    IRBuilder<> Builder(P.InsertPt); +    Module *M = BB->getParent()->getParent(); +    Type *I32 = Type::getInt32Ty(BB->getContext()); +    Function *PrefetchFunc = Intrinsic::getDeclaration( +        M, Intrinsic::prefetch, PrefPtrValue->getType()); +    Builder.CreateCall( +        PrefetchFunc, +        {PrefPtrValue, +         ConstantInt::get(I32, P.Writes), +         ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); +    ++NumPrefetches; +    LLVM_DEBUG(dbgs() << "  Access: " +               << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1) +               << ", SCEV: " << *P.LSCEVAddRec << "\n"); +    ORE->emit([&]() { +        return OptimizationRemark(DEBUG_TYPE, "Prefetched", P.MemI) +          << "prefetched memory access";        }); -      MadeChange = true; -    } +    MadeChange = true;    }    return MadeChange; diff --git a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp index 2451572d6171..be209d34be42 100644 --- a/llvm/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -18,6 +18,8 @@  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/PatternMatch.h"  #include "llvm/InitializePasses.h" @@ -134,7 +136,9 @@ static bool isLoopNeverExecuted(Loop *L) {  /// is unable to delete it due to hoisting trivially loop invariant  /// instructions out of the loop.  static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, -                                           ScalarEvolution &SE, LoopInfo &LI) { +                                           ScalarEvolution &SE, LoopInfo &LI, +                                           MemorySSA *MSSA, +                                           OptimizationRemarkEmitter &ORE) {    assert(L->isLCSSAForm(DT) && "Expected LCSSA!");    // We can only remove the loop if there is a preheader that we can branch from @@ -164,7 +168,12 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT,        std::fill(P.incoming_values().begin(), P.incoming_values().end(),                  UndefValue::get(P.getType()));      } -    deleteDeadLoop(L, &DT, &SE, &LI); +    ORE.emit([&]() { +      return OptimizationRemark(DEBUG_TYPE, "NeverExecutes", L->getStartLoc(), +                                L->getHeader()) +             << "Loop deleted because it never executes"; +    }); +    deleteDeadLoop(L, &DT, &SE, &LI, MSSA);      ++NumDeleted;      return LoopDeletionResult::Deleted;    } @@ -200,7 +209,12 @@ static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT,    }    LLVM_DEBUG(dbgs() << "Loop is invariant, delete it!"); -  deleteDeadLoop(L, &DT, &SE, &LI); +  ORE.emit([&]() { +    return OptimizationRemark(DEBUG_TYPE, "Invariant", L->getStartLoc(), +                              L->getHeader()) +           << "Loop deleted because it is invariant"; +  }); +  deleteDeadLoop(L, &DT, &SE, &LI, MSSA);    ++NumDeleted;    return LoopDeletionResult::Deleted; @@ -212,15 +226,22 @@ PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM,    LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: ");    LLVM_DEBUG(L.dump()); -  std::string LoopName = L.getName(); -  auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI); +  std::string LoopName = std::string(L.getName()); +  // For the new PM, we can't use OptimizationRemarkEmitter as an analysis +  // pass. Function analyses need to be preserved across loop transformations +  // but ORE cannot be preserved (see comment before the pass definition). +  OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); +  auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI, AR.MSSA, ORE);    if (Result == LoopDeletionResult::Unmodified)      return PreservedAnalyses::all();    if (Result == LoopDeletionResult::Deleted)      Updater.markLoopAsDeleted(L, LoopName); -  return getLoopPassPreservedAnalyses(); +  auto PA = getLoopPassPreservedAnalyses(); +  if (AR.MSSA) +    PA.preserve<MemorySSAAnalysis>(); +  return PA;  }  namespace { @@ -235,6 +256,7 @@ public:    bool runOnLoop(Loop *L, LPPassManager &) override;    void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addPreserved<MemorySSAWrapperPass>();      getLoopAnalysisUsage(AU);    }  }; @@ -255,11 +277,19 @@ bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {    DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();    ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();    LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +  auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); +  MemorySSA *MSSA = nullptr; +  if (MSSAAnalysis) +    MSSA = &MSSAAnalysis->getMSSA(); +  // For the old PM, we can't use OptimizationRemarkEmitter as an analysis +  // pass.  Function analyses need to be preserved across loop transformations +  // but ORE cannot be preserved (see comment before the pass definition). +  OptimizationRemarkEmitter ORE(L->getHeader()->getParent());    LLVM_DEBUG(dbgs() << "Analyzing Loop for deletion: ");    LLVM_DEBUG(L->dump()); -  LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI); +  LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI, MSSA, ORE);    if (Result == LoopDeletionResult::Deleted)      LPM.markLoopAsDeleted(*L); diff --git a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp index 8e04e6e0ffe8..7867a5468891 100644 --- a/llvm/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -789,12 +789,6 @@ public:      // instructions to partitions.      Partitions.setupPartitionIdOnInstructions(); -    // To keep things simple have an empty preheader before we version or clone -    // the loop.  (Also split if this has no predecessor, i.e. entry, because we -    // rely on PH having a predecessor.) -    if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator()) -      SplitBlock(PH, PH->getTerminator(), DT, LI); -      // If we need run-time checks, version the loop now.      auto PtrToPartition = Partitions.computePartitionSetForPointers(*LAI);      const auto *RtPtrChecking = LAI->getRuntimePointerChecking(); @@ -807,6 +801,12 @@ public:                    "may not insert runtime check with convergent operation");      } +    // To keep things simple have an empty preheader before we version or clone +    // the loop.  (Also split if this has no predecessor, i.e. entry, because we +    // rely on PH having a predecessor.) +    if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator()) +      SplitBlock(PH, PH->getTerminator(), DT, LI); +      if (!Pred.isAlwaysTrue() || !Checks.empty()) {        assert(!LAI->hasConvergentOp() && "inserting illegal loop versioning"); @@ -903,15 +903,14 @@ private:    /// \p PtrToPartition contains the partition number for pointers.  Partition    /// number -1 means that the pointer is used in multiple partitions.  In this    /// case we can't safely omit the check. -  SmallVector<RuntimePointerChecking::PointerCheck, 4> -  includeOnlyCrossPartitionChecks( -      const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &AllChecks, +  SmallVector<RuntimePointerCheck, 4> includeOnlyCrossPartitionChecks( +      const SmallVectorImpl<RuntimePointerCheck> &AllChecks,        const SmallVectorImpl<int> &PtrToPartition,        const RuntimePointerChecking *RtPtrChecking) { -    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; +    SmallVector<RuntimePointerCheck, 4> Checks;      copy_if(AllChecks, std::back_inserter(Checks), -            [&](const RuntimePointerChecking::PointerCheck &Check) { +            [&](const RuntimePointerCheck &Check) {                for (unsigned PtrIdx1 : Check.first->Members)                  for (unsigned PtrIdx2 : Check.second->Members)                    // Only include this check if there is a pair of pointers diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index e1738f08eb23..20edc8699d79 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -86,11 +86,15 @@ STATISTIC(UnknownTripCount, "Loop has unknown trip count");  STATISTIC(UncomputableTripCount, "SCEV cannot compute trip count of loop");  STATISTIC(NonEqualTripCount, "Loop trip counts are not the same");  STATISTIC(NonAdjacent, "Loops are not adjacent"); -STATISTIC(NonEmptyPreheader, "Loop has a non-empty preheader"); +STATISTIC( +    NonEmptyPreheader, +    "Loop has a non-empty preheader with instructions that cannot be moved");  STATISTIC(FusionNotBeneficial, "Fusion is not beneficial");  STATISTIC(NonIdenticalGuards, "Candidates have different guards"); -STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block"); -STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block"); +STATISTIC(NonEmptyExitBlock, "Candidate has a non-empty exit block with " +                             "instructions that cannot be moved"); +STATISTIC(NonEmptyGuardBlock, "Candidate has a non-empty guard block with " +                              "instructions that cannot be moved");  STATISTIC(NotRotated, "Candidate is not rotated");  enum FusionDependenceAnalysisChoice { @@ -738,33 +742,40 @@ private:              continue;            } -          // The following three checks look for empty blocks in FC0 and FC1. If -          // any of these blocks are non-empty, we do not fuse. This is done -          // because we currently do not have the safety checks to determine if -          // it is safe to move the blocks past other blocks in the loop. Once -          // these checks are added, these conditions can be relaxed. -          if (!isEmptyPreheader(*FC1)) { -            LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty " -                                 "preheader. Not fusing.\n"); +          if (!isSafeToMoveBefore(*FC1->Preheader, +                                  *FC0->Preheader->getTerminator(), DT, &PDT, +                                  &DI)) { +            LLVM_DEBUG(dbgs() << "Fusion candidate contains unsafe " +                                 "instructions in preheader. Not fusing.\n");              reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1,                                                         NonEmptyPreheader);              continue;            } -          if (FC0->GuardBranch && !isEmptyExitBlock(*FC0)) { -            LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty exit " -                                 "block. Not fusing.\n"); -            reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, -                                                       NonEmptyExitBlock); -            continue; -          } +          if (FC0->GuardBranch) { +            assert(FC1->GuardBranch && "Expecting valid FC1 guard branch"); + +            if (!isSafeToMoveBefore(*FC0->ExitBlock, +                                    *FC1->ExitBlock->getFirstNonPHIOrDbg(), DT, +                                    &PDT, &DI)) { +              LLVM_DEBUG(dbgs() << "Fusion candidate contains unsafe " +                                   "instructions in exit block. Not fusing.\n"); +              reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, +                                                         NonEmptyExitBlock); +              continue; +            } -          if (FC1->GuardBranch && !isEmptyGuardBlock(*FC1)) { -            LLVM_DEBUG(dbgs() << "Fusion candidate does not have empty guard " -                                 "block. Not fusing.\n"); -            reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, -                                                       NonEmptyGuardBlock); -            continue; +            if (!isSafeToMoveBefore( +                    *FC1->GuardBranch->getParent(), +                    *FC0->GuardBranch->getParent()->getTerminator(), DT, &PDT, +                    &DI)) { +              LLVM_DEBUG(dbgs() +                         << "Fusion candidate contains unsafe " +                            "instructions in guard block. Not fusing.\n"); +              reportLoopFusion<OptimizationRemarkMissed>(*FC0, *FC1, +                                                         NonEmptyGuardBlock); +              continue; +            }            }            // Check the dependencies across the loops and do not fuse if it would @@ -1075,38 +1086,6 @@ private:        return (FC1.GuardBranch->getSuccessor(1) == FC1.Preheader);    } -  /// Check that the guard for \p FC *only* contains the cmp/branch for the -  /// guard. -  /// Once we are able to handle intervening code, any code in the guard block -  /// for FC1 will need to be treated as intervening code and checked whether -  /// it can safely move around the loops. -  bool isEmptyGuardBlock(const FusionCandidate &FC) const { -    assert(FC.GuardBranch && "Expecting a fusion candidate with guard branch."); -    if (auto *CmpInst = dyn_cast<Instruction>(FC.GuardBranch->getCondition())) { -      auto *GuardBlock = FC.GuardBranch->getParent(); -      // If the generation of the cmp value is in GuardBlock, then the size of -      // the guard block should be 2 (cmp + branch). If the generation of the -      // cmp value is in a different block, then the size of the guard block -      // should only be 1. -      if (CmpInst->getParent() == GuardBlock) -        return GuardBlock->size() == 2; -      else -        return GuardBlock->size() == 1; -    } - -    return false; -  } - -  bool isEmptyPreheader(const FusionCandidate &FC) const { -    assert(FC.Preheader && "Expecting a valid preheader"); -    return FC.Preheader->size() == 1; -  } - -  bool isEmptyExitBlock(const FusionCandidate &FC) const { -    assert(FC.ExitBlock && "Expecting a valid exit block"); -    return FC.ExitBlock->size() == 1; -  } -    /// Simplify the condition of the latch branch of \p FC to true, when both of    /// its successors are the same.    void simplifyLatchBranch(const FusionCandidate &FC) const { @@ -1123,7 +1102,7 @@ private:    /// Move instructions from FC0.Latch to FC1.Latch. If FC0.Latch has an unique    /// successor, then merge FC0.Latch with its unique successor.    void mergeLatch(const FusionCandidate &FC0, const FusionCandidate &FC1) { -    moveInstsBottomUp(*FC0.Latch, *FC1.Latch, DT, PDT, DI); +    moveInstructionsToTheBeginning(*FC0.Latch, *FC1.Latch, DT, PDT, DI);      if (BasicBlock *Succ = FC0.Latch->getUniqueSuccessor()) {        MergeBlockIntoPredecessor(Succ, &DTU, &LI);        DTU.flush(); @@ -1166,6 +1145,10 @@ private:      LLVM_DEBUG(dbgs() << "Fusion Candidate 0: \n"; FC0.dump();                 dbgs() << "Fusion Candidate 1: \n"; FC1.dump();); +    // Move instructions from the preheader of FC1 to the end of the preheader +    // of FC0. +    moveInstructionsToTheEnd(*FC1.Preheader, *FC0.Preheader, DT, PDT, DI); +      // Fusing guarded loops is handled slightly differently than non-guarded      // loops and has been broken out into a separate method instead of trying to      // intersperse the logic within a single method. @@ -1382,6 +1365,14 @@ private:      BasicBlock *FC0NonLoopBlock = FC0.getNonLoopBlock();      BasicBlock *FC1NonLoopBlock = FC1.getNonLoopBlock(); +    // Move instructions from the exit block of FC0 to the beginning of the exit +    // block of FC1. +    moveInstructionsToTheBeginning(*FC0.ExitBlock, *FC1.ExitBlock, DT, PDT, DI); + +    // Move instructions from the guard block of FC1 to the end of the guard +    // block of FC0. +    moveInstructionsToTheEnd(*FC1GuardBlock, *FC0GuardBlock, DT, PDT, DI); +      assert(FC0NonLoopBlock == FC1GuardBlock && "Loops are not adjacent");      SmallVector<DominatorTree::UpdateType, 8> TreeUpdates; @@ -1394,6 +1385,7 @@ private:      // Thus, one path from the guard goes to the preheader for FC0 (and thus      // executes the new fused loop) and the other path goes to the NonLoopBlock      // for FC1 (where FC1 guard would have gone if FC1 was not executed). +    FC1NonLoopBlock->replacePhiUsesWith(FC1GuardBlock, FC0GuardBlock);      FC0.GuardBranch->replaceUsesOfWith(FC0NonLoopBlock, FC1NonLoopBlock);      FC0.ExitBlock->getTerminator()->replaceUsesOfWith(FC1GuardBlock,                                                        FC1.Header); @@ -1545,7 +1537,10 @@ private:      // Update DT/PDT      DTU.applyUpdates(TreeUpdates); +    LI.removeBlock(FC1GuardBlock);      LI.removeBlock(FC1.Preheader); +    LI.removeBlock(FC0.ExitBlock); +    DTU.deleteBB(FC1GuardBlock);      DTU.deleteBB(FC1.Preheader);      DTU.deleteBB(FC0.ExitBlock);      DTU.flush(); diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index b77843d7cd71..3cb4df12e9b0 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -51,9 +51,11 @@  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/MustExecute.h"  #include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/TargetTransformInfo.h" @@ -91,6 +93,7 @@  #include "llvm/Transforms/Utils/BuildLibCalls.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include <algorithm>  #include <cassert>  #include <cstdint> @@ -123,15 +126,19 @@ class LoopIdiomRecognize {    const DataLayout *DL;    OptimizationRemarkEmitter &ORE;    bool ApplyCodeSizeHeuristics; +  std::unique_ptr<MemorySSAUpdater> MSSAU;  public:    explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT,                                LoopInfo *LI, ScalarEvolution *SE,                                TargetLibraryInfo *TLI, -                              const TargetTransformInfo *TTI, +                              const TargetTransformInfo *TTI, MemorySSA *MSSA,                                const DataLayout *DL,                                OptimizationRemarkEmitter &ORE) -      : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) {} +      : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL), ORE(ORE) { +    if (MSSA) +      MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); +  }    bool runOnLoop(Loop *L); @@ -224,13 +231,17 @@ public:          &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(              *L->getHeader()->getParent());      const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout(); +    auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); +    MemorySSA *MSSA = nullptr; +    if (MSSAAnalysis) +      MSSA = &MSSAAnalysis->getMSSA();      // For the old PM, we can't use OptimizationRemarkEmitter as an analysis      // pass.  Function analyses need to be preserved across loop transformations      // but ORE cannot be preserved (see comment before the pass definition).      OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); -    LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL, ORE); +    LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, MSSA, DL, ORE);      return LIR.runOnLoop(L);    } @@ -239,6 +250,7 @@ public:    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.addRequired<TargetLibraryInfoWrapperPass>();      AU.addRequired<TargetTransformInfoWrapperPass>(); +    AU.addPreserved<MemorySSAWrapperPass>();      getLoopAnalysisUsage(AU);    }  }; @@ -252,23 +264,20 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM,                                                LPMUpdater &) {    const auto *DL = &L.getHeader()->getModule()->getDataLayout(); -  const auto &FAM = -      AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); -  Function *F = L.getHeader()->getParent(); - -  auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); -  // FIXME: This should probably be optional rather than required. -  if (!ORE) -    report_fatal_error( -        "LoopIdiomRecognizePass: OptimizationRemarkEmitterAnalysis not cached " -        "at a higher level"); +  // For the new PM, we also can't use OptimizationRemarkEmitter as an analysis +  // pass.  Function analyses need to be preserved across loop transformations +  // but ORE cannot be preserved (see comment before the pass definition). +  OptimizationRemarkEmitter ORE(L.getHeader()->getParent()); -  LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL, -                         *ORE); +  LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, +                         AR.MSSA, DL, ORE);    if (!LIR.runOnLoop(&L))      return PreservedAnalyses::all(); -  return getLoopPassPreservedAnalyses(); +  auto PA = getLoopPassPreservedAnalyses(); +  if (AR.MSSA) +    PA.preserve<MemorySSAAnalysis>(); +  return PA;  }  INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom", @@ -339,14 +348,14 @@ bool LoopIdiomRecognize::runOnCountableLoop() {                      << "] Countable Loop %" << CurLoop->getHeader()->getName()                      << "\n"); -  bool MadeChange = false; -    // The following transforms hoist stores/memsets into the loop pre-header. -  // Give up if the loop has instructions may throw. +  // Give up if the loop has instructions that may throw.    SimpleLoopSafetyInfo SafetyInfo;    SafetyInfo.computeLoopSafetyInfo(CurLoop);    if (SafetyInfo.anyBlockMayThrow()) -    return MadeChange; +    return false; + +  bool MadeChange = false;    // Scan all the blocks in the loop that are not in subloops.    for (auto *BB : CurLoop->getBlocks()) { @@ -968,11 +977,17 @@ bool LoopIdiomRecognize::processLoopStridedStore(      Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy);      NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes});    } +  NewCall->setDebugLoc(TheStore->getDebugLoc()); + +  if (MSSAU) { +    MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB( +        NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator); +    MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); +  }    LLVM_DEBUG(dbgs() << "  Formed memset: " << *NewCall << "\n"                      << "    from store to: " << *Ev << " at: " << *TheStore                      << "\n"); -  NewCall->setDebugLoc(TheStore->getDebugLoc());    ORE.emit([&]() {      return OptimizationRemark(DEBUG_TYPE, "ProcessLoopStridedStore", @@ -984,12 +999,40 @@ bool LoopIdiomRecognize::processLoopStridedStore(    // Okay, the memset has been formed.  Zap the original store and anything that    // feeds into it. -  for (auto *I : Stores) +  for (auto *I : Stores) { +    if (MSSAU) +      MSSAU->removeMemoryAccess(I, true);      deleteDeadInstruction(I); +  } +  if (MSSAU && VerifyMemorySSA) +    MSSAU->getMemorySSA()->verifyMemorySSA();    ++NumMemSet;    return true;  } +class ExpandedValuesCleaner { +  SCEVExpander &Expander; +  TargetLibraryInfo *TLI; +  SmallVector<Value *, 4> ExpandedValues; +  bool Commit = false; + +public: +  ExpandedValuesCleaner(SCEVExpander &Expander, TargetLibraryInfo *TLI) +      : Expander(Expander), TLI(TLI) {} + +  void add(Value *V) { ExpandedValues.push_back(V); } + +  void commit() { Commit = true; } + +  ~ExpandedValuesCleaner() { +    if (!Commit) { +      Expander.clear(); +      for (auto *V : ExpandedValues) +        RecursivelyDeleteTriviallyDeadInstructions(V, TLI); +    } +  } +}; +  /// If the stored value is a strided load in the same loop with the same stride  /// this may be transformable into a memcpy.  This kicks in for stuff like  /// for (i) A[i] = B[i]; @@ -1020,6 +1063,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,    IRBuilder<> Builder(Preheader->getTerminator());    SCEVExpander Expander(*SE, *DL, "loop-idiom"); +  ExpandedValuesCleaner EVC(Expander, TLI); +    const SCEV *StrStart = StoreEv->getStart();    unsigned StrAS = SI->getPointerAddressSpace();    Type *IntIdxTy = Builder.getIntNTy(DL->getIndexSizeInBits(StrAS)); @@ -1036,16 +1081,13 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,    // checking everything.    Value *StoreBasePtr = Expander.expandCodeFor(        StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator()); +  EVC.add(StoreBasePtr);    SmallPtrSet<Instruction *, 1> Stores;    Stores.insert(SI);    if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, -                            StoreSize, *AA, Stores)) { -    Expander.clear(); -    // If we generated new code for the base pointer, clean up. -    RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); +                            StoreSize, *AA, Stores))      return false; -  }    const SCEV *LdStart = LoadEv->getStart();    unsigned LdAS = LI->getPointerAddressSpace(); @@ -1058,15 +1100,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,    // mutated by the loop.    Value *LoadBasePtr = Expander.expandCodeFor(        LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); +  EVC.add(LoadBasePtr);    if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, -                            StoreSize, *AA, Stores)) { -    Expander.clear(); -    // If we generated new code for the base pointer, clean up. -    RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); -    RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); +                            StoreSize, *AA, Stores))      return false; -  }    if (avoidLIRForMultiBlockLoop())      return false; @@ -1078,6 +1116,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,    Value *NumBytes =        Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); +  EVC.add(NumBytes);    CallInst *NewCall = nullptr;    // Check whether to generate an unordered atomic memcpy: @@ -1089,8 +1128,9 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,    else {      // We cannot allow unaligned ops for unordered load/store, so reject      // anything where the alignment isn't at least the element size. -    unsigned Align = std::min(SI->getAlignment(), LI->getAlignment()); -    if (Align < StoreSize) +    const Align StoreAlign = SI->getAlign(); +    const Align LoadAlign = LI->getAlign(); +    if (StoreAlign < StoreSize || LoadAlign < StoreSize)        return false;      // If the element.atomic memcpy is not lowered into explicit @@ -1104,11 +1144,17 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,      // Note that unordered atomic loads/stores are *required* by the spec to      // have an alignment but non-atomic loads/stores may not.      NewCall = Builder.CreateElementUnorderedAtomicMemCpy( -        StoreBasePtr, SI->getAlignment(), LoadBasePtr, LI->getAlignment(), -        NumBytes, StoreSize); +        StoreBasePtr, StoreAlign, LoadBasePtr, LoadAlign, NumBytes, +        StoreSize);    }    NewCall->setDebugLoc(SI->getDebugLoc()); +  if (MSSAU) { +    MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB( +        NewCall, nullptr, NewCall->getParent(), MemorySSA::BeforeTerminator); +    MSSAU->insertDef(cast<MemoryDef>(NewMemAcc), true); +  } +    LLVM_DEBUG(dbgs() << "  Formed memcpy: " << *NewCall << "\n"                      << "    from load ptr=" << *LoadEv << " at: " << *LI << "\n"                      << "    from store ptr=" << *StoreEv << " at: " << *SI @@ -1124,8 +1170,13 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI,    // Okay, the memcpy has been formed.  Zap the original store and anything that    // feeds into it. +  if (MSSAU) +    MSSAU->removeMemoryAccess(SI, true);    deleteDeadInstruction(SI); +  if (MSSAU && VerifyMemorySSA) +    MSSAU->getMemorySSA()->verifyMemorySSA();    ++NumMemCpy; +  EVC.commit();    return true;  } @@ -1502,18 +1553,20 @@ bool LoopIdiomRecognize::recognizeAndInsertFFS() {    //  %inc = add nsw %i.0, 1    //  br i1 %tobool -  const Value *Args[] = -      {InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext()) -                        : ConstantInt::getFalse(InitX->getContext())}; +  const Value *Args[] = { +      InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext()) +                       : ConstantInt::getFalse(InitX->getContext())};    // @llvm.dbg doesn't count as they have no semantic effect.    auto InstWithoutDebugIt = CurLoop->getHeader()->instructionsWithoutDebug();    uint32_t HeaderSize =        std::distance(InstWithoutDebugIt.begin(), InstWithoutDebugIt.end()); +  IntrinsicCostAttributes Attrs(IntrinID, InitX->getType(), Args); +  int Cost = +    TTI->getIntrinsicInstrCost(Attrs, TargetTransformInfo::TCK_SizeAndLatency);    if (HeaderSize != IdiomCanonicalSize && -      TTI->getIntrinsicCost(IntrinID, InitX->getType(), Args) > -          TargetTransformInfo::TCC_Basic) +      Cost > TargetTransformInfo::TCC_Basic)      return false;    transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX, diff --git a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp index 901204181a7c..3153a8721193 100644 --- a/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -68,7 +68,7 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI,    // While simplifying we may discover dead code or cause code to become dead.    // Keep track of all such instructions and we will delete them at the end. -  SmallVector<Instruction *, 8> DeadInsts; +  SmallVector<WeakTrackingVH, 8> DeadInsts;    // First we want to create an RPO traversal of the loop body. By processing in    // RPO we can ensure that definitions are processed prior to uses (for non PHI diff --git a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp index 6ce2d06058cf..7787c0bccd4c 100644 --- a/llvm/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -412,7 +412,6 @@ public:  private:    bool adjustLoopLinks(); -  void adjustLoopPreheaders();    bool adjustLoopBranches();    Loop *OuterLoop; @@ -580,6 +579,12 @@ struct LoopInterchange : public LoopPass {      LIT.transform();      LLVM_DEBUG(dbgs() << "Loops interchanged.\n");      LoopsInterchanged++; + +    assert(InnerLoop->isLCSSAForm(*DT) && +           "Inner loop not left in LCSSA form after loop interchange!"); +    assert(OuterLoop->isLCSSAForm(*DT) && +           "Outer loop not left in LCSSA form after loop interchange!"); +      return true;    }  }; @@ -689,7 +694,7 @@ bool LoopInterchangeLegality::findInductionAndReductions(        // PHIs in inner loops need to be part of a reduction in the outer loop,        // discovered when checking the PHIs of the outer loop earlier.        if (!InnerLoop) { -        if (OuterInnerReductions.find(&PHI) == OuterInnerReductions.end()) { +        if (!OuterInnerReductions.count(&PHI)) {            LLVM_DEBUG(dbgs() << "Inner loop PHI is not part of reductions "                                 "across the outer loop.\n");            return false; @@ -903,8 +908,8 @@ areInnerLoopExitPHIsSupported(Loop *InnerL, Loop *OuterL,        return false;      if (any_of(PHI.users(), [&Reductions, OuterL](User *U) {            PHINode *PN = dyn_cast<PHINode>(U); -          return !PN || (Reductions.find(PN) == Reductions.end() && -                         OuterL->contains(PN->getParent())); +          return !PN || +                 (!Reductions.count(PN) && OuterL->contains(PN->getParent()));          })) {        return false;      } @@ -1319,6 +1324,23 @@ static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) {                  FromBB->getTerminator()->getIterator());  } +/// Swap instructions between \p BB1 and \p BB2 but keep terminators intact. +static void swapBBContents(BasicBlock *BB1, BasicBlock *BB2) { +  // Save all non-terminator instructions of BB1 into TempInstrs and unlink them +  // from BB1 afterwards. +  auto Iter = map_range(*BB1, [](Instruction &I) { return &I; }); +  SmallVector<Instruction *, 4> TempInstrs(Iter.begin(), std::prev(Iter.end())); +  for (Instruction *I : TempInstrs) +    I->removeFromParent(); + +  // Move instructions from BB2 to BB1. +  moveBBContents(BB2, BB1->getTerminator()); + +  // Move instructions from TempInstrs to BB2. +  for (Instruction *I : TempInstrs) +    I->insertBefore(BB2->getTerminator()); +} +  // Update BI to jump to NewBB instead of OldBB. Records updates to the  // dominator tree in DTUpdates. If \p MustUpdateOnce is true, assert that  // \p OldBB  is exactly once in BI's successor list. @@ -1560,13 +1582,11 @@ bool LoopInterchangeTransform::adjustLoopBranches() {    // outer loop and all the remains to do is and updating the incoming blocks.    for (PHINode *PHI : OuterLoopPHIs) {      PHI->moveBefore(InnerLoopHeader->getFirstNonPHI()); -    assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() && -           "Expected a reduction PHI node"); +    assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node");    }    for (PHINode *PHI : InnerLoopPHIs) {      PHI->moveBefore(OuterLoopHeader->getFirstNonPHI()); -    assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() && -           "Expected a reduction PHI node"); +    assert(OuterInnerReductions.count(PHI) && "Expected a reduction PHI node");    }    // Update the incoming blocks for moved PHI nodes. @@ -1578,30 +1598,17 @@ bool LoopInterchangeTransform::adjustLoopBranches() {    return true;  } -void LoopInterchangeTransform::adjustLoopPreheaders() { -  // We have interchanged the preheaders so we need to interchange the data in -  // the preheader as well. -  // This is because the content of inner preheader was previously executed -  // inside the outer loop. -  BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); -  BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); -  BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); -  BranchInst *InnerTermBI = -      cast<BranchInst>(InnerLoopPreHeader->getTerminator()); - -  // These instructions should now be executed inside the loop. -  // Move instruction into a new block after outer header. -  moveBBContents(InnerLoopPreHeader, OuterLoopHeader->getTerminator()); -  // These instructions were not executed previously in the loop so move them to -  // the older inner loop preheader. -  moveBBContents(OuterLoopPreHeader, InnerTermBI); -} -  bool LoopInterchangeTransform::adjustLoopLinks() {    // Adjust all branches in the inner and outer loop.    bool Changed = adjustLoopBranches(); -  if (Changed) -    adjustLoopPreheaders(); +  if (Changed) { +    // We have interchanged the preheaders so we need to interchange the data in +    // the preheaders as well. This is because the content of the inner +    // preheader was previously executed inside the outer loop. +    BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); +    BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); +    swapBBContents(OuterLoopPreHeader, InnerLoopPreHeader); +  }    return Changed;  } diff --git a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp index 4e1b4e87ebc9..4412b3079461 100644 --- a/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -38,7 +38,6 @@  #include "llvm/Analysis/MemorySSA.h"  #include "llvm/Analysis/ProfileSummaryInfo.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/TargetTransformInfo.h" @@ -58,6 +57,7 @@  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils.h"  #include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include "llvm/Transforms/Utils/SizeOpts.h"  #include <algorithm>  #include <cassert> @@ -377,7 +377,7 @@ public:    /// Determine the pointer alias checks to prove that there are no    /// intervening stores. -  SmallVector<RuntimePointerChecking::PointerCheck, 4> collectMemchecks( +  SmallVector<RuntimePointerCheck, 4> collectMemchecks(        const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) {      SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath = @@ -391,10 +391,10 @@ public:                     std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr));      const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); -    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; +    SmallVector<RuntimePointerCheck, 4> Checks;      copy_if(AllChecks, std::back_inserter(Checks), -            [&](const RuntimePointerChecking::PointerCheck &Check) { +            [&](const RuntimePointerCheck &Check) {                for (auto PtrIdx1 : Check.first->Members)                  for (auto PtrIdx2 : Check.second->Members)                    if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, @@ -432,12 +432,12 @@ public:      Value *Ptr = Cand.Load->getPointerOperand();      auto *PtrSCEV = cast<SCEVAddRecExpr>(PSE.getSCEV(Ptr));      auto *PH = L->getLoopPreheader(); +    assert(PH && "Preheader should exist!");      Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(),                                            PH->getTerminator());      Value *Initial = new LoadInst(          Cand.Load->getType(), InitialPtr, "load_initial", -        /* isVolatile */ false, MaybeAlign(Cand.Load->getAlignment()), -        PH->getTerminator()); +        /* isVolatile */ false, Cand.Load->getAlign(), PH->getTerminator());      PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded",                                     &L->getHeader()->front()); @@ -520,8 +520,7 @@ public:      // Check intervening may-alias stores.  These need runtime checks for alias      // disambiguation. -    SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks = -        collectMemchecks(Candidates); +    SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates);      // Too many checks are likely to outweigh the benefits of forwarding.      if (Checks.size() > Candidates.size() * CheckPerElim) { @@ -535,6 +534,11 @@ public:        return false;      } +    if (!L->isLoopSimplifyForm()) { +      LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form"); +      return false; +    } +      if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) {        if (LAI.hasConvergentOp()) {          LLVM_DEBUG(dbgs() << "Versioning is needed but not allowed with " @@ -554,11 +558,6 @@ public:          return false;        } -      if (!L->isLoopSimplifyForm()) { -        LLVM_DEBUG(dbgs() << "Loop is not is loop-simplify form"); -        return false; -      } -        // Point of no-return, start the transformation.  First, version the loop        // if necessary. @@ -697,8 +696,8 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F,    auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);    auto &AA = AM.getResult<AAManager>(F);    auto &AC = AM.getResult<AssumptionAnalysis>(F); -  auto &MAM = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); -  auto *PSI = MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); +  auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); +  auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());    auto *BFI = (PSI && PSI->hasProfileSummary()) ?        &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr;    MemorySSA *MSSA = EnableMSSALoopDependency diff --git a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp index f3bfbd3564ab..98889a9df116 100644 --- a/llvm/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -6,6 +6,7 @@  //  //===----------------------------------------------------------------------===// +#include "llvm/Support/TimeProfiler.h"  #include "llvm/Transforms/Scalar/LoopPassManager.h"  #include "llvm/Analysis/LoopInfo.h" @@ -33,15 +34,19 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &,    // instrumenting callbacks for the passes later.    PassInstrumentation PI = AM.getResult<PassInstrumentationAnalysis>(L, AR);    for (auto &Pass : Passes) { -    if (DebugLogging) -      dbgs() << "Running pass: " << Pass->name() << " on " << L; -      // Check the PassInstrumentation's BeforePass callbacks before running the      // pass, skip its execution completely if asked to (callback returns false).      if (!PI.runBeforePass<Loop>(*Pass, L))        continue; -    PreservedAnalyses PassPA = Pass->run(L, AM, AR, U); +    if (DebugLogging) +      dbgs() << "Running pass: " << Pass->name() << " on " << L; + +    PreservedAnalyses PassPA; +    { +      TimeTraceScope TimeScope(Pass->name(), L.getName()); +      PassPA = Pass->run(L, AM, AR, U); +    }      // do not pass deleted Loop into the instrumentation      if (U.skipCurrentLoop()) diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index 1a42f6b23443..edde22d6708f 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -184,7 +184,6 @@  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/GlobalValue.h" @@ -199,6 +198,7 @@  #include "llvm/Transforms/Utils/GuardUtils.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #define DEBUG_TYPE "loop-predication" @@ -268,7 +268,7 @@ class LoopPredication {    /// Return an insertion point suitable for inserting a safe to speculate    /// instruction whose only user will be 'User' which has operands 'Ops'.  A    /// trivial result would be the at the User itself, but we try to return a -  /// loop invariant location if possible.   +  /// loop invariant location if possible.    Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);    /// Same as above, *except* that this uses the SCEV definition of invariant    /// which is that an expression *can be made* invariant via SCEVExpander. @@ -278,7 +278,7 @@ class LoopPredication {    /// Return true if the value is known to produce a single fixed value across    /// all iterations on which it executes.  Note that this does not imply -  /// speculation safety.  That must be established seperately.   +  /// speculation safety.  That must be established separately.    bool isLoopInvariantValue(const SCEV* S);    Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, @@ -342,7 +342,7 @@ public:  };  char LoopPredicationLegacyPass::ID = 0; -} // end namespace llvm +} // end namespace  INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",                        "Loop predication", false, false) @@ -358,11 +358,12 @@ Pass *llvm::createLoopPredicationPass() {  PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,                                             LoopStandardAnalysisResults &AR,                                             LPMUpdater &U) { -  const auto &FAM = -      AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();    Function *F = L.getHeader()->getParent(); -  auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); -  LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, BPI); +  // For the new PM, we also can't use BranchProbabilityInfo as an analysis +  // pass. Function analyses need to be preserved across loop transformations +  // but BPI is not preserved, hence a newly built one is needed. +  BranchProbabilityInfo BPI(*F, AR.LI, &AR.TLI); +  LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, &BPI);    if (!LP.runOnLoop(&L))      return PreservedAnalyses::all(); @@ -397,7 +398,7 @@ LoopPredication::parseLoopICmp(ICmpInst *ICI) {  }  Value *LoopPredication::expandCheck(SCEVExpander &Expander, -                                    Instruction *Guard,  +                                    Instruction *Guard,                                      ICmpInst::Predicate Pred, const SCEV *LHS,                                      const SCEV *RHS) {    Type *Ty = LHS->getType(); @@ -521,7 +522,7 @@ Instruction *LoopPredication::findInsertPt(Instruction *Use,    return Preheader->getTerminator();  } -bool LoopPredication::isLoopInvariantValue(const SCEV* S) {  +bool LoopPredication::isLoopInvariantValue(const SCEV* S) {    // Handling expressions which produce invariant results, but *haven't* yet    // been removed from the loop serves two important purposes.    // 1) Most importantly, it resolves a pass ordering cycle which would @@ -534,12 +535,12 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) {    // much more obviously in the IR.  Otherwise, the cost modeling for other    // transforms would end up needing to duplicate all of this logic to model a    // check which becomes predictable based on a modeled peel or unswitch. -  //  +  //    // The cost of doing so in the worst case is an extra fill from the stack  in    // the loop to materialize the loop invariant test value instead of checking    // against the original IV which is presumable in a register inside the loop.    // Such cases are presumably rare, and hint at missing oppurtunities for -  // other passes.  +  // other passes.    if (SE->isLoopInvariant(S, L))      // Note: This the SCEV variant, so the original Value* may be within the @@ -547,7 +548,7 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) {      return true;    // Handle a particular important case which SCEV doesn't yet know about which -  // shows up in range checks on arrays with immutable lengths.   +  // shows up in range checks on arrays with immutable lengths.    // TODO: This should be sunk inside SCEV.    if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))      if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) @@ -574,7 +575,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(    const SCEV *LatchLimit = LatchCheck.Limit;    // Subtlety: We need all the values to be *invariant* across all iterations,    // but we only need to check expansion safety for those which *aren't* -  // already guaranteed to dominate the guard.   +  // already guaranteed to dominate the guard.    if (!isLoopInvariantValue(GuardStart) ||        !isLoopInvariantValue(GuardLimit) ||        !isLoopInvariantValue(LatchStart) || @@ -598,7 +599,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(    LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");    LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");    LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); -  +    auto *LimitCheck =        expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);    auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred, @@ -617,7 +618,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(    const SCEV *LatchLimit = LatchCheck.Limit;    // Subtlety: We need all the values to be *invariant* across all iterations,    // but we only need to check expansion safety for those which *aren't* -  // already guaranteed to dominate the guard.   +  // already guaranteed to dominate the guard.    if (!isLoopInvariantValue(GuardStart) ||        !isLoopInvariantValue(GuardLimit) ||        !isLoopInvariantValue(LatchStart) || @@ -658,7 +659,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(  static void normalizePredicate(ScalarEvolution *SE, Loop *L,                                 LoopICmp& RC) {    // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the -  // ULT/UGE form for ease of handling by our caller.  +  // ULT/UGE form for ease of handling by our caller.    if (ICmpInst::isEquality(RC.Pred) &&        RC.IV->getStepRecurrence(*SE)->isOne() &&        SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit)) @@ -1020,17 +1021,6 @@ static const SCEV *getMinAnalyzeableBackedgeTakenCount(ScalarEvolution &SE,    return SE.getUMinFromMismatchedTypes(ExitCounts);  } -/// Return true if we can be fairly sure that executing block BB will probably -/// lead to executing an __llvm_deoptimize.  This is a profitability heuristic, -/// not a legality constraint. -static bool isVeryLikelyToDeopt(BasicBlock *BB) { -  while (BB->getUniqueSuccessor()) -    // Will skip side effects, that's okay -    BB = BB->getUniqueSuccessor(); - -  return BB->getTerminatingDeoptimizeCall(); -} -  /// This implements an analogous, but entirely distinct transform from the main  /// loop predication transform.  This one is phrased in terms of using a  /// widenable branch *outside* the loop to allow us to simplify loop exits in a @@ -1054,7 +1044,7 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {    // inserting a branch on the value which can be either poison or undef.  In    // this case, the branch can legally go either way; we just need to avoid    // introducing UB.  This is achieved through the use of the freeze -  // instruction.   +  // instruction.    SmallVector<BasicBlock *, 16> ExitingBlocks;    L->getExitingBlocks(ExitingBlocks); @@ -1082,7 +1072,7 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {    // analyzeable after dropping widenability.    {      bool Invalidate = false; -     +      for (auto *ExitingBB : ExitingBlocks) {        if (LI->getLoopFor(ExitingBB) != L)          continue; @@ -1150,10 +1140,13 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {      const bool ExitIfTrue = !L->contains(*succ_begin(ExitingBB));      BasicBlock *ExitBB = BI->getSuccessor(ExitIfTrue ? 0 : 1); -    if (!isVeryLikelyToDeopt(ExitBB)) -      // Profitability: indicator of rarely/never taken exit +    if (!ExitBB->getPostdominatingDeoptimizeCall())        continue; +    /// Here we can be fairly sure that executing this exit will most likely +    /// lead to executing llvm.experimental.deoptimize. +    /// This is a profitability heuristic, not a legality constraint. +      // If we found a widenable exit condition, do two things:      // 1) fold the widened exit test into the widenable condition      // 2) fold the branch to untaken - avoids infinite looping diff --git a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp index da13a342ae12..3542d0a4ee73 100644 --- a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -24,7 +24,6 @@  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/ValueTracking.h" @@ -55,6 +54,7 @@  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include <cassert>  #include <cstddef>  #include <cstdint> @@ -880,6 +880,12 @@ bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) {    if (DRS.Roots.empty())      return false; +  // If the value of the base instruction is used outside the loop, we cannot +  // reroll the loop. Check for other root instructions is unnecessary because +  // they don't match any base instructions if their values are used outside. +  if (hasUsesOutsideLoop(DRS.BaseInst, L)) +    return false; +    // Consider a DAGRootSet with N-1 roots (so N different values including    //   BaseInst).    // Define d = Roots[0] - BaseInst, which should be the same as @@ -1126,7 +1132,7 @@ static bool isIgnorableInst(const Instruction *I) {      case Intrinsic::annotation:      case Intrinsic::ptr_annotation:      case Intrinsic::var_annotation: -    // TODO: the following intrinsics may also be whitelisted: +    // TODO: the following intrinsics may also be allowed:      //   lifetime_start, lifetime_end, invariant_start, invariant_end        return true;    } diff --git a/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/llvm/lib/Transforms/Scalar/LoopRotation.cpp index 0868e742f4ee..f92566ba77ce 100644 --- a/llvm/lib/Transforms/Scalar/LoopRotation.cpp +++ b/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -81,10 +81,8 @@ public:    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.addRequired<AssumptionCacheTracker>();      AU.addRequired<TargetTransformInfoWrapperPass>(); -    if (EnableMSSALoopDependency) { -      AU.addRequired<MemorySSAWrapperPass>(); +    if (EnableMSSALoopDependency)        AU.addPreserved<MemorySSAWrapperPass>(); -    }      getLoopAnalysisUsage(AU);    } @@ -101,15 +99,18 @@ public:      const SimplifyQuery SQ = getBestSimplifyQuery(*this, F);      Optional<MemorySSAUpdater> MSSAU;      if (EnableMSSALoopDependency) { -      MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); -      MSSAU = MemorySSAUpdater(MSSA); +      // Not requiring MemorySSA and getting it only if available will split +      // the loop pass pipeline when LoopRotate is being run first. +      auto *MSSAA = getAnalysisIfAvailable<MemorySSAWrapperPass>(); +      if (MSSAA) +        MSSAU = MemorySSAUpdater(&MSSAA->getMSSA());      }      return LoopRotation(L, LI, TTI, AC, &DT, &SE,                          MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ,                          false, MaxHeaderSize, false);    }  }; -} +} // end namespace  char LoopRotateLegacyPass::ID = 0;  INITIALIZE_PASS_BEGIN(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", diff --git a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index b27e65e0adb7..031e5b9c1d2c 100644 --- a/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -23,6 +23,7 @@  #include "llvm/Analysis/DomTreeUpdater.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h"  #include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/MemorySSA.h"  #include "llvm/Analysis/MemorySSAUpdater.h" @@ -30,6 +31,7 @@  #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"  #include "llvm/Analysis/TargetTransformInfo.h"  #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h"  #include "llvm/InitializePasses.h"  #include "llvm/Support/CommandLine.h"  #include "llvm/Transforms/Scalar.h" @@ -673,13 +675,13 @@ static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT,  static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI,                              ScalarEvolution &SE, MemorySSAUpdater *MSSAU, -                            bool &isLoopDeleted) { +                            bool &IsLoopDeleted) {    bool Changed = false;    // Constant-fold terminators with known constant conditions. -  Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU, isLoopDeleted); +  Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU, IsLoopDeleted); -  if (isLoopDeleted) +  if (IsLoopDeleted)      return true;    // Eliminate unconditional branches by merging blocks into their predecessors. @@ -752,7 +754,7 @@ public:      getLoopAnalysisUsage(AU);    }  }; -} +} // end namespace  char LoopSimplifyCFGLegacyPass::ID = 0;  INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index e9f368628a08..cf02ef1e83f3 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -65,12 +65,14 @@  #include "llvm/ADT/SmallSet.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/IVUsers.h"  #include "llvm/Analysis/LoopAnalysisManager.h"  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/Analysis/ScalarEvolutionNormalization.h"  #include "llvm/Analysis/TargetTransformInfo.h" @@ -109,6 +111,7 @@  #include "llvm/Transforms/Utils.h"  #include "llvm/Transforms/Utils/BasicBlockUtils.h"  #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"  #include <algorithm>  #include <cassert>  #include <cstddef> @@ -807,9 +810,14 @@ static bool isAddressUse(const TargetTransformInfo &TTI,      switch (II->getIntrinsicID()) {      case Intrinsic::memset:      case Intrinsic::prefetch: +    case Intrinsic::masked_load:        if (II->getArgOperand(0) == OperandVal)          isAddress = true;        break; +    case Intrinsic::masked_store: +      if (II->getArgOperand(1) == OperandVal) +        isAddress = true; +      break;      case Intrinsic::memmove:      case Intrinsic::memcpy:        if (II->getArgOperand(0) == OperandVal || @@ -859,6 +867,15 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI,        AccessTy.AddrSpace = OperandVal->getType()->getPointerAddressSpace();        AccessTy.MemTy = OperandVal->getType();        break; +    case Intrinsic::masked_load: +      AccessTy.AddrSpace = +          II->getArgOperand(0)->getType()->getPointerAddressSpace(); +      break; +    case Intrinsic::masked_store: +      AccessTy.MemTy = II->getOperand(0)->getType(); +      AccessTy.AddrSpace = +          II->getArgOperand(1)->getType()->getPointerAddressSpace(); +      break;      default: {        MemIntrinsicInfo IntrInfo;        if (TTI.getTgtMemIntrinsic(II, IntrInfo) && IntrInfo.PtrVal) { @@ -962,33 +979,6 @@ static bool isHighCostExpansion(const SCEV *S,    return true;  } -/// If any of the instructions in the specified set are trivially dead, delete -/// them and see if this makes any of their operands subsequently dead. -static bool -DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakTrackingVH> &DeadInsts) { -  bool Changed = false; - -  while (!DeadInsts.empty()) { -    Value *V = DeadInsts.pop_back_val(); -    Instruction *I = dyn_cast_or_null<Instruction>(V); - -    if (!I || !isInstructionTriviallyDead(I)) -      continue; - -    for (Use &O : I->operands()) -      if (Instruction *U = dyn_cast<Instruction>(O)) { -        O = nullptr; -        if (U->use_empty()) -          DeadInsts.emplace_back(U); -      } - -    I->eraseFromParent(); -    Changed = true; -  } - -  return Changed; -} -  namespace {  class LSRUse; @@ -1242,7 +1232,7 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg,      // for now LSR only handles innermost loops).      if (AR->getLoop() != L) {        // If the AddRec exists, consider it's register free and leave it alone. -      if (isExistingPhi(AR, *SE)) +      if (isExistingPhi(AR, *SE) && !TTI->shouldFavorPostInc())          return;        // It is bad to allow LSR for current loop to add induction variables @@ -1913,9 +1903,10 @@ class LSRInstance {    DominatorTree &DT;    LoopInfo &LI;    AssumptionCache &AC; -  TargetLibraryInfo &LibInfo; +  TargetLibraryInfo &TLI;    const TargetTransformInfo &TTI;    Loop *const L; +  MemorySSAUpdater *MSSAU;    bool FavorBackedgeIndex = false;    bool Changed = false; @@ -2018,6 +2009,7 @@ class LSRInstance {    void NarrowSearchSpaceByCollapsingUnrolledCode();    void NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters();    void NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); +  void NarrowSearchSpaceByFilterPostInc();    void NarrowSearchSpaceByDeletingCostlyFormulas();    void NarrowSearchSpaceByPickingWinnerRegs();    void NarrowSearchSpaceUsingHeuristics(); @@ -2053,7 +2045,7 @@ class LSRInstance {  public:    LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT,                LoopInfo &LI, const TargetTransformInfo &TTI, AssumptionCache &AC, -              TargetLibraryInfo &LibInfo); +              TargetLibraryInfo &TLI, MemorySSAUpdater *MSSAU);    bool getChanged() const { return Changed; } @@ -2830,9 +2822,10 @@ bool IVChain::isProfitableIncrement(const SCEV *OperExpr,  /// increments can be computed in fewer registers when chained.  ///  /// TODO: Consider IVInc free if it's already used in another chains. -static bool -isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, -                  ScalarEvolution &SE) { +static bool isProfitableChain(IVChain &Chain, +                              SmallPtrSetImpl<Instruction *> &Users, +                              ScalarEvolution &SE, +                              const TargetTransformInfo &TTI) {    if (StressIVChain)      return true; @@ -2861,7 +2854,14 @@ isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users,    unsigned NumConstIncrements = 0;    unsigned NumVarIncrements = 0;    unsigned NumReusedIncrements = 0; + +  if (TTI.isProfitableLSRChainElement(Chain.Incs[0].UserInst)) +    return true; +    for (const IVInc &Inc : Chain) { +    if (TTI.isProfitableLSRChainElement(Inc.UserInst)) +      return true; +      if (Inc.IncExpr->isZero())        continue; @@ -3092,7 +3092,7 @@ void LSRInstance::CollectChains() {    for (unsigned UsersIdx = 0, NChains = IVChainVec.size();         UsersIdx < NChains; ++UsersIdx) {      if (!isProfitableChain(IVChainVec[UsersIdx], -                           ChainUsersVec[UsersIdx].FarUsers, SE)) +                           ChainUsersVec[UsersIdx].FarUsers, SE, TTI))        continue;      // Preserve the chain at UsesIdx.      if (ChainIdx != UsersIdx) @@ -3212,7 +3212,8 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter,        IVOper = Builder.CreateTruncOrBitCast(IVOper, OperTy, "lsr.chain");      }      Inc.UserInst->replaceUsesOfWith(Inc.IVOperand, IVOper); -    DeadInsts.emplace_back(Inc.IVOperand); +    if (auto *OperandIsInstr = dyn_cast<Instruction>(Inc.IVOperand)) +      DeadInsts.emplace_back(OperandIsInstr);    }    // If LSR created a new, wider phi, we may also replace its postinc. We only    // do this if we also found a wide value for the head of the chain. @@ -3240,7 +3241,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter,  void LSRInstance::CollectFixupsAndInitialFormulae() {    BranchInst *ExitBranch = nullptr; -  bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &LibInfo); +  bool SaveCmp = TTI.canSaveCmp(L, &ExitBranch, &SE, &LI, &DT, &AC, &TLI);    for (const IVStrideUse &U : IU) {      Instruction *UserInst = U.getUser(); @@ -3553,9 +3554,6 @@ static bool mayUsePostIncMode(const TargetTransformInfo &TTI,    const SCEV *LoopStep = AR->getStepRecurrence(SE);    if (!isa<SCEVConstant>(LoopStep))      return false; -  if (LU.AccessTy.getType()->getScalarSizeInBits() != -      LoopStep->getType()->getScalarSizeInBits()) -    return false;    // Check if a post-indexed load/store can be used.    if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) ||        TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) { @@ -4673,6 +4671,54 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() {    });  } +/// If we are over the complexity limit, filter out any post-inc prefering +/// variables to only post-inc values. +void LSRInstance::NarrowSearchSpaceByFilterPostInc() { +  if (!TTI.shouldFavorPostInc()) +    return; +  if (EstimateSearchSpaceComplexity() < ComplexityLimit) +    return; + +  LLVM_DEBUG(dbgs() << "The search space is too complex.\n" +                       "Narrowing the search space by choosing the lowest " +                       "register Formula for PostInc Uses.\n"); + +  for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { +    LSRUse &LU = Uses[LUIdx]; + +    if (LU.Kind != LSRUse::Address) +      continue; +    if (!TTI.isIndexedLoadLegal(TTI.MIM_PostInc, LU.AccessTy.getType()) && +        !TTI.isIndexedStoreLegal(TTI.MIM_PostInc, LU.AccessTy.getType())) +      continue; + +    size_t MinRegs = std::numeric_limits<size_t>::max(); +    for (const Formula &F : LU.Formulae) +      MinRegs = std::min(F.getNumRegs(), MinRegs); + +    bool Any = false; +    for (size_t FIdx = 0, NumForms = LU.Formulae.size(); FIdx != NumForms; +         ++FIdx) { +      Formula &F = LU.Formulae[FIdx]; +      if (F.getNumRegs() > MinRegs) { +        LLVM_DEBUG(dbgs() << "  Filtering out formula "; F.print(dbgs()); +                   dbgs() << "\n"); +        LU.DeleteFormula(F); +        --FIdx; +        --NumForms; +        Any = true; +      } +    } +    if (Any) +      LU.RecomputeRegs(LUIdx, RegUses); + +    if (EstimateSearchSpaceComplexity() < ComplexityLimit) +      break; +  } + +  LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); +} +  /// The function delete formulas with high registers number expectation.  /// Assuming we don't know the value of each formula (already delete  /// all inefficient), generate probability of not selecting for each @@ -4883,6 +4929,7 @@ void LSRInstance::NarrowSearchSpaceUsingHeuristics() {    NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters();    if (FilterSameScaledReg)      NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); +  NarrowSearchSpaceByFilterPostInc();    if (LSRExpNarrow)      NarrowSearchSpaceByDeletingCostlyFormulas();    else @@ -4923,19 +4970,24 @@ void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution,      // Ignore formulae which may not be ideal in terms of register reuse of      // ReqRegs.  The formula should use all required registers before      // introducing new ones. -    int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size()); -    for (const SCEV *Reg : ReqRegs) { -      if ((F.ScaledReg && F.ScaledReg == Reg) || -          is_contained(F.BaseRegs, Reg)) { -        --NumReqRegsToFind; -        if (NumReqRegsToFind == 0) -          break; +    // This can sometimes (notably when trying to favour postinc) lead to +    // sub-optimial decisions. There it is best left to the cost modelling to +    // get correct. +    if (!TTI.shouldFavorPostInc() || LU.Kind != LSRUse::Address) { +      int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size()); +      for (const SCEV *Reg : ReqRegs) { +        if ((F.ScaledReg && F.ScaledReg == Reg) || +            is_contained(F.BaseRegs, Reg)) { +          --NumReqRegsToFind; +          if (NumReqRegsToFind == 0) +            break; +        } +      } +      if (NumReqRegsToFind != 0) { +        // If none of the formulae satisfied the required registers, then we could +        // clear ReqRegs and try again. Currently, we simply give up in this case. +        continue;        } -    } -    if (NumReqRegsToFind != 0) { -      // If none of the formulae satisfied the required registers, then we could -      // clear ReqRegs and try again. Currently, we simply give up in this case. -      continue;      }      // Evaluate the cost of the current formula. If it's already worse than @@ -5268,7 +5320,8 @@ Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF,    // form, update the ICmp's other operand.    if (LU.Kind == LSRUse::ICmpZero) {      ICmpInst *CI = cast<ICmpInst>(LF.UserInst); -    DeadInsts.emplace_back(CI->getOperand(1)); +    if (auto *OperandIsInstr = dyn_cast<Instruction>(CI->getOperand(1))) +      DeadInsts.emplace_back(OperandIsInstr);      assert(!F.BaseGV && "ICmp does not support folding a global value and "                             "a scale at the same time!");      if (F.Scale == -1) { @@ -5449,7 +5502,8 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF,        LF.UserInst->replaceUsesOfWith(LF.OperandValToReplace, FullV);    } -  DeadInsts.emplace_back(LF.OperandValToReplace); +  if (auto *OperandIsInstr = dyn_cast<Instruction>(LF.OperandValToReplace)) +    DeadInsts.emplace_back(OperandIsInstr);  }  /// Rewrite all the fixup locations with new values, following the chosen @@ -5490,16 +5544,17 @@ void LSRInstance::ImplementSolution(    // instructions.    Rewriter.clear(); -  Changed |= DeleteTriviallyDeadInstructions(DeadInsts); +  Changed |= RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, +                                                                  &TLI, MSSAU);  }  LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE,                           DominatorTree &DT, LoopInfo &LI,                           const TargetTransformInfo &TTI, AssumptionCache &AC, -                         TargetLibraryInfo &LibInfo) -    : IU(IU), SE(SE), DT(DT), LI(LI), AC(AC), LibInfo(LibInfo), TTI(TTI), L(L), -      FavorBackedgeIndex(EnableBackedgeIndexing && -                         TTI.shouldFavorBackedgeIndex(L)) { +                         TargetLibraryInfo &TLI, MemorySSAUpdater *MSSAU) +    : IU(IU), SE(SE), DT(DT), LI(LI), AC(AC), TLI(TLI), TTI(TTI), L(L), +      MSSAU(MSSAU), FavorBackedgeIndex(EnableBackedgeIndexing && +                                       TTI.shouldFavorBackedgeIndex(L)) {    // If LoopSimplify form is not available, stay out of trouble.    if (!L->isLoopSimplifyForm())      return; @@ -5702,21 +5757,26 @@ void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const {    AU.addRequired<IVUsersWrapperPass>();    AU.addPreserved<IVUsersWrapperPass>();    AU.addRequired<TargetTransformInfoWrapperPass>(); +  AU.addPreserved<MemorySSAWrapperPass>();  }  static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,                                 DominatorTree &DT, LoopInfo &LI,                                 const TargetTransformInfo &TTI, -                               AssumptionCache &AC, -                               TargetLibraryInfo &LibInfo) { +                               AssumptionCache &AC, TargetLibraryInfo &TLI, +                               MemorySSA *MSSA) {    bool Changed = false; +  std::unique_ptr<MemorySSAUpdater> MSSAU; +  if (MSSA) +    MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);    // Run the main LSR transformation. -  Changed |= LSRInstance(L, IU, SE, DT, LI, TTI, AC, LibInfo).getChanged(); +  Changed |= +      LSRInstance(L, IU, SE, DT, LI, TTI, AC, TLI, MSSAU.get()).getChanged();    // Remove any extra phis created by processing inner loops. -  Changed |= DeleteDeadPHIs(L->getHeader()); +  Changed |= DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());    if (EnablePhiElim && L->isLoopSimplifyForm()) {      SmallVector<WeakTrackingVH, 16> DeadInsts;      const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); @@ -5727,8 +5787,9 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,      unsigned numFolded = Rewriter.replaceCongruentIVs(L, &DT, DeadInsts, &TTI);      if (numFolded) {        Changed = true; -      DeleteTriviallyDeadInstructions(DeadInsts); -      DeleteDeadPHIs(L->getHeader()); +      RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts, &TLI, +                                                           MSSAU.get()); +      DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());      }    }    return Changed; @@ -5746,19 +5807,26 @@ bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {        *L->getHeader()->getParent());    auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(        *L->getHeader()->getParent()); -  auto &LibInfo = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( +  auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(        *L->getHeader()->getParent()); -  return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, LibInfo); +  auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); +  MemorySSA *MSSA = nullptr; +  if (MSSAAnalysis) +    MSSA = &MSSAAnalysis->getMSSA(); +  return ReduceLoopStrength(L, IU, SE, DT, LI, TTI, AC, TLI, MSSA);  }  PreservedAnalyses LoopStrengthReducePass::run(Loop &L, LoopAnalysisManager &AM,                                                LoopStandardAnalysisResults &AR,                                                LPMUpdater &) {    if (!ReduceLoopStrength(&L, AM.getResult<IVUsersAnalysis>(L, AR), AR.SE, -                          AR.DT, AR.LI, AR.TTI, AR.AC, AR.TLI)) +                          AR.DT, AR.LI, AR.TTI, AR.AC, AR.TLI, AR.MSSA))      return PreservedAnalyses::all(); -  return getLoopPassPreservedAnalyses(); +  auto PA = getLoopPassPreservedAnalyses(); +  if (AR.MSSA) +    PA.preserve<MemorySSAAnalysis>(); +  return PA;  }  char LoopStrengthReduce::ID = 0; diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 92ad8dafa5ab..285cba6ee205 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -11,8 +11,10 @@  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h" +#include "llvm/ADT/ArrayRef.h"  #include "llvm/ADT/None.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/PriorityWorklist.h"  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/ADT/StringRef.h"  #include "llvm/Analysis/AssumptionCache.h" @@ -20,37 +22,36 @@  #include "llvm/Analysis/DependenceAnalysis.h"  #include "llvm/Analysis/LoopAnalysisManager.h"  #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/ScalarEvolution.h"  #include "llvm/Analysis/TargetTransformInfo.h"  #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" -#include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h" -#include "llvm/IR/Instruction.h"  #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h"  #include "llvm/IR/Metadata.h"  #include "llvm/IR/PassManager.h"  #include "llvm/InitializePasses.h"  #include "llvm/Pass.h" +#include "llvm/PassRegistry.h"  #include "llvm/Support/Casting.h"  #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h"  #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h"  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" -#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/LoopSimplify.h"  #include "llvm/Transforms/Utils/LoopUtils.h"  #include "llvm/Transforms/Utils/UnrollLoop.h" -#include <algorithm>  #include <cassert>  #include <cstdint> -#include <string> +#include <vector> + +namespace llvm { +class Instruction; +class Value; +} // namespace llvm  using namespace llvm; @@ -91,7 +92,7 @@ static cl::opt<unsigned> PragmaUnrollAndJamThreshold(  // Returns the loop hint metadata node with the given name (for example,  // "llvm.loop.unroll.count").  If no such metadata node exists, then nullptr is  // returned. -static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) { +static MDNode *getUnrollMetadataForLoop(const Loop *L, StringRef Name) {    if (MDNode *LoopID = L->getLoopID())      return GetUnrollMetadata(LoopID, Name);    return nullptr; @@ -99,14 +100,14 @@ static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) {  // Returns true if the loop has any metadata starting with Prefix. For example a  // Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata. -static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) { +static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) {    if (MDNode *LoopID = L->getLoopID()) {      // First operand should refer to the loop id itself.      assert(LoopID->getNumOperands() > 0 && "requires at least one operand");      assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); -    for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { -      MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); +    for (unsigned I = 1, E = LoopID->getNumOperands(); I < E; ++I) { +      MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(I));        if (!MD)          continue; @@ -122,14 +123,14 @@ static bool HasAnyUnrollPragma(const Loop *L, StringRef Prefix) {  }  // Returns true if the loop has an unroll_and_jam(enable) pragma. -static bool HasUnrollAndJamEnablePragma(const Loop *L) { -  return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable"); +static bool hasUnrollAndJamEnablePragma(const Loop *L) { +  return getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable");  }  // If loop has an unroll_and_jam_count pragma return the (necessarily  // positive) value from the pragma.  Otherwise return 0. -static unsigned UnrollAndJamCountPragmaValue(const Loop *L) { -  MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count"); +static unsigned unrollAndJamCountPragmaValue(const Loop *L) { +  MDNode *MD = getUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.count");    if (MD) {      assert(MD->getNumOperands() == 2 &&             "Unroll count hint metadata should have two operands."); @@ -157,7 +158,8 @@ static bool computeUnrollAndJamCount(      const SmallPtrSetImpl<const Value *> &EphValues,      OptimizationRemarkEmitter *ORE, unsigned OuterTripCount,      unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount, -    unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP) { +    unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP, +    TargetTransformInfo::PeelingPreferences &PP) {    // First up use computeUnrollCount from the loop unroller to get a count    // for unrolling the outer loop, plus any loops requiring explicit    // unrolling we leave to the unroller. This uses UP.Threshold / @@ -167,7 +169,8 @@ static bool computeUnrollAndJamCount(    bool UseUpperBound = false;    bool ExplicitUnroll = computeUnrollCount(        L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount, -      /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, UseUpperBound); +      /*MaxOrZero*/ false, OuterTripMultiple, OuterLoopSize, UP, PP, +      UseUpperBound);    if (ExplicitUnroll || UseUpperBound) {      // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it      // for the unroller instead. @@ -190,7 +193,7 @@ static bool computeUnrollAndJamCount(    }    // Check for unroll_and_jam pragmas -  unsigned PragmaCount = UnrollAndJamCountPragmaValue(L); +  unsigned PragmaCount = unrollAndJamCountPragmaValue(L);    if (PragmaCount > 0) {      UP.Count = PragmaCount;      UP.Runtime = true; @@ -202,7 +205,7 @@ static bool computeUnrollAndJamCount(        return true;    } -  bool PragmaEnableUnroll = HasUnrollAndJamEnablePragma(L); +  bool PragmaEnableUnroll = hasUnrollAndJamEnablePragma(L);    bool ExplicitUnrollAndJamCount = PragmaCount > 0 || UserUnrollCount;    bool ExplicitUnrollAndJam = PragmaEnableUnroll || ExplicitUnrollAndJamCount; @@ -279,24 +282,11 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,                        ScalarEvolution &SE, const TargetTransformInfo &TTI,                        AssumptionCache &AC, DependenceInfo &DI,                        OptimizationRemarkEmitter &ORE, int OptLevel) { -  // Quick checks of the correct loop form -  if (!L->isLoopSimplifyForm() || L->getSubLoops().size() != 1) -    return LoopUnrollResult::Unmodified; -  Loop *SubLoop = L->getSubLoops()[0]; -  if (!SubLoop->isLoopSimplifyForm()) -    return LoopUnrollResult::Unmodified; - -  BasicBlock *Latch = L->getLoopLatch(); -  BasicBlock *Exit = L->getExitingBlock(); -  BasicBlock *SubLoopLatch = SubLoop->getLoopLatch(); -  BasicBlock *SubLoopExit = SubLoop->getExitingBlock(); - -  if (Latch != Exit || SubLoopLatch != SubLoopExit) -    return LoopUnrollResult::Unmodified; -    TargetTransformInfo::UnrollingPreferences UP =        gatherUnrollingPreferences(L, SE, TTI, nullptr, nullptr, OptLevel, None, -                                 None, None, None, None, None, None, None); +                                 None, None, None, None, None); +  TargetTransformInfo::PeelingPreferences PP = +      gatherPeelingPreferences(L, SE, TTI, None, None);    if (AllowUnrollAndJam.getNumOccurrences() > 0)      UP.UnrollAndJam = AllowUnrollAndJam;    if (UnrollAndJamThreshold.getNumOccurrences() > 0) @@ -317,13 +307,13 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,    // the unroller, so long as it does not explicitly have unroll_and_jam    // metadata. This means #pragma nounroll will disable unroll and jam as well    // as unrolling -  if (HasAnyUnrollPragma(L, "llvm.loop.unroll.") && -      !HasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) { +  if (hasAnyUnrollPragma(L, "llvm.loop.unroll.") && +      !hasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) {      LLVM_DEBUG(dbgs() << "  Disabled due to pragma.\n");      return LoopUnrollResult::Unmodified;    } -  if (!isSafeToUnrollAndJam(L, SE, DT, DI)) { +  if (!isSafeToUnrollAndJam(L, SE, DT, DI, *LI)) {      LLVM_DEBUG(dbgs() << "  Disabled due to not being safe.\n");      return LoopUnrollResult::Unmodified;    } @@ -334,6 +324,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,    bool Convergent;    SmallPtrSet<const Value *, 32> EphValues;    CodeMetrics::collectEphemeralValues(L, &AC, EphValues); +  Loop *SubLoop = L->getSubLoops()[0];    unsigned InnerLoopSize =        ApproximateLoopSize(SubLoop, NumInlineCandidates, NotDuplicatable,                            Convergent, TTI, EphValues, UP.BEInsns); @@ -371,6 +362,8 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,      SubLoop->setLoopID(NewInnerEpilogueLoopID.getValue());    // Find trip count and trip multiple +  BasicBlock *Latch = L->getLoopLatch(); +  BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();    unsigned OuterTripCount = SE.getSmallConstantTripCount(L, Latch);    unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, Latch);    unsigned InnerTripCount = SE.getSmallConstantTripCount(SubLoop, SubLoopLatch); @@ -378,7 +371,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,    // Decide if, and by how much, to unroll    bool IsCountSetExplicitly = computeUnrollAndJamCount(        L, SubLoop, TTI, DT, LI, SE, EphValues, &ORE, OuterTripCount, -      OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP); +      OuterTripMultiple, OuterLoopSize, InnerTripCount, InnerLoopSize, UP, PP);    if (UP.Count <= 1)      return LoopUnrollResult::Unmodified;    // Unroll factor (Count) must be less or equal to TripCount. @@ -388,7 +381,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,    Loop *EpilogueOuterLoop = nullptr;    LoopUnrollResult UnrollResult = UnrollAndJamLoop(        L, UP.Count, OuterTripCount, OuterTripMultiple, UP.UnrollRemainder, LI, -      &SE, &DT, &AC, &ORE, &EpilogueOuterLoop); +      &SE, &DT, &AC, &TTI, &ORE, &EpilogueOuterLoop);    // Assign new loop attributes.    if (EpilogueOuterLoop) { @@ -435,22 +428,23 @@ static bool tryToUnrollAndJamLoop(Function &F, DominatorTree &DT, LoopInfo &LI,                                    int OptLevel) {    bool DidSomething = false; -  // The loop unroll and jam pass requires loops to be in simplified form, and also needs LCSSA. -  // Since simplification may add new inner loops, it has to run before the -  // legality and profitability checks. This means running the loop unroll and jam pass -  // will simplify all loops, regardless of whether anything end up being -  // unroll and jammed. +  // The loop unroll and jam pass requires loops to be in simplified form, and +  // also needs LCSSA. Since simplification may add new inner loops, it has to +  // run before the legality and profitability checks. This means running the +  // loop unroll and jam pass will simplify all loops, regardless of whether +  // anything end up being unroll and jammed.    for (auto &L : LI) {      DidSomething |=          simplifyLoop(L, &DT, &LI, &SE, &AC, nullptr, false /* PreserveLCSSA */);      DidSomething |= formLCSSARecursively(*L, DT, &LI, &SE);    } +  // Add the loop nests in the reverse order of LoopInfo. See method +  // declaration.    SmallPriorityWorklist<Loop *, 4> Worklist; -  internal::appendLoopsToWorklist(reverse(LI), Worklist); +  appendLoopsToWorklist(LI, Worklist);    while (!Worklist.empty()) {      Loop *L = Worklist.pop_back_val(); -    formLCSSA(*L, DT, &LI, &SE);      LoopUnrollResult Result =          tryToUnrollAndJamLoop(L, DT, &LI, SE, TTI, AC, DI, ORE, OptLevel);      if (Result != LoopUnrollResult::Unmodified) diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index 4c2b079c6bb5..87f40bb7ba85 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -154,6 +154,10 @@ static cl::opt<bool>                         cl::desc("Allows loops to be peeled when the dynamic "                                  "trip count is known to be low.")); +static cl::opt<bool> UnrollAllowLoopNestsPeeling( +    "unroll-allow-loop-nests-peeling", cl::init(false), cl::Hidden, +    cl::desc("Allows loop nests to be peeled.")); +  static cl::opt<bool> UnrollUnrollRemainder(    "unroll-remainder", cl::Hidden,    cl::desc("Allow the loop remainder to be unrolled.")); @@ -167,6 +171,16 @@ static cl::opt<bool> UnrollRevisitChildLoops(               "This shouldn't typically be needed as child loops (or their "               "clones) were already visited.")); +static cl::opt<unsigned> UnrollThresholdAggressive( +    "unroll-threshold-aggressive", cl::init(300), cl::Hidden, +    cl::desc("Threshold (max size of unrolled loop) to use in aggressive (O3) " +             "optimizations")); +static cl::opt<unsigned> +    UnrollThresholdDefault("unroll-threshold-default", cl::init(150), +                           cl::Hidden, +                           cl::desc("Default threshold (max size of unrolled " +                                    "loop), used in all but O3 optimizations")); +  /// A magic value for use with the Threshold parameter to indicate  /// that the loop unroll should be performed regardless of how much  /// code expansion would result. @@ -179,19 +193,17 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(      BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, int OptLevel,      Optional<unsigned> UserThreshold, Optional<unsigned> UserCount,      Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, -    Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling, -    Optional<bool> UserAllowProfileBasedPeeling, -    Optional<unsigned> UserFullUnrollMaxCount) { +    Optional<bool> UserUpperBound, Optional<unsigned> UserFullUnrollMaxCount) {    TargetTransformInfo::UnrollingPreferences UP;    // Set up the defaults -  UP.Threshold = OptLevel > 2 ? 300 : 150; +  UP.Threshold = +      OptLevel > 2 ? UnrollThresholdAggressive : UnrollThresholdDefault;    UP.MaxPercentThresholdBoost = 400;    UP.OptSizeThreshold = 0;    UP.PartialThreshold = 150;    UP.PartialOptSizeThreshold = 0;    UP.Count = 0; -  UP.PeelCount = 0;    UP.DefaultUnrollRuntimeCount = 8;    UP.MaxCount = std::numeric_limits<unsigned>::max();    UP.FullUnrollMaxCount = std::numeric_limits<unsigned>::max(); @@ -203,10 +215,9 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(    UP.AllowExpensiveTripCount = false;    UP.Force = false;    UP.UpperBound = false; -  UP.AllowPeeling = true;    UP.UnrollAndJam = false; -  UP.PeelProfiledIterations = true;    UP.UnrollAndJamInnerLoopThreshold = 60; +  UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze;    // Override with any target specific settings    TTI.getUnrollingPreferences(L, SE, UP); @@ -232,8 +243,6 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(      UP.MaxCount = UnrollMaxCount;    if (UnrollFullMaxCount.getNumOccurrences() > 0)      UP.FullUnrollMaxCount = UnrollFullMaxCount; -  if (UnrollPeelCount.getNumOccurrences() > 0) -    UP.PeelCount = UnrollPeelCount;    if (UnrollAllowPartial.getNumOccurrences() > 0)      UP.Partial = UnrollAllowPartial;    if (UnrollAllowRemainder.getNumOccurrences() > 0) @@ -242,10 +251,10 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(      UP.Runtime = UnrollRuntime;    if (UnrollMaxUpperBound == 0)      UP.UpperBound = false; -  if (UnrollAllowPeeling.getNumOccurrences() > 0) -    UP.AllowPeeling = UnrollAllowPeeling;    if (UnrollUnrollRemainder.getNumOccurrences() > 0)      UP.UnrollRemainder = UnrollUnrollRemainder; +  if (UnrollMaxIterationsCountToAnalyze.getNumOccurrences() > 0) +    UP.MaxIterationsCountToAnalyze = UnrollMaxIterationsCountToAnalyze;    // Apply user values provided by argument    if (UserThreshold.hasValue()) { @@ -260,16 +269,45 @@ TargetTransformInfo::UnrollingPreferences llvm::gatherUnrollingPreferences(      UP.Runtime = *UserRuntime;    if (UserUpperBound.hasValue())      UP.UpperBound = *UserUpperBound; -  if (UserAllowPeeling.hasValue()) -    UP.AllowPeeling = *UserAllowPeeling; -  if (UserAllowProfileBasedPeeling.hasValue()) -    UP.PeelProfiledIterations = *UserAllowProfileBasedPeeling;    if (UserFullUnrollMaxCount.hasValue())      UP.FullUnrollMaxCount = *UserFullUnrollMaxCount;    return UP;  } +TargetTransformInfo::PeelingPreferences +llvm::gatherPeelingPreferences(Loop *L, ScalarEvolution &SE, +                               const TargetTransformInfo &TTI, +                               Optional<bool> UserAllowPeeling, +                               Optional<bool> UserAllowProfileBasedPeeling) { +  TargetTransformInfo::PeelingPreferences PP; + +  // Default values +  PP.PeelCount = 0; +  PP.AllowPeeling = true; +  PP.AllowLoopNestsPeeling = false; +  PP.PeelProfiledIterations = true; + +  // Get Target Specifc Values +  TTI.getPeelingPreferences(L, SE, PP); + +  // User Specified Values using cl::opt +  if (UnrollPeelCount.getNumOccurrences() > 0) +    PP.PeelCount = UnrollPeelCount; +  if (UnrollAllowPeeling.getNumOccurrences() > 0) +    PP.AllowPeeling = UnrollAllowPeeling; +  if (UnrollAllowLoopNestsPeeling.getNumOccurrences() > 0) +    PP.AllowLoopNestsPeeling = UnrollAllowLoopNestsPeeling; + +  // User Specifed values provided by argument +  if (UserAllowPeeling.hasValue()) +    PP.AllowPeeling = *UserAllowPeeling; +  if (UserAllowProfileBasedPeeling.hasValue()) +    PP.PeelProfiledIterations = *UserAllowProfileBasedPeeling; + +  return PP; +} +  namespace {  /// A struct to densely store the state of an instruction after unrolling at @@ -335,11 +373,12 @@ struct EstimatedUnrollCost {  static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(      const Loop *L, unsigned TripCount, DominatorTree &DT, ScalarEvolution &SE,      const SmallPtrSetImpl<const Value *> &EphValues, -    const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize) { +    const TargetTransformInfo &TTI, unsigned MaxUnrolledLoopSize, +    unsigned MaxIterationsCountToAnalyze) {    // We want to be able to scale offsets by the trip count and add more offsets    // to them without checking for overflows, and we already don't want to    // analyze *massive* trip counts, so we force the max to be reasonably small. -  assert(UnrollMaxIterationsCountToAnalyze < +  assert(MaxIterationsCountToAnalyze <               (unsigned)(std::numeric_limits<int>::max() / 2) &&           "The unroll iterations max is too large!"); @@ -349,8 +388,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(      return None;    // Don't simulate loops with a big or unknown tripcount -  if (!UnrollMaxIterationsCountToAnalyze || !TripCount || -      TripCount > UnrollMaxIterationsCountToAnalyze) +  if (!TripCount || TripCount > MaxIterationsCountToAnalyze)      return None;    SmallSetVector<BasicBlock *, 16> BBWorklist; @@ -428,7 +466,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(          // First accumulate the cost of this instruction.          if (!Cost.IsFree) { -          UnrolledCost += TTI.getUserCost(I); +          UnrolledCost += TTI.getUserCost(I, TargetTransformInfo::TCK_CodeSize);            LLVM_DEBUG(dbgs() << "Adding cost of instruction (iteration "                              << Iteration << "): ");            LLVM_DEBUG(I->dump()); @@ -521,7 +559,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost(          // Track this instruction's expected baseline cost when executing the          // rolled loop form. -        RolledDynamicCost += TTI.getUserCost(&I); +        RolledDynamicCost += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize);          // Visit the instruction to analyze its loop cost after unrolling,          // and if the visitor returns true, mark the instruction as free after @@ -665,32 +703,32 @@ unsigned llvm::ApproximateLoopSize(  // Returns the loop hint metadata node with the given name (for example,  // "llvm.loop.unroll.count").  If no such metadata node exists, then nullptr is  // returned. -static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) { +static MDNode *getUnrollMetadataForLoop(const Loop *L, StringRef Name) {    if (MDNode *LoopID = L->getLoopID())      return GetUnrollMetadata(LoopID, Name);    return nullptr;  }  // Returns true if the loop has an unroll(full) pragma. -static bool HasUnrollFullPragma(const Loop *L) { -  return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.full"); +static bool hasUnrollFullPragma(const Loop *L) { +  return getUnrollMetadataForLoop(L, "llvm.loop.unroll.full");  }  // Returns true if the loop has an unroll(enable) pragma. This metadata is used  // for both "#pragma unroll" and "#pragma clang loop unroll(enable)" directives. -static bool HasUnrollEnablePragma(const Loop *L) { -  return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.enable"); +static bool hasUnrollEnablePragma(const Loop *L) { +  return getUnrollMetadataForLoop(L, "llvm.loop.unroll.enable");  }  // Returns true if the loop has an runtime unroll(disable) pragma. -static bool HasRuntimeUnrollDisablePragma(const Loop *L) { -  return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable"); +static bool hasRuntimeUnrollDisablePragma(const Loop *L) { +  return getUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable");  }  // If loop has an unroll_count pragma return the (necessarily  // positive) value from the pragma.  Otherwise return 0. -static unsigned UnrollCountPragmaValue(const Loop *L) { -  MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll.count"); +static unsigned unrollCountPragmaValue(const Loop *L) { +  MDNode *MD = getUnrollMetadataForLoop(L, "llvm.loop.unroll.count");    if (MD) {      assert(MD->getNumOperands() == 2 &&             "Unroll count hint metadata should have two operands."); @@ -740,7 +778,8 @@ bool llvm::computeUnrollCount(      ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues,      OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount,      bool MaxOrZero, unsigned &TripMultiple, unsigned LoopSize, -    TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { +    TargetTransformInfo::UnrollingPreferences &UP, +    TargetTransformInfo::PeelingPreferences &PP, bool &UseUpperBound) {    // Check for explicit Count.    // 1st priority is unroll count set by "unroll-count" option. @@ -754,7 +793,7 @@ bool llvm::computeUnrollCount(    }    // 2nd priority is unroll count set by pragma. -  unsigned PragmaCount = UnrollCountPragmaValue(L); +  unsigned PragmaCount = unrollCountPragmaValue(L);    if (PragmaCount > 0) {      UP.Count = PragmaCount;      UP.Runtime = true; @@ -764,14 +803,14 @@ bool llvm::computeUnrollCount(          getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold)        return true;    } -  bool PragmaFullUnroll = HasUnrollFullPragma(L); +  bool PragmaFullUnroll = hasUnrollFullPragma(L);    if (PragmaFullUnroll && TripCount != 0) {      UP.Count = TripCount;      if (getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold)        return false;    } -  bool PragmaEnableUnroll = HasUnrollEnablePragma(L); +  bool PragmaEnableUnroll = hasUnrollEnablePragma(L);    bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll ||                          PragmaEnableUnroll || UserUnrollCount; @@ -827,7 +866,8 @@ bool llvm::computeUnrollCount(        // To check that, run additional analysis on the loop.        if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost(                L, FullUnrollTripCount, DT, SE, EphValues, TTI, -              UP.Threshold * UP.MaxPercentThresholdBoost / 100)) { +              UP.Threshold * UP.MaxPercentThresholdBoost / 100, +              UP.MaxIterationsCountToAnalyze)) {          unsigned Boost =              getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost);          if (Cost->UnrolledCost < UP.Threshold * Boost / 100) { @@ -841,8 +881,8 @@ bool llvm::computeUnrollCount(    }    // 4th priority is loop peeling. -  computePeelCount(L, LoopSize, UP, TripCount, SE); -  if (UP.PeelCount) { +  computePeelCount(L, LoopSize, UP, PP, TripCount, SE); +  if (PP.PeelCount) {      UP.Runtime = false;      UP.Count = 1;      return ExplicitUnroll; @@ -925,7 +965,7 @@ bool llvm::computeUnrollCount(    // 6th priority is runtime unrolling.    // Don't unroll a runtime trip count loop when it is disabled. -  if (HasRuntimeUnrollDisablePragma(L)) { +  if (hasRuntimeUnrollDisablePragma(L)) {      UP.Count = 0;      return false;    } @@ -1045,8 +1085,9 @@ static LoopUnrollResult tryToUnrollLoop(    TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(        L, SE, TTI, BFI, PSI, OptLevel, ProvidedThreshold, ProvidedCount,        ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, -      ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling,        ProvidedFullUnrollMaxCount); +  TargetTransformInfo::PeelingPreferences PP = gatherPeelingPreferences( +      L, SE, TTI, ProvidedAllowPeeling, ProvidedAllowProfileBasedPeeling);    // Exit early if unrolling is disabled. For OptForSize, we pick the loop size    // as threshold later on. @@ -1120,7 +1161,7 @@ static LoopUnrollResult tryToUnrollLoop(    bool UseUpperBound = false;    bool IsCountSetExplicitly = computeUnrollCount(        L, TTI, DT, LI, SE, EphValues, &ORE, TripCount, MaxTripCount, MaxOrZero, -      TripMultiple, LoopSize, UP, UseUpperBound); +      TripMultiple, LoopSize, UP, PP, UseUpperBound);    if (!UP.Count)      return LoopUnrollResult::Unmodified;    // Unroll factor (Count) must be less or equal to TripCount. @@ -1135,9 +1176,9 @@ static LoopUnrollResult tryToUnrollLoop(    LoopUnrollResult UnrollResult = UnrollLoop(        L,        {UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, -       UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder, +       UseUpperBound, MaxOrZero, TripMultiple, PP.PeelCount, UP.UnrollRemainder,         ForgetAllSCEV}, -      LI, &SE, &DT, &AC, &ORE, PreserveLCSSA, &RemainderLoop); +      LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop);    if (UnrollResult == LoopUnrollResult::Unmodified)      return LoopUnrollResult::Unmodified; @@ -1167,7 +1208,7 @@ static LoopUnrollResult tryToUnrollLoop(    // If the loop was peeled, we already "used up" the profile information    // we had, so we don't want to unroll or peel again.    if (UnrollResult != LoopUnrollResult::FullyUnrolled && -      (IsCountSetExplicitly || (UP.PeelProfiledIterations && UP.PeelCount))) +      (IsCountSetExplicitly || (PP.PeelProfiledIterations && PP.PeelCount)))      L->setLoopAlreadyUnrolled();    return UnrollResult; @@ -1296,16 +1337,10 @@ Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced,  PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,                                            LoopStandardAnalysisResults &AR,                                            LPMUpdater &Updater) { -  const auto &FAM = -      AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); -  Function *F = L.getHeader()->getParent(); - -  auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); -  // FIXME: This should probably be optional rather than required. -  if (!ORE) -    report_fatal_error( -        "LoopFullUnrollPass: OptimizationRemarkEmitterAnalysis not " -        "cached at a higher level"); +  // For the new PM, we can't use OptimizationRemarkEmitter as an analysis +  // pass. Function analyses need to be preserved across loop transformations +  // but ORE cannot be preserved (see comment before the pass definition). +  OptimizationRemarkEmitter ORE(L.getHeader()->getParent());    // Keep track of the previous loop structure so we can identify new loops    // created by unrolling. @@ -1316,9 +1351,9 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,    else      OldLoops.insert(AR.LI.begin(), AR.LI.end()); -  std::string LoopName = L.getName(); +  std::string LoopName = std::string(L.getName()); -  bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, +  bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, ORE,                                   /*BFI*/ nullptr, /*PSI*/ nullptr,                                   /*PreserveLCSSA*/ true, OptLevel,                                   OnlyWhenForced, ForgetSCEV, /*Count*/ None, @@ -1384,30 +1419,6 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM,    return getLoopPassPreservedAnalyses();  } -template <typename RangeT> -static SmallVector<Loop *, 8> appendLoopsToWorklist(RangeT &&Loops) { -  SmallVector<Loop *, 8> Worklist; -  // We use an internal worklist to build up the preorder traversal without -  // recursion. -  SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist; - -  for (Loop *RootL : Loops) { -    assert(PreOrderLoops.empty() && "Must start with an empty preorder walk."); -    assert(PreOrderWorklist.empty() && -           "Must start with an empty preorder walk worklist."); -    PreOrderWorklist.push_back(RootL); -    do { -      Loop *L = PreOrderWorklist.pop_back_val(); -      PreOrderWorklist.append(L->begin(), L->end()); -      PreOrderLoops.push_back(L); -    } while (!PreOrderWorklist.empty()); - -    Worklist.append(PreOrderLoops.begin(), PreOrderLoops.end()); -    PreOrderLoops.clear(); -  } -  return Worklist; -} -  PreservedAnalyses LoopUnrollPass::run(Function &F,                                        FunctionAnalysisManager &AM) {    auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); @@ -1421,10 +1432,9 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,    if (auto *LAMProxy = AM.getCachedResult<LoopAnalysisManagerFunctionProxy>(F))      LAM = &LAMProxy->getManager(); -  const ModuleAnalysisManager &MAM = -      AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); +  auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);    ProfileSummaryInfo *PSI = -      MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); +      MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());    auto *BFI = (PSI && PSI->hasProfileSummary()) ?        &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; @@ -1441,7 +1451,10 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,      Changed |= formLCSSARecursively(*L, DT, &LI, &SE);    } -  SmallVector<Loop *, 8> Worklist = appendLoopsToWorklist(LI); +  // Add the loop nests in the reverse order of LoopInfo. See method +  // declaration. +  SmallPriorityWorklist<Loop *, 4> Worklist; +  appendLoopsToWorklist(LI, Worklist);    while (!Worklist.empty()) {      // Because the LoopInfo stores the loops in RPO, we walk the worklist @@ -1459,7 +1472,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,      Optional<bool> LocalAllowPeeling = UnrollOpts.AllowPeeling;      if (PSI && PSI->hasHugeWorkingSetSize())        LocalAllowPeeling = false; -    std::string LoopName = L.getName(); +    std::string LoopName = std::string(L.getName());      // The API here is quite complex to call and we allow to select some      // flavors of unrolling during construction time (by setting UnrollOpts).      LoopUnrollResult Result = tryToUnrollLoop( diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp index 915e053704b2..645a89bbd0ff 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -38,11 +38,11 @@  #include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/MemorySSA.h"  #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/MustExecute.h"  #include "llvm/Analysis/ScalarEvolution.h"  #include "llvm/Analysis/TargetTransformInfo.h"  #include "llvm/IR/Attributes.h"  #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DerivedTypes.h" @@ -158,7 +158,7 @@ namespace {      // Returns true if another unswitching could be done within the cost      // threshold. -    bool CostAllowsUnswitching(); +    bool costAllowsUnswitching();      // Clone all loop-unswitch related loop properties.      // Redistribute unswitching quotas. @@ -173,20 +173,20 @@ namespace {      AssumptionCache *AC;      // Used to check if second loop needs processing after -    // RewriteLoopBodyWithConditionConstant rewrites first loop. +    // rewriteLoopBodyWithConditionConstant rewrites first loop.      std::vector<Loop*> LoopProcessWorklist;      LUAnalysisCache BranchesInfo;      bool OptimizeForSize; -    bool redoLoop = false; +    bool RedoLoop = false; -    Loop *currentLoop = nullptr; +    Loop *CurrentLoop = nullptr;      DominatorTree *DT = nullptr;      MemorySSA *MSSA = nullptr;      std::unique_ptr<MemorySSAUpdater> MSSAU; -    BasicBlock *loopHeader = nullptr; -    BasicBlock *loopPreheader = nullptr; +    BasicBlock *LoopHeader = nullptr; +    BasicBlock *LoopPreheader = nullptr;      bool SanitizeMemory;      SimpleLoopSafetyInfo SafetyInfo; @@ -198,15 +198,15 @@ namespace {      // NewBlocks contained cloned copy of basic blocks from LoopBlocks.      std::vector<BasicBlock*> NewBlocks; -    bool hasBranchDivergence; +    bool HasBranchDivergence;    public:      static char ID; // Pass ID, replacement for typeid -    explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) +    explicit LoopUnswitch(bool Os = false, bool HasBranchDivergence = false)          : LoopPass(ID), OptimizeForSize(Os), -          hasBranchDivergence(hasBranchDivergence) { -        initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); +          HasBranchDivergence(HasBranchDivergence) { +      initializeLoopUnswitchPass(*PassRegistry::getPassRegistry());      }      bool runOnLoop(Loop *L, LPPassManager &LPM) override; @@ -223,48 +223,46 @@ namespace {          AU.addRequired<MemorySSAWrapperPass>();          AU.addPreserved<MemorySSAWrapperPass>();        } -      if (hasBranchDivergence) +      if (HasBranchDivergence)          AU.addRequired<LegacyDivergenceAnalysis>();        getLoopAnalysisUsage(AU);      }    private: -    void releaseMemory() override { -      BranchesInfo.forgetLoop(currentLoop); -    } +    void releaseMemory() override { BranchesInfo.forgetLoop(CurrentLoop); }      void initLoopData() { -      loopHeader = currentLoop->getHeader(); -      loopPreheader = currentLoop->getLoopPreheader(); +      LoopHeader = CurrentLoop->getHeader(); +      LoopPreheader = CurrentLoop->getLoopPreheader();      }      /// Split all of the edges from inside the loop to their exit blocks.      /// Update the appropriate Phi nodes as we do so. -    void SplitExitEdges(Loop *L, +    void splitExitEdges(Loop *L,                          const SmallVectorImpl<BasicBlock *> &ExitBlocks); -    bool TryTrivialLoopUnswitch(bool &Changed); +    bool tryTrivialLoopUnswitch(bool &Changed); -    bool UnswitchIfProfitable(Value *LoopCond, Constant *Val, +    bool unswitchIfProfitable(Value *LoopCond, Constant *Val,                                Instruction *TI = nullptr); -    void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, +    void unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,                                    BasicBlock *ExitBlock, Instruction *TI); -    void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L, +    void unswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L,                                       Instruction *TI); -    void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, -                                              Constant *Val, bool isEqual); +    void rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, +                                              Constant *Val, bool IsEqual); -    void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, +    void emitPreheaderBranchOnCondition(Value *LIC, Constant *Val,                                          BasicBlock *TrueDest,                                          BasicBlock *FalseDest,                                          BranchInst *OldBranch, Instruction *TI); -    void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); +    void simplifyCode(std::vector<Instruction *> &Worklist, Loop *L);      /// Given that the Invariant is not equal to Val. Simplify instructions      /// in the loop. -    Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, +    Value *simplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant,                                             Constant *Val);    }; @@ -347,7 +345,7 @@ bool LUAnalysisCache::isUnswitched(const SwitchInst *SI, const Value *V) {    return (*CurLoopInstructions)[SI].count(V);  } -bool LUAnalysisCache::CostAllowsUnswitching() { +bool LUAnalysisCache::costAllowsUnswitching() {    return CurrentLoopProperties->CanBeUnswitchedCount > 0;  } @@ -396,8 +394,8 @@ INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)  INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops",                        false, false) -Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) { -  return new LoopUnswitch(Os, hasBranchDivergence); +Pass *llvm::createLoopUnswitchPass(bool Os, bool HasBranchDivergence) { +  return new LoopUnswitch(Os, HasBranchDivergence);  }  /// Operator chain lattice. @@ -411,15 +409,15 @@ enum OperatorChain {  /// Cond is a condition that occurs in L. If it is invariant in the loop, or has  /// an invariant piece, return the invariant. Otherwise, return null.  // -/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a -/// mixed operator chain, as we can not reliably find a value which will simplify -/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 -/// to simplify the chain. +/// NOTE: findLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will +/// simplify the operator chain. If the chain is AND-only or OR-only, we can use +/// 0 or ~0 to simplify the chain.  ///  /// NOTE: In case a partial LIV and a mixed operator chain, we may be able to  /// simplify the condition itself to a loop variant condition, but at the  /// cost of creating an entirely new loop. -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, +static Value *findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,                                     OperatorChain &ParentChain,                                     DenseMap<Value *, Value *> &Cache,                                     MemorySSAUpdater *MSSAU) { @@ -479,7 +477,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,          // If either the left or right side is invariant, we can unswitch on this,          // which will cause the branch to go away in one loop and the condition to          // simplify in the other one. -        if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, +        if (Value *LHS = findLIVLoopCondition(BO->getOperand(0), L, Changed,                                                ParentChain, Cache, MSSAU)) {            Cache[Cond] = LHS;            return LHS; @@ -487,7 +485,7 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,          // We did not manage to find a partial LIV in operand(0). Backtrack and try          // operand(1).          ParentChain = NewChain; -        if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, +        if (Value *RHS = findLIVLoopCondition(BO->getOperand(1), L, Changed,                                                ParentChain, Cache, MSSAU)) {            Cache[Cond] = RHS;            return RHS; @@ -503,11 +501,11 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,  /// an invariant piece, return the invariant along with the operator chain type.  /// Otherwise, return null.  static std::pair<Value *, OperatorChain> -FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, +findLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,                       MemorySSAUpdater *MSSAU) {    DenseMap<Value *, Value *> Cache;    OperatorChain OpChain = OC_OpChainNone; -  Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU); +  Value *FCond = findLIVLoopCondition(Cond, L, Changed, OpChain, Cache, MSSAU);    // In case we do find a LIV, it can not be obtained by walking up a mixed    // operator chain. @@ -516,22 +514,22 @@ FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed,    return {FCond, OpChain};  } -bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { +bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPMRef) {    if (skipLoop(L))      return false;    AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(        *L->getHeader()->getParent());    LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); -  LPM = &LPM_Ref; +  LPM = &LPMRef;    DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();    if (EnableMSSALoopDependency) {      MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA();      MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);      assert(DT && "Cannot update MemorySSA without a valid DomTree.");    } -  currentLoop = L; -  Function *F = currentLoop->getHeader()->getParent(); +  CurrentLoop = L; +  Function *F = CurrentLoop->getHeader()->getParent();    SanitizeMemory = F->hasFnAttribute(Attribute::SanitizeMemory);    if (SanitizeMemory) @@ -542,12 +540,12 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {    bool Changed = false;    do { -    assert(currentLoop->isLCSSAForm(*DT)); +    assert(CurrentLoop->isLCSSAForm(*DT));      if (MSSA && VerifyMemorySSA)        MSSA->verifyMemorySSA(); -    redoLoop = false; +    RedoLoop = false;      Changed |= processCurrentLoop(); -  } while(redoLoop); +  } while (RedoLoop);    if (MSSA && VerifyMemorySSA)      MSSA->verifyMemorySSA(); @@ -560,7 +558,7 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) {  bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) {    auto *Node = DT->getNode(BB)->getIDom();    BasicBlock *DomBB = Node->getBlock(); -  while (currentLoop->contains(DomBB)) { +  while (CurrentLoop->contains(DomBB)) {      BranchInst *BInst = dyn_cast<BranchInst>(DomBB->getTerminator());      Node = DT->getNode(DomBB)->getIDom(); @@ -591,7 +589,7 @@ bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) {  /// causing problems. Detail could be found in PR31652. Note if the  /// func returns true, it is unsafe. But if it is false, it doesn't mean  /// it is necessarily safe. -static bool EqualityPropUnSafe(Value &LoopCond) { +static bool equalityPropUnSafe(Value &LoopCond) {    ICmpInst *CI = dyn_cast<ICmpInst>(&LoopCond);    if (!CI || !CI->isEquality())      return false; @@ -601,7 +599,7 @@ static bool EqualityPropUnSafe(Value &LoopCond) {    if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS))      return true; -  auto hasUndefInPHI = [](PHINode &PN) { +  auto HasUndefInPHI = [](PHINode &PN) {      for (Value *Opd : PN.incoming_values()) {        if (isa<UndefValue>(Opd))          return true; @@ -610,10 +608,10 @@ static bool EqualityPropUnSafe(Value &LoopCond) {    };    PHINode *LPHI = dyn_cast<PHINode>(LHS);    PHINode *RPHI = dyn_cast<PHINode>(RHS); -  if ((LPHI && hasUndefInPHI(*LPHI)) || (RPHI && hasUndefInPHI(*RPHI))) +  if ((LPHI && HasUndefInPHI(*LPHI)) || (RPHI && HasUndefInPHI(*RPHI)))      return true; -  auto hasUndefInSelect = [](SelectInst &SI) { +  auto HasUndefInSelect = [](SelectInst &SI) {      if (isa<UndefValue>(SI.getTrueValue()) ||          isa<UndefValue>(SI.getFalseValue()))        return true; @@ -621,7 +619,7 @@ static bool EqualityPropUnSafe(Value &LoopCond) {    };    SelectInst *LSI = dyn_cast<SelectInst>(LHS);    SelectInst *RSI = dyn_cast<SelectInst>(RHS); -  if ((LSI && hasUndefInSelect(*LSI)) || (RSI && hasUndefInSelect(*RSI))) +  if ((LSI && HasUndefInSelect(*LSI)) || (RSI && HasUndefInSelect(*RSI)))      return true;    return false;  } @@ -633,35 +631,36 @@ bool LoopUnswitch::processCurrentLoop() {    initLoopData();    // If LoopSimplify was unable to form a preheader, don't do any unswitching. -  if (!loopPreheader) +  if (!LoopPreheader)      return false;    // Loops with indirectbr cannot be cloned. -  if (!currentLoop->isSafeToClone()) +  if (!CurrentLoop->isSafeToClone())      return false;    // Without dedicated exits, splitting the exit edge may fail. -  if (!currentLoop->hasDedicatedExits()) +  if (!CurrentLoop->hasDedicatedExits())      return false; -  LLVMContext &Context = loopHeader->getContext(); +  LLVMContext &Context = LoopHeader->getContext();    // Analyze loop cost, and stop unswitching if loop content can not be duplicated.    if (!BranchesInfo.countLoop( -          currentLoop, getAnalysis<TargetTransformInfoWrapperPass>().getTTI( -                           *currentLoop->getHeader()->getParent()), +          CurrentLoop, +          getAnalysis<TargetTransformInfoWrapperPass>().getTTI( +              *CurrentLoop->getHeader()->getParent()),            AC))      return false;    // Try trivial unswitch first before loop over other basic blocks in the loop. -  if (TryTrivialLoopUnswitch(Changed)) { +  if (tryTrivialLoopUnswitch(Changed)) {      return true;    }    // Do not do non-trivial unswitch while optimizing for size.    // FIXME: Use Function::hasOptSize().    if (OptimizeForSize || -      loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) +      LoopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize))      return false;    // Run through the instructions in the loop, keeping track of three things: @@ -680,11 +679,12 @@ bool LoopUnswitch::processCurrentLoop() {    SmallVector<IntrinsicInst *, 4> Guards; -  for (const auto BB : currentLoop->blocks()) { +  for (const auto BB : CurrentLoop->blocks()) {      for (auto &I : *BB) { -      auto CS = CallSite(&I); -      if (!CS) continue; -      if (CS.isConvergent()) +      auto *CB = dyn_cast<CallBase>(&I); +      if (!CB) +        continue; +      if (CB->isConvergent())          return false;        if (auto *II = dyn_cast<InvokeInst>(&I))          if (!II->getUnwindDest()->canSplitPredecessors()) @@ -696,11 +696,11 @@ bool LoopUnswitch::processCurrentLoop() {    }    for (IntrinsicInst *Guard : Guards) { -    Value *LoopCond = FindLIVLoopCondition(Guard->getOperand(0), currentLoop, +    Value *LoopCond = findLIVLoopCondition(Guard->getOperand(0), CurrentLoop,                                             Changed, MSSAU.get())                            .first;      if (LoopCond && -        UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { +        unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {        // NB! Unswitching (if successful) could have erased some of the        // instructions in Guards leaving dangling pointers there.  This is fine        // because we're returning now, and won't look at Guards again. @@ -712,8 +712,9 @@ bool LoopUnswitch::processCurrentLoop() {    // Loop over all of the basic blocks in the loop.  If we find an interior    // block that is branching on a loop-invariant condition, we can unswitch this    // loop. -  for (Loop::block_iterator I = currentLoop->block_begin(), -         E = currentLoop->block_end(); I != E; ++I) { +  for (Loop::block_iterator I = CurrentLoop->block_begin(), +                            E = CurrentLoop->block_end(); +       I != E; ++I) {      Instruction *TI = (*I)->getTerminator();      // Unswitching on a potentially uninitialized predicate is not @@ -723,7 +724,7 @@ bool LoopUnswitch::processCurrentLoop() {      // This is a workaround for the discrepancy between LLVM IR and MSan      // semantics. See PR28054 for more details.      if (SanitizeMemory && -        !SafetyInfo.isGuaranteedToExecute(*TI, DT, currentLoop)) +        !SafetyInfo.isGuaranteedToExecute(*TI, DT, CurrentLoop))        continue;      if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { @@ -738,11 +739,11 @@ bool LoopUnswitch::processCurrentLoop() {        if (BI->isConditional()) {          // See if this, or some part of it, is loop invariant.  If so, we can          // unswitch on it if we desire. -        Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, +        Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop,                                                 Changed, MSSAU.get())                                .first; -        if (LoopCond && !EqualityPropUnSafe(*LoopCond) && -            UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { +        if (LoopCond && !equalityPropUnSafe(*LoopCond) && +            unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) {            ++NumBranches;            return true;          } @@ -752,7 +753,7 @@ bool LoopUnswitch::processCurrentLoop() {        Value *LoopCond;        OperatorChain OpChain;        std::tie(LoopCond, OpChain) = -          FindLIVLoopCondition(SC, currentLoop, Changed, MSSAU.get()); +          findLIVLoopCondition(SC, CurrentLoop, Changed, MSSAU.get());        unsigned NumCases = SI->getNumCases();        if (LoopCond && NumCases) { @@ -796,7 +797,7 @@ bool LoopUnswitch::processCurrentLoop() {          if (!UnswitchVal)            continue; -        if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { +        if (unswitchIfProfitable(LoopCond, UnswitchVal)) {            ++NumSwitches;            // In case of a full LIV, UnswitchVal is the value we unswitched out.            // In case of a partial LIV, we only unswitch when its an AND-chain @@ -812,11 +813,11 @@ bool LoopUnswitch::processCurrentLoop() {      for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end();           BBI != E; ++BBI)        if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { -        Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop, +        Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop,                                                 Changed, MSSAU.get())                                .first; -        if (LoopCond && UnswitchIfProfitable(LoopCond, -                                             ConstantInt::getTrue(Context))) { +        if (LoopCond && +            unswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) {            ++NumSelects;            return true;          } @@ -875,62 +876,38 @@ static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) {    return nullptr;  } -/// We have found that we can unswitch currentLoop when LoopCond == Val to +/// We have found that we can unswitch CurrentLoop when LoopCond == Val to  /// simplify the loop.  If we decide that this is profitable,  /// unswitch the loop, reprocess the pieces, then return true. -bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, +bool LoopUnswitch::unswitchIfProfitable(Value *LoopCond, Constant *Val,                                          Instruction *TI) {    // Check to see if it would be profitable to unswitch current loop. -  if (!BranchesInfo.CostAllowsUnswitching()) { +  if (!BranchesInfo.costAllowsUnswitching()) {      LLVM_DEBUG(dbgs() << "NOT unswitching loop %" -                      << currentLoop->getHeader()->getName() +                      << CurrentLoop->getHeader()->getName()                        << " at non-trivial condition '" << *Val                        << "' == " << *LoopCond << "\n"                        << ". Cost too high.\n");      return false;    } -  if (hasBranchDivergence && +  if (HasBranchDivergence &&        getAnalysis<LegacyDivergenceAnalysis>().isDivergent(LoopCond)) {      LLVM_DEBUG(dbgs() << "NOT unswitching loop %" -                      << currentLoop->getHeader()->getName() +                      << CurrentLoop->getHeader()->getName()                        << " at non-trivial condition '" << *Val                        << "' == " << *LoopCond << "\n"                        << ". Condition is divergent.\n");      return false;    } -  UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI); +  unswitchNontrivialCondition(LoopCond, Val, CurrentLoop, TI);    return true;  } -/// Recursively clone the specified loop and all of its children, -/// mapping the blocks with the specified map. -static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, -                       LoopInfo *LI, LPPassManager *LPM) { -  Loop &New = *LI->AllocateLoop(); -  if (PL) -    PL->addChildLoop(&New); -  else -    LI->addTopLevelLoop(&New); -  LPM->addLoop(New); - -  // Add all of the blocks in L to the new loop. -  for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); -       I != E; ++I) -    if (LI->getLoopFor(*I) == L) -      New.addBasicBlockToLoop(cast<BasicBlock>(VM[*I]), *LI); - -  // Add all of the subloops to the new loop. -  for (Loop *I : *L) -    CloneLoop(I, &New, VM, LI, LPM); - -  return &New; -} -  /// Emit a conditional branch on two values if LIC == Val, branch to TrueDst,  /// otherwise branch to FalseDest. Insert the code immediately before OldBranch  /// and remove (but not erase!) it from the function. -void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, +void LoopUnswitch::emitPreheaderBranchOnCondition(Value *LIC, Constant *Val,                                                    BasicBlock *TrueDest,                                                    BasicBlock *FalseDest,                                                    BranchInst *OldBranch, @@ -997,11 +974,11 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val,  /// that doesn't execute its body has no side-effects), unswitch it. This  /// doesn't involve any code duplication, just moving the conditional branch  /// outside of the loop and updating loop info. -void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, +void LoopUnswitch::unswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,                                              BasicBlock *ExitBlock,                                              Instruction *TI) {    LLVM_DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %" -                    << loopHeader->getName() << " [" << L->getBlocks().size() +                    << LoopHeader->getName() << " [" << L->getBlocks().size()                      << " blocks] in Function "                      << L->getHeader()->getParent()->getName()                      << " on cond: " << *Val << " == " << *Cond << "\n"); @@ -1011,9 +988,9 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,      SEWP->getSE().forgetTopmostLoop(L);    // First step, split the preheader, so that we know that there is a safe place -  // to insert the conditional branch.  We will change loopPreheader to have a +  // to insert the conditional branch.  We will change LoopPreheader to have a    // conditional branch on Cond. -  BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get()); +  BasicBlock *NewPH = SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get());    // Now that we have a place to insert the conditional branch, create a place    // to branch to: this is the exit block out of the loop that we should @@ -1029,22 +1006,21 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,    // Okay, now we have a position to branch from and a position to branch to,    // insert the new conditional branch. -  auto *OldBranch = dyn_cast<BranchInst>(loopPreheader->getTerminator()); +  auto *OldBranch = dyn_cast<BranchInst>(LoopPreheader->getTerminator());    assert(OldBranch && "Failed to split the preheader"); -  EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI); -  LPM->deleteSimpleAnalysisValue(OldBranch, L); +  emitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI); -  // EmitPreheaderBranchOnCondition removed the OldBranch from the function. +  // emitPreheaderBranchOnCondition removed the OldBranch from the function.    // Delete it, as it is no longer needed.    delete OldBranch;    // We need to reprocess this loop, it could be unswitched again. -  redoLoop = true; +  RedoLoop = true;    // Now that we know that the loop is never entered when this condition is a    // particular value, rewrite the loop with this info.  We know that this will    // at least eliminate the old branch. -  RewriteLoopBodyWithConditionConstant(L, Cond, Val, false); +  rewriteLoopBodyWithConditionConstant(L, Cond, Val, /*IsEqual=*/false);    ++NumTrivial;  } @@ -1055,8 +1031,8 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val,  /// produces no code duplications (equivalently, it produces a simpler loop and  /// a new empty loop, which gets deleted). Therefore always unswitch trivial  /// condition. -bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { -  BasicBlock *CurrentBB = currentLoop->getHeader(); +bool LoopUnswitch::tryTrivialLoopUnswitch(bool &Changed) { +  BasicBlock *CurrentBB = CurrentLoop->getHeader();    Instruction *CurrentTerm = CurrentBB->getTerminator();    LLVMContext &Context = CurrentBB->getContext(); @@ -1081,7 +1057,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      // we can not reach any trivial condition candidates (unfoldable      // branch instructions or switch instructions) and no unswitch      // can happen. Exit and return false. -    if (!currentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second) +    if (!CurrentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second)        return false;      // Check if this loop will execute any side-effecting instructions (e.g. @@ -1128,7 +1104,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      if (!BI->isConditional())        return false; -    Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, +    Value *LoopCond = findLIVLoopCondition(BI->getCondition(), CurrentLoop,                                             Changed, MSSAU.get())                            .first; @@ -1141,11 +1117,11 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      // exit through a unique exit block without having any      // side-effects.  If so, determine the value of Cond that causes      // it to do this. -    if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, -                                             BI->getSuccessor(0)))) { +    if ((LoopExitBB = +             isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(0)))) {        CondVal = ConstantInt::getTrue(Context); -    } else if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, -                                                    BI->getSuccessor(1)))) { +    } else if ((LoopExitBB = +                    isTrivialLoopExitBlock(CurrentLoop, BI->getSuccessor(1)))) {        CondVal = ConstantInt::getFalse(Context);      } @@ -1154,16 +1130,16 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin()))        return false;   // Can't handle this. -    if (EqualityPropUnSafe(*LoopCond)) +    if (equalityPropUnSafe(*LoopCond))        return false; -    UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, +    unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB,                               CurrentTerm);      ++NumBranches;      return true;    } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) {      // If this isn't switching on an invariant condition, we can't unswitch it. -    Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), currentLoop, +    Value *LoopCond = findLIVLoopCondition(SI->getCondition(), CurrentLoop,                                             Changed, MSSAU.get())                            .first; @@ -1181,7 +1157,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      for (auto Case : SI->cases()) {        BasicBlock *LoopExitCandidate;        if ((LoopExitCandidate = -               isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) { +               isTrivialLoopExitBlock(CurrentLoop, Case.getCaseSuccessor()))) {          // Okay, we found a trivial case, remember the value that is trivial.          ConstantInt *CaseVal = Case.getCaseValue(); @@ -1200,7 +1176,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {      if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin()))        return false;   // Can't handle this. -    UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, +    unswitchTrivialCondition(CurrentLoop, LoopCond, CondVal, LoopExitBB,                               nullptr);      // We are only unswitching full LIV. @@ -1213,11 +1189,11 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) {  /// Split all of the edges from inside the loop to their exit blocks.  /// Update the appropriate Phi nodes as we do so. -void LoopUnswitch::SplitExitEdges(Loop *L, -                               const SmallVectorImpl<BasicBlock *> &ExitBlocks){ +void LoopUnswitch::splitExitEdges( +    Loop *L, const SmallVectorImpl<BasicBlock *> &ExitBlocks) { -  for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { -    BasicBlock *ExitBlock = ExitBlocks[i]; +  for (unsigned I = 0, E = ExitBlocks.size(); I != E; ++I) { +    BasicBlock *ExitBlock = ExitBlocks[I];      SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBlock),                                         pred_end(ExitBlock)); @@ -1231,11 +1207,11 @@ void LoopUnswitch::SplitExitEdges(Loop *L,  /// We determined that the loop is profitable to unswitch when LIC equal Val.  /// Split it into loop versions and test the condition outside of either loop.  /// Return the loops created as Out1/Out2. -void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, +void LoopUnswitch::unswitchNontrivialCondition(Value *LIC, Constant *Val,                                                 Loop *L, Instruction *TI) { -  Function *F = loopHeader->getParent(); +  Function *F = LoopHeader->getParent();    LLVM_DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" -                    << loopHeader->getName() << " [" << L->getBlocks().size() +                    << LoopHeader->getName() << " [" << L->getBlocks().size()                      << " blocks] in Function " << F->getName() << " when '"                      << *Val << "' == " << *LIC << "\n"); @@ -1253,7 +1229,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    // First step, split the preheader and exit blocks, and add these blocks to    // the LoopBlocks list.    BasicBlock *NewPreheader = -      SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get()); +      SplitEdge(LoopPreheader, LoopHeader, DT, LI, MSSAU.get());    LoopBlocks.push_back(NewPreheader);    // We want the loop to come after the preheader, but before the exit blocks. @@ -1264,7 +1240,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    // Split all of the edges from inside the loop to their exit blocks.  Update    // the appropriate Phi nodes as we do so. -  SplitExitEdges(L, ExitBlocks); +  splitExitEdges(L, ExitBlocks);    // The exit blocks may have been changed due to edge splitting, recompute.    ExitBlocks.clear(); @@ -1278,12 +1254,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    // the instructions and blocks.    NewBlocks.reserve(LoopBlocks.size());    ValueToValueMapTy VMap; -  for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) { -    BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[i], VMap, ".us", F); +  for (unsigned I = 0, E = LoopBlocks.size(); I != E; ++I) { +    BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[I], VMap, ".us", F);      NewBlocks.push_back(NewBB); -    VMap[LoopBlocks[i]] = NewBB;  // Keep the BB mapping. -    LPM->cloneBasicBlockSimpleAnalysis(LoopBlocks[i], NewBB, L); +    VMap[LoopBlocks[I]] = NewBB; // Keep the BB mapping.    }    // Splice the newly inserted blocks into the function right before the @@ -1293,7 +1268,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,                                  NewBlocks[0]->getIterator(), F->end());    // Now we create the new Loop object for the versioned loop. -  Loop *NewLoop = CloneLoop(L, L->getParentLoop(), VMap, LI, LPM); +  Loop *NewLoop = cloneLoop(L, L->getParentLoop(), VMap, LI, LPM);    // Recalculate unswitching quota, inherit simplified switches info for NewBB,    // Probably clone more loop-unswitch related loop properties. @@ -1306,10 +1281,10 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,      ParentLoop->addBasicBlockToLoop(NewBlocks[0], *LI);    } -  for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { -    BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[i]]); +  for (unsigned EBI = 0, EBE = ExitBlocks.size(); EBI != EBE; ++EBI) { +    BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[EBI]]);      // The new exit block should be in the same loop as the old one. -    if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[i])) +    if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[EBI]))        ExitBBLoop->addBasicBlockToLoop(NewExit, *LI);      assert(NewExit->getTerminator()->getNumSuccessors() == 1 && @@ -1319,7 +1294,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,      // If the successor of the exit block had PHI nodes, add an entry for      // NewExit.      for (PHINode &PN : ExitSucc->phis()) { -      Value *V = PN.getIncomingValueForBlock(ExitBlocks[i]); +      Value *V = PN.getIncomingValueForBlock(ExitBlocks[EBI]);        ValueToValueMapTy::iterator It = VMap.find(V);        if (It != VMap.end()) V = It->second;        PN.addIncoming(V, NewExit); @@ -1340,8 +1315,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    }    // Rewrite the code to refer to itself. -  for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { -    for (Instruction &I : *NewBlocks[i]) { +  for (unsigned NBI = 0, NBE = NewBlocks.size(); NBI != NBE; ++NBI) { +    for (Instruction &I : *NewBlocks[NBI]) {        RemapInstruction(&I, VMap,                         RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);        if (auto *II = dyn_cast<IntrinsicInst>(&I)) @@ -1351,7 +1326,7 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    }    // Rewrite the original preheader to select between versions of the loop. -  BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator()); +  BranchInst *OldBR = cast<BranchInst>(LoopPreheader->getTerminator());    assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] &&           "Preheader splitting did not work correctly!"); @@ -1364,9 +1339,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    }    // Emit the new branch that selects between the two versions of this loop. -  EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR, +  emitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR,                                   TI); -  LPM->deleteSimpleAnalysisValue(OldBR, L);    if (MSSAU) {      // Update MemoryPhis in Exit blocks.      MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMap, *DT); @@ -1375,11 +1349,11 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    }    // The OldBr was replaced by a new one and removed (but not erased) by -  // EmitPreheaderBranchOnCondition. It is no longer needed, so delete it. +  // emitPreheaderBranchOnCondition. It is no longer needed, so delete it.    delete OldBR;    LoopProcessWorklist.push_back(NewLoop); -  redoLoop = true; +  RedoLoop = true;    // Keep a WeakTrackingVH holding onto LIC.  If the first call to    // RewriteLoopBody @@ -1390,22 +1364,23 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val,    // Now we rewrite the original code to know that the condition is true and the    // new code to know that the condition is false. -  RewriteLoopBodyWithConditionConstant(L, LIC, Val, false); +  rewriteLoopBodyWithConditionConstant(L, LIC, Val, /*IsEqual=*/false);    // It's possible that simplifying one loop could cause the other to be    // changed to another value or a constant.  If its a constant, don't simplify    // it.    if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop &&        LICHandle && !isa<Constant>(LICHandle)) -    RewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, true); +    rewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, +                                         /*IsEqual=*/true);    if (MSSA && VerifyMemorySSA)      MSSA->verifyMemorySSA();  }  /// Remove all instances of I from the worklist vector specified. -static void RemoveFromWorklist(Instruction *I, -                               std::vector<Instruction*> &Worklist) { +static void removeFromWorklist(Instruction *I, +                               std::vector<Instruction *> &Worklist) {    Worklist.erase(std::remove(Worklist.begin(), Worklist.end(), I),                   Worklist.end()); @@ -1413,7 +1388,7 @@ static void RemoveFromWorklist(Instruction *I,  /// When we find that I really equals V, remove I from the  /// program, replacing all uses with V and update the worklist. -static void ReplaceUsesOfWith(Instruction *I, Value *V, +static void replaceUsesOfWith(Instruction *I, Value *V,                                std::vector<Instruction *> &Worklist, Loop *L,                                LPPassManager *LPM, MemorySSAUpdater *MSSAU) {    LLVM_DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); @@ -1426,8 +1401,7 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V,    // Add users to the worklist which may be simplified now.    for (User *U : I->users())      Worklist.push_back(cast<Instruction>(U)); -  LPM->deleteSimpleAnalysisValue(I, L); -  RemoveFromWorklist(I, Worklist); +  removeFromWorklist(I, Worklist);    I->replaceAllUsesWith(V);    if (!I->mayHaveSideEffects()) {      if (MSSAU) @@ -1440,7 +1414,7 @@ static void ReplaceUsesOfWith(Instruction *I, Value *V,  /// We know either that the value LIC has the value specified by Val in the  /// specified loop, or we know it does NOT have that value.  /// Rewrite any uses of LIC or of properties correlated to it. -void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, +void LoopUnswitch::rewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,                                                          Constant *Val,                                                          bool IsEqual) {    assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); @@ -1478,7 +1452,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,      for (Instruction *UI : Worklist)        UI->replaceUsesOfWith(LIC, Replacement); -    SimplifyCode(Worklist, L); +    simplifyCode(Worklist, L);      return;    } @@ -1492,7 +1466,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,      // At this point, we know LIC is definitely not Val. Try to use some simple      // logic to simplify the user w.r.t. to the context. -    if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) { +    if (Value *Replacement = simplifyInstructionWithNotEqual(UI, LIC, Val)) {        if (LI->replacementPreservesLCSSAForm(UI, Replacement)) {          // This in-loop instruction has been simplified w.r.t. its context,          // i.e. LIC != Val, make sure we propagate its replacement value to @@ -1506,7 +1480,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,        }      } -    // This is a LIC user, push it into the worklist so that SimplifyCode can +    // This is a LIC user, push it into the worklist so that simplifyCode can      // attempt to simplify it.      Worklist.push_back(UI); @@ -1568,7 +1542,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,      DT->addNewBlock(Abort, NewSISucc);    } -  SimplifyCode(Worklist, L); +  simplifyCode(Worklist, L);  }  /// Now that we have simplified some instructions in the loop, walk over it and @@ -1579,7 +1553,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC,  /// FIXME: When the loop optimizer is more mature, separate this out to a new  /// pass.  /// -void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { +void LoopUnswitch::simplifyCode(std::vector<Instruction *> &Worklist, Loop *L) {    const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();    while (!Worklist.empty()) {      Instruction *I = Worklist.back(); @@ -1593,8 +1567,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)          if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i)))            Worklist.push_back(Use); -      LPM->deleteSimpleAnalysisValue(I, L); -      RemoveFromWorklist(I, Worklist); +      removeFromWorklist(I, Worklist);        if (MSSAU)          MSSAU->removeMemoryAccess(I);        I->eraseFromParent(); @@ -1607,7 +1580,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {      // 'false'.  TODO: update the domtree properly so we can pass it here.      if (Value *V = SimplifyInstruction(I, DL))        if (LI->replacementPreservesLCSSAForm(I, V)) { -        ReplaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get()); +        replaceUsesOfWith(I, V, Worklist, L, LPM, MSSAU.get());          continue;        } @@ -1624,9 +1597,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {          assert(SinglePred == Pred && "CFG broken");          // Make the LPM and Worklist updates specific to LoopUnswitch. -        LPM->deleteSimpleAnalysisValue(BI, L); -        RemoveFromWorklist(BI, Worklist); -        LPM->deleteSimpleAnalysisValue(Succ, L); +        removeFromWorklist(BI, Worklist);          auto SuccIt = Succ->begin();          while (PHINode *PN = dyn_cast<PHINode>(SuccIt++)) {            for (unsigned It = 0, E = PN->getNumOperands(); It != E; ++It) @@ -1634,8 +1605,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {                Worklist.push_back(Use);            for (User *U : PN->users())              Worklist.push_back(cast<Instruction>(U)); -          LPM->deleteSimpleAnalysisValue(PN, L); -          RemoveFromWorklist(PN, Worklist); +          removeFromWorklist(PN, Worklist);            ++NumSimplify;          }          // Merge the block and make the remaining analyses updates. @@ -1652,7 +1622,7 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) {  /// Simple simplifications we can do given the information that Cond is  /// definitely not equal to Val. -Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst, +Value *LoopUnswitch::simplifyInstructionWithNotEqual(Instruction *Inst,                                                       Value *Invariant,                                                       Constant *Val) {    // icmp eq cond, val -> false diff --git a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 7b9af527d444..06b684ef1e70 100644 --- a/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -69,7 +69,6 @@  #include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/Dominators.h"  #include "llvm/IR/Instruction.h" diff --git a/llvm/lib/Transforms/Scalar/LowerAtomic.cpp b/llvm/lib/Transforms/Scalar/LowerAtomic.cpp index ab7b85e89e7b..d1f67b355b19 100644 --- a/llvm/lib/Transforms/Scalar/LowerAtomic.cpp +++ b/llvm/lib/Transforms/Scalar/LowerAtomic.cpp @@ -117,18 +117,17 @@ static bool LowerStoreInst(StoreInst *SI) {  static bool runOnBasicBlock(BasicBlock &BB) {    bool Changed = false; -  for (BasicBlock::iterator DI = BB.begin(), DE = BB.end(); DI != DE;) { -    Instruction *Inst = &*DI++; -    if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) +  for (Instruction &Inst : make_early_inc_range(BB)) { +    if (FenceInst *FI = dyn_cast<FenceInst>(&Inst))        Changed |= LowerFenceInst(FI); -    else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(Inst)) +    else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(&Inst))        Changed |= LowerAtomicCmpXchgInst(CXI); -    else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(Inst)) +    else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(&Inst))        Changed |= LowerAtomicRMWInst(RMWI); -    else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { +    else if (LoadInst *LI = dyn_cast<LoadInst>(&Inst)) {        if (LI->isAtomic())          LowerLoadInst(LI); -    } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { +    } else if (StoreInst *SI = dyn_cast<StoreInst>(&Inst)) {        if (SI->isAtomic())          LowerStoreInst(SI);      } diff --git a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp index 21c6c32e8e02..fddf28c281fc 100644 --- a/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerConstantIntrinsics.cpp @@ -13,7 +13,9 @@  #include "llvm/Transforms/Scalar/LowerConstantIntrinsics.h"  #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/InstructionSimplify.h"  #include "llvm/Analysis/MemoryBuiltins.h"  #include "llvm/Analysis/TargetLibraryInfo.h" @@ -135,8 +137,12 @@ static bool lowerConstantIntrinsics(Function &F, const TargetLibraryInfo *TLI) {  PreservedAnalyses  LowerConstantIntrinsicsPass::run(Function &F, FunctionAnalysisManager &AM) { -  if (lowerConstantIntrinsics(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) -    return PreservedAnalyses::none(); +  if (lowerConstantIntrinsics(F, +                              AM.getCachedResult<TargetLibraryAnalysis>(F))) { +    PreservedAnalyses PA; +    PA.preserve<GlobalsAA>(); +    return PA; +  }    return PreservedAnalyses::all();  } @@ -145,7 +151,7 @@ namespace {  /// Legacy pass for lowering is.constant intrinsics out of the IR.  ///  /// When this pass is run over a function it converts is.constant intrinsics -/// into 'true' or 'false'. This is completements the normal constand folding +/// into 'true' or 'false'. This complements the normal constant folding  /// to 'true' as part of Instruction Simplify passes.  class LowerConstantIntrinsics : public FunctionPass {  public: @@ -159,6 +165,10 @@ public:      const TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI(F) : nullptr;      return lowerConstantIntrinsics(F, TLI);    } + +  void getAnalysisUsage(AnalysisUsage &AU) const override { +    AU.addPreserved<GlobalsAAWrapperPass>(); +  }  };  } // namespace diff --git a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp index 53671c7bc3d1..0fe7dd9cfb39 100644 --- a/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp +++ b/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -55,13 +55,35 @@ static cl::opt<uint32_t> UnlikelyBranchWeight(      "unlikely-branch-weight", cl::Hidden, cl::init(1),      cl::desc("Weight of the branch unlikely to be taken (default = 1)")); +static std::tuple<uint32_t, uint32_t> +getBranchWeight(Intrinsic::ID IntrinsicID, CallInst *CI, int BranchCount) { +  if (IntrinsicID == Intrinsic::expect) { +    // __builtin_expect +    return std::make_tuple(LikelyBranchWeight.getValue(), +                           UnlikelyBranchWeight.getValue()); +  } else { +    // __builtin_expect_with_probability +    assert(CI->getNumOperands() >= 3 && +           "expect with probability must have 3 arguments"); +    ConstantFP *Confidence = dyn_cast<ConstantFP>(CI->getArgOperand(2)); +    double TrueProb = Confidence->getValueAPF().convertToDouble(); +    assert((TrueProb >= 0.0 && TrueProb <= 1.0) && +           "probability value must be in the range [0.0, 1.0]"); +    double FalseProb = (1.0 - TrueProb) / (BranchCount - 1); +    uint32_t LikelyBW = ceil((TrueProb * (double)(INT32_MAX - 1)) + 1.0); +    uint32_t UnlikelyBW = ceil((FalseProb * (double)(INT32_MAX - 1)) + 1.0); +    return std::make_tuple(LikelyBW, UnlikelyBW); +  } +} +  static bool handleSwitchExpect(SwitchInst &SI) {    CallInst *CI = dyn_cast<CallInst>(SI.getCondition());    if (!CI)      return false;    Function *Fn = CI->getCalledFunction(); -  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect) +  if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect && +              Fn->getIntrinsicID() != Intrinsic::expect_with_probability))      return false;    Value *ArgValue = CI->getArgOperand(0); @@ -71,15 +93,19 @@ static bool handleSwitchExpect(SwitchInst &SI) {    SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue);    unsigned n = SI.getNumCases(); // +1 for default case. -  SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight); +  uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; +  std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = +      getBranchWeight(Fn->getIntrinsicID(), CI, n + 1); + +  SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeightVal);    uint64_t Index = (Case == *SI.case_default()) ? 0 : Case.getCaseIndex() + 1; -  Weights[Index] = LikelyBranchWeight; +  Weights[Index] = LikelyBranchWeightVal; -  SI.setMetadata( -      LLVMContext::MD_misexpect, -      MDBuilder(CI->getContext()) -          .createMisExpect(Index, LikelyBranchWeight, UnlikelyBranchWeight)); +  SI.setMetadata(LLVMContext::MD_misexpect, +                 MDBuilder(CI->getContext()) +                     .createMisExpect(Index, LikelyBranchWeightVal, +                                      UnlikelyBranchWeightVal));    SI.setCondition(ArgValue);    misexpect::checkFrontendInstrumentation(SI); @@ -223,15 +249,18 @@ static void handlePhiDef(CallInst *Expect) {          return true;        return false;      }; +    uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; +    std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = getBranchWeight( +        Expect->getCalledFunction()->getIntrinsicID(), Expect, 2);      if (IsOpndComingFromSuccessor(BI->getSuccessor(1))) -      BI->setMetadata( -          LLVMContext::MD_prof, -          MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight)); +      BI->setMetadata(LLVMContext::MD_prof, +                      MDB.createBranchWeights(LikelyBranchWeightVal, +                                              UnlikelyBranchWeightVal));      else if (IsOpndComingFromSuccessor(BI->getSuccessor(0))) -      BI->setMetadata( -          LLVMContext::MD_prof, -          MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight)); +      BI->setMetadata(LLVMContext::MD_prof, +                      MDB.createBranchWeights(UnlikelyBranchWeightVal, +                                              LikelyBranchWeightVal));    }  } @@ -277,7 +306,8 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {    }    Function *Fn = CI->getCalledFunction(); -  if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect) +  if (!Fn || (Fn->getIntrinsicID() != Intrinsic::expect && +              Fn->getIntrinsicID() != Intrinsic::expect_with_probability))      return false;    Value *ArgValue = CI->getArgOperand(0); @@ -289,13 +319,21 @@ template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) {    MDNode *Node;    MDNode *ExpNode; +  uint32_t LikelyBranchWeightVal, UnlikelyBranchWeightVal; +  std::tie(LikelyBranchWeightVal, UnlikelyBranchWeightVal) = +      getBranchWeight(Fn->getIntrinsicID(), CI, 2); +    if ((ExpectedValue->getZExtValue() == ValueComparedTo) ==        (Predicate == CmpInst::ICMP_EQ)) { -    Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight); -    ExpNode = MDB.createMisExpect(0, LikelyBranchWeight, UnlikelyBranchWeight); +    Node = +        MDB.createBranchWeights(LikelyBranchWeightVal, UnlikelyBranchWeightVal); +    ExpNode = +        MDB.createMisExpect(0, LikelyBranchWeightVal, UnlikelyBranchWeightVal);    } else { -    Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight); -    ExpNode = MDB.createMisExpect(1, LikelyBranchWeight, UnlikelyBranchWeight); +    Node = +        MDB.createBranchWeights(UnlikelyBranchWeightVal, LikelyBranchWeightVal); +    ExpNode = +        MDB.createMisExpect(1, LikelyBranchWeightVal, UnlikelyBranchWeightVal);    }    BSI.setMetadata(LLVMContext::MD_misexpect, ExpNode); @@ -347,7 +385,8 @@ static bool lowerExpectIntrinsic(Function &F) {        }        Function *Fn = CI->getCalledFunction(); -      if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) { +      if (Fn && (Fn->getIntrinsicID() == Intrinsic::expect || +                 Fn->getIntrinsicID() == Intrinsic::expect_with_probability)) {          // Before erasing the llvm.expect, walk backward to find          // phi that define llvm.expect's first arg, and          // infer branch probability: diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp index 0ff6ee8bcfcc..90314b17b5e2 100644 --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -9,8 +9,11 @@  // Lower matrix intrinsics to vector operations.  //  // TODO: -//  * Implement multiply & add fusion -//  * Add remark, summarizing the available matrix optimization opportunities. +//  * Improve fusion: +//   * Support more cases, e.g. multiply-add, multiply-sub, operands/results +//     transposed. +//   * Improve cost-modeling, e.g. choose different number of rows/columns +//     columns for tiles, consider cost of copies on alias.  //  //===----------------------------------------------------------------------===// @@ -18,10 +21,15 @@  #include "llvm/ADT/GraphTraits.h"  #include "llvm/ADT/PostOrderIterator.h"  #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h"  #include "llvm/Analysis/VectorUtils.h"  #include "llvm/IR/CFG.h"  #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/Instructions.h" @@ -29,30 +37,69 @@  #include "llvm/IR/PatternMatch.h"  #include "llvm/InitializePasses.h"  #include "llvm/Pass.h" +#include "llvm/Support/Alignment.h" +#include "llvm/Support/CommandLine.h"  #include "llvm/Support/Debug.h"  #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h"  using namespace llvm;  using namespace PatternMatch;  #define DEBUG_TYPE "lower-matrix-intrinsics" -static cl::opt<bool> EnableShapePropagation("matrix-propagate-shape", -                                            cl::init(true)); - +static cl::opt<bool> EnableShapePropagation( +    "matrix-propagate-shape", cl::init(true), cl::Hidden, +    cl::desc("Enable/disable shape propagation from matrix intrinsics to other " +             "instructions.")); + +static cl::opt<bool> +    FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden, +               cl::desc("Enable/disable fusing matrix instructions.")); +// TODO: Allow and use non-square tiles. +static cl::opt<unsigned> TileSize( +    "fuse-matrix-tile-size", cl::init(4), cl::Hidden, +    cl::desc( +        "Tile size for matrix instruction fusion using square-shaped tiles.")); +static cl::opt<bool> ForceFusion( +    "force-fuse-matrix", cl::init(false), cl::Hidden, +    cl::desc("Force matrix instruction fusion even if not profitable."));  static cl::opt<bool> AllowContractEnabled(      "matrix-allow-contract", cl::init(false), cl::Hidden,      cl::desc("Allow the use of FMAs if available and profitable. This may "               "result in different results, due to less rounding error.")); +enum class MatrixLayoutTy { ColumnMajor, RowMajor }; + +static cl::opt<MatrixLayoutTy> MatrixLayout( +    "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor), +    cl::desc("Sets the default matrix layout"), +    cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major", +                          "Use column-major layout"), +               clEnumValN(MatrixLayoutTy::RowMajor, "row-major", +                          "Use row-major layout"))); + +/// Helper function to either return Scope, if it is a subprogram or the +/// attached subprogram for a local scope. +static DISubprogram *getSubprogram(DIScope *Scope) { +  if (auto *Subprogram = dyn_cast<DISubprogram>(Scope)) +    return Subprogram; +  return cast<DILocalScope>(Scope)->getSubprogram(); +} +  namespace { -// Given an element poitner \p BasePtr to the start of a (sub) matrix, compute -// the start address of column \p Col with type (\p EltType x \p NumRows) -// assuming \p Stride elements between start two consecutive columns. -// \p Stride must be >= \p NumRows. +// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute +// the start address of vector \p VecIdx with type (\p EltType x \p NumElements) +// assuming \p Stride elements between start two consecutive vectors. +// \p Stride must be >= \p NumElements. +// For column-major matrixes, the function computes the address of a column +// vectors and \p NumElements must be set to the number of elements in a column +// (= number of rows of the matrix). For row-major matrixes, the function +// computes the address of a row vector and \p NumElements must be set to the +// number of elements in a column (= number of columns of the matrix).  // -// Consider a 4x4 matrix like below +// Consider a 4x4 matrix in column-mjaor layout like below  //  //      0       1      2      3  // 0   v_0_0  v_0_1  v_0_2  v_0_3 @@ -62,14 +109,14 @@ namespace {  // To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,  // we need a pointer to the first element of the submatrix as base pointer. -// Then we can use computeColumnAddr to compute the addresses for the columns +// Then we can use computeVectorAddr to compute the addresses for the columns  // of the sub-matrix.  // -// Column 0: computeColumnAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..) +// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)  //           -> just returns Base -// Column 1: computeColumnAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..) +// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)  //           -> returns Base + (1 * 4) -// Column 2: computeColumnAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..) +// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)  //           -> returns Base + (2 * 4)  //  // The graphic below illustrates the number of elements in a column (marked @@ -82,30 +129,30 @@ namespace {  //         v_2_0 |v_2_1 |v_2_2 |v_2_3  //         v_3_0 {v_3_1 {v_3_2  v_3_3  // -Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride, -                         unsigned NumRows, Type *EltType, +Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride, +                         unsigned NumElements, Type *EltType,                           IRBuilder<> &Builder) {    assert((!isa<ConstantInt>(Stride) || -          cast<ConstantInt>(Stride)->getZExtValue() >= NumRows) && -         "Stride must be >= the number of rows."); +          cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) && +         "Stride must be >= the number of elements in the result vector.");    unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); -  // Compute the start of the column with index Col as Col * Stride. -  Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start"); +  // Compute the start of the vector with index VecIdx as VecIdx * Stride. +  Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start"); -  // Get pointer to the start of the selected column. Skip GEP creation, -  // if we select column 0. -  if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero()) -    ColumnStart = BasePtr; +  // Get pointer to the start of the selected vector. Skip GEP creation, +  // if we select vector 0. +  if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero()) +    VecStart = BasePtr;    else -    ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep"); +    VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep"); -  // Cast elementwise column start pointer to a pointer to a column -  // (EltType x NumRows)*. -  Type *ColumnType = VectorType::get(EltType, NumRows); -  Type *ColumnPtrType = PointerType::get(ColumnType, AS); -  return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast"); +  // Cast elementwise vector start pointer to a pointer to a vector +  // (EltType x NumElements)*. +  auto *VecType = FixedVectorType::get(EltType, NumElements); +  Type *VecPtrType = PointerType::get(VecType, AS); +  return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");  }  /// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics. @@ -113,15 +160,16 @@ Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,  /// Currently, the lowering for each matrix intrinsic is done as follows:  /// 1. Propagate the shape information from intrinsics to connected  /// instructions. -/// 2. Lower instructions with shape information. +/// 2. Lower instructions with shape information (assuming column-major layout). +///  The lowering works similarly using row-major layout.  ///  2.1. Get column vectors for each argument. If we already lowered the  ///       definition of an argument, use the produced column vectors directly.  ///       If not, split the operand vector containing an embedded matrix into  ///       a set of column vectors, -///  2.2. Lower the instruction in terms of columnwise operations, which yields -///       a set of column vectors containing result matrix. Note that we lower -///       all instructions that have shape information. Besides the intrinsics, -///       this includes stores for example. +///  2.2. Lower the instruction in terms of column major operations, which +///       yields a set of column vectors containing result matrix. Note that we +///       lower all instructions that have shape information. Besides the +///       intrinsics, this includes stores for example.  ///  2.3. Update uses of the lowered instruction. If we have shape information  ///       for a user, there is nothing to do, as we will look up the result  ///       column matrix when lowering the user. For other uses, we embed the @@ -134,42 +182,157 @@ class LowerMatrixIntrinsics {    Function &Func;    const DataLayout &DL;    const TargetTransformInfo &TTI; +  AliasAnalysis &AA; +  DominatorTree &DT; +  LoopInfo &LI; +  OptimizationRemarkEmitter &ORE; + +  /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. +  struct OpInfoTy { +    /// Number of stores emitted to generate this matrix. +    unsigned NumStores = 0; +    /// Number of loads emitted to generate this matrix. +    unsigned NumLoads = 0; +    /// Number of compute operations emitted to generate this matrix. +    unsigned NumComputeOps = 0; + +    OpInfoTy &operator+=(const OpInfoTy &RHS) { +      NumStores += RHS.NumStores; +      NumLoads += RHS.NumLoads; +      NumComputeOps += RHS.NumComputeOps; +      return *this; +    } +  }; + +  /// Wrapper class representing a matrix as a set of vectors, either in row or +  /// column major layout. All vectors must have the same vector type. +  class MatrixTy { +    SmallVector<Value *, 16> Vectors; + +    OpInfoTy OpInfo; -  /// Wrapper class representing a matrix as a set of column vectors. -  /// All column vectors must have the same vector type. -  class ColumnMatrixTy { -    SmallVector<Value *, 16> Columns; +    bool IsColumnMajor = true;    public: -    ColumnMatrixTy() : Columns() {} -    ColumnMatrixTy(ArrayRef<Value *> Cols) -        : Columns(Cols.begin(), Cols.end()) {} +    MatrixTy() +        : Vectors(), +          IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} +    MatrixTy(ArrayRef<Value *> Vectors) +        : Vectors(Vectors.begin(), Vectors.end()), +          IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {} +    MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy) +        : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) { + +      unsigned D = isColumnMajor() ? NumColumns : NumRows; +      for (unsigned J = 0; J < D; ++J) +        addVector(UndefValue::get(FixedVectorType::get( +            EltTy, isColumnMajor() ? NumRows : NumColumns))); +    } + +    Value *getVector(unsigned i) const { return Vectors[i]; } +    Value *getColumn(unsigned i) const { +      assert(isColumnMajor() && "only supported for column-major matrixes"); +      return Vectors[i]; +    } +    Value *getRow(unsigned i) const { +      assert(!isColumnMajor() && "only supported for row-major matrixes"); +      return Vectors[i]; +    } -    Value *getColumn(unsigned i) const { return Columns[i]; } +    void setVector(unsigned i, Value *V) { Vectors[i] = V; } -    void setColumn(unsigned i, Value *V) { Columns[i] = V; } +    Type *getElementType() { return getVectorTy()->getElementType(); } -    size_t getNumColumns() const { return Columns.size(); } -    size_t getNumRows() const { -      assert(Columns.size() > 0 && "Cannot call getNumRows without columns"); -      return cast<VectorType>(Columns[0]->getType())->getNumElements(); +    unsigned getNumVectors() const { +      if (isColumnMajor()) +        return getNumColumns(); +      return getNumRows();      } -    const SmallVectorImpl<Value *> &getColumnVectors() const { return Columns; } +    unsigned getNumColumns() const { +      if (isColumnMajor()) +        return Vectors.size(); +      else { +        assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); +        return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); +      } +    } +    unsigned getNumRows() const { +      if (isColumnMajor()) { +        assert(Vectors.size() > 0 && "Cannot call getNumRows without columns"); +        return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements(); +      } else +        return Vectors.size(); +    } -    SmallVectorImpl<Value *> &getColumnVectors() { return Columns; } +    void addVector(Value *V) { Vectors.push_back(V); } +    VectorType *getColumnTy() { +      assert(isColumnMajor() && "only supported for column-major matrixes"); +      return getVectorTy(); +    } -    void addColumn(Value *V) { Columns.push_back(V); } +    VectorType *getVectorTy() { +      return cast<VectorType>(Vectors[0]->getType()); +    }      iterator_range<SmallVector<Value *, 8>::iterator> columns() { -      return make_range(Columns.begin(), Columns.end()); +      assert(isColumnMajor() && +             "columns() only supported for column-major matrixes"); +      return make_range(Vectors.begin(), Vectors.end());      } -    /// Embed the columns of the matrix into a flat vector by concatenating +    iterator_range<SmallVector<Value *, 8>::iterator> vectors() { +      return make_range(Vectors.begin(), Vectors.end()); +    } + +    /// Embed the vectors of the matrix into a flat vector by concatenating      /// them.      Value *embedInVector(IRBuilder<> &Builder) const { -      return Columns.size() == 1 ? Columns[0] -                                 : concatenateVectors(Builder, Columns); +      return Vectors.size() == 1 ? Vectors[0] +                                 : concatenateVectors(Builder, Vectors); +    } + +    MatrixTy &addNumLoads(unsigned N) { +      OpInfo.NumLoads += N; +      return *this; +    } + +    void setNumLoads(unsigned N) { OpInfo.NumLoads = N; } + +    MatrixTy &addNumStores(unsigned N) { +      OpInfo.NumStores += N; +      return *this; +    } + +    MatrixTy &addNumComputeOps(unsigned N) { +      OpInfo.NumComputeOps += N; +      return *this; +    } + +    unsigned getNumStores() const { return OpInfo.NumStores; } +    unsigned getNumLoads() const { return OpInfo.NumLoads; } +    unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; } + +    const OpInfoTy &getOpInfo() const { return OpInfo; } + +    bool isColumnMajor() const { return IsColumnMajor; } + +    unsigned getStride() const { +      if (isColumnMajor()) +        return getNumRows(); +      return getNumColumns(); +    } + +    /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the +    /// matrix is column-major, the result vector is extracted from a column +    /// vector, otherwise from a row vector. +    Value *extractVector(unsigned I, unsigned J, unsigned NumElts, +                         IRBuilder<> &Builder) const { +      Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I); +      Value *Undef = UndefValue::get(Vec->getType()); +      return Builder.CreateShuffleVector( +          Vec, Undef, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0), +          "block");      }    }; @@ -177,12 +340,15 @@ class LowerMatrixIntrinsics {      unsigned NumRows;      unsigned NumColumns; +    bool IsColumnMajor; +      ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0) -        : NumRows(NumRows), NumColumns(NumColumns) {} +        : NumRows(NumRows), NumColumns(NumColumns), +          IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}      ShapeInfo(Value *NumRows, Value *NumColumns) -        : NumRows(cast<ConstantInt>(NumRows)->getZExtValue()), -          NumColumns(cast<ConstantInt>(NumColumns)->getZExtValue()) {} +        : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(), +                    cast<ConstantInt>(NumColumns)->getZExtValue()) {}      bool operator==(const ShapeInfo &other) {        return NumRows == other.NumRows && NumColumns == other.NumColumns; @@ -195,12 +361,24 @@ class LowerMatrixIntrinsics {        assert(NumRows == 0 || NumColumns != 0);        return NumRows != 0;      } + +    unsigned getStride() const { +      if (IsColumnMajor) +        return NumRows; +      return NumColumns; +    } + +    unsigned getNumVectors() const { +      if (IsColumnMajor) +        return NumColumns; +      return NumRows; +    }    };    /// Maps instructions to their shape information. The shape information    /// describes the shape to be used while lowering. This matches the shape of    /// the result value of the instruction, with the only exceptions being store -  /// instructions and the matrix_columnwise_store intrinsics. For those, the +  /// instructions and the matrix_column_major_store intrinsics. For those, the    /// shape information indicates that those instructions should be lowered    /// using shape information as well.    DenseMap<Value *, ShapeInfo> ShapeMap; @@ -211,31 +389,49 @@ class LowerMatrixIntrinsics {    SmallVector<Instruction *, 16> ToRemove;    /// Map from instructions to their produced column matrix. -  DenseMap<Value *, ColumnMatrixTy> Inst2ColumnMatrix; +  MapVector<Value *, MatrixTy> Inst2ColumnMatrix;  public: -  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI) -      : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI) {} +  LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI, +                        AliasAnalysis &AA, DominatorTree &DT, LoopInfo &LI, +                        OptimizationRemarkEmitter &ORE) +      : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT), +        LI(LI), ORE(ORE) {} + +  unsigned getNumOps(Type *VT) { +    assert(isa<VectorType>(VT) && "Expected vector type"); +    return getNumOps(VT->getScalarType(), +                     cast<FixedVectorType>(VT)->getNumElements()); +  } -  /// Return the set of column vectors that a matrix value is lowered to. +  // +  /// Return the estimated number of vector ops required for an operation on +  /// \p VT * N. +  unsigned getNumOps(Type *ST, unsigned N) { +    return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() / +                     double(TTI.getRegisterBitWidth(true))); +  } + +  /// Return the set of vectors that a matrix value is lowered to.    /// -  /// If we lowered \p MatrixVal, just return the cache result column matrix. -  /// Otherwie split the flat vector \p MatrixVal containing a matrix with -  /// shape \p SI into column vectors. -  ColumnMatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, -                           IRBuilder<> Builder) { +  /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise +  /// split the flat vector \p MatrixVal containing a matrix with shape \p SI +  /// into vectors. +  MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI, +                     IRBuilder<> &Builder) {      VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());      assert(VType && "MatrixVal must be a vector type"); -    assert(VType->getNumElements() == SI.NumRows * SI.NumColumns && +    assert(cast<FixedVectorType>(VType)->getNumElements() == +               SI.NumRows * SI.NumColumns &&             "The vector size must match the number of matrix elements");      // Check if we lowered MatrixVal using shape information. In that case, -    // return the existing column matrix, if it matches the requested shape +    // return the existing matrix, if it matches the requested shape      // information. If there is a mis-match, embed the result in a flat      // vector and split it later.      auto Found = Inst2ColumnMatrix.find(MatrixVal);      if (Found != Inst2ColumnMatrix.end()) { -      ColumnMatrixTy &M = Found->second; +      MatrixTy &M = Found->second;        // Return the found matrix, if its shape matches the requested shape        // information        if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns()) @@ -247,10 +443,12 @@ public:      // Otherwise split MatrixVal.      SmallVector<Value *, 16> SplitVecs;      Value *Undef = UndefValue::get(VType); -    for (unsigned MaskStart = 0; MaskStart < VType->getNumElements(); -         MaskStart += SI.NumRows) { -      Constant *Mask = createSequentialMask(Builder, MaskStart, SI.NumRows, 0); -      Value *V = Builder.CreateShuffleVector(MatrixVal, Undef, Mask, "split"); +    for (unsigned MaskStart = 0; +         MaskStart < cast<FixedVectorType>(VType)->getNumElements(); +         MaskStart += SI.getStride()) { +      Value *V = Builder.CreateShuffleVector( +          MatrixVal, Undef, createSequentialMask(MaskStart, SI.getStride(), 0), +          "split");        SplitVecs.push_back(V);      } @@ -308,8 +506,8 @@ public:        switch (II->getIntrinsicID()) {        case Intrinsic::matrix_multiply:        case Intrinsic::matrix_transpose: -      case Intrinsic::matrix_columnwise_load: -      case Intrinsic::matrix_columnwise_store: +      case Intrinsic::matrix_column_major_load: +      case Intrinsic::matrix_column_major_store:          return true;        default:          return false; @@ -348,13 +546,13 @@ public:                                   m_Value(MatrixA), m_Value(M), m_Value(N)))) {          // Flip dimensions.          Propagate = setShapeInfo(Inst, {N, M}); -      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_columnwise_store>( +      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(                                   m_Value(MatrixA), m_Value(), m_Value(), -                                 m_Value(M), m_Value(N)))) { +                                 m_Value(), m_Value(M), m_Value(N)))) {          Propagate = setShapeInfo(Inst, {N, M}); -      } else if (match(Inst, -                       m_Intrinsic<Intrinsic::matrix_columnwise_load>( -                           m_Value(), m_Value(), m_Value(M), m_Value(N)))) { +      } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>( +                                 m_Value(), m_Value(), m_Value(), m_Value(M), +                                 m_Value(N)))) {          Propagate = setShapeInfo(Inst, {M, N});        } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {          auto OpShape = ShapeMap.find(MatrixA); @@ -426,14 +624,14 @@ public:          // Flip dimensions.          if (setShapeInfo(MatrixA, {M, N}))            pushInstruction(MatrixA, WorkList); -      } else if (match(V, m_Intrinsic<Intrinsic::matrix_columnwise_store>( -                              m_Value(MatrixA), m_Value(), m_Value(), +      } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>( +                              m_Value(MatrixA), m_Value(), m_Value(), m_Value(),                                m_Value(M), m_Value(N)))) {          if (setShapeInfo(MatrixA, {M, N})) {            pushInstruction(MatrixA, WorkList);          }        } else if (isa<LoadInst>(V) || -                 match(V, m_Intrinsic<Intrinsic::matrix_columnwise_load>())) { +                 match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {          // Nothing to do, no matrix input.        } else if (isa<StoreInst>(V)) {          // Nothing to do.  We forward-propagated to this so we would just @@ -472,8 +670,8 @@ public:            switch (II->getIntrinsicID()) {            case Intrinsic::matrix_multiply:            case Intrinsic::matrix_transpose: -          case Intrinsic::matrix_columnwise_load: -          case Intrinsic::matrix_columnwise_store: +          case Intrinsic::matrix_column_major_load: +          case Intrinsic::matrix_column_major_store:              WorkList.push_back(&Inst);              break;            default: @@ -487,45 +685,57 @@ public:        }      } -    ReversePostOrderTraversal<Function *> RPOT(&Func);      bool Changed = false; -    for (auto *BB : RPOT) { -      for (Instruction &Inst : make_early_inc_range(*BB)) { -        IRBuilder<> Builder(&Inst); - -        if (CallInst *CInst = dyn_cast<CallInst>(&Inst)) -          Changed |= VisitCallInst(CInst); - -        Value *Op1; -        Value *Op2; -        if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst)) -          Changed |= VisitBinaryOperator(BinOp); -        if (match(&Inst, m_Load(m_Value(Op1)))) -          Changed |= VisitLoad(&Inst, Op1, Builder); -        else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2)))) -          Changed |= VisitStore(&Inst, Op1, Op2, Builder); +    SmallVector<CallInst *, 16> MaybeFusableInsts; +    SmallVector<Instruction *, 16> MatrixInsts; + +    // First, collect all instructions with shape information and candidates for +    // fusion (currently only matrix multiplies). +    ReversePostOrderTraversal<Function *> RPOT(&Func); +    for (auto *BB : RPOT) +      for (Instruction &I : *BB) { +        if (ShapeMap.find(&I) == ShapeMap.end()) +          continue; +        if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>())) +          MaybeFusableInsts.push_back(cast<CallInst>(&I)); +        MatrixInsts.push_back(&I);        } + +    // Second, try to fuse candidates. +    SmallPtrSet<Instruction *, 16> FusedInsts; +    for (CallInst *CI : MaybeFusableInsts) +      LowerMatrixMultiplyFused(CI, FusedInsts); +    Changed = !FusedInsts.empty(); + +    // Third, lower remaining instructions with shape information. +    for (Instruction *Inst : MatrixInsts) { +      if (FusedInsts.count(Inst)) +        continue; + +      IRBuilder<> Builder(Inst); + +      if (CallInst *CInst = dyn_cast<CallInst>(Inst)) +        Changed |= VisitCallInst(CInst); + +      Value *Op1; +      Value *Op2; +      if (auto *BinOp = dyn_cast<BinaryOperator>(Inst)) +        Changed |= VisitBinaryOperator(BinOp); +      if (match(Inst, m_Load(m_Value(Op1)))) +        Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder); +      else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2)))) +        Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);      } +    RemarkGenerator RemarkGen(Inst2ColumnMatrix, ORE, Func); +    RemarkGen.emitRemarks(); +      for (Instruction *Inst : reverse(ToRemove))        Inst->eraseFromParent();      return Changed;    } -  LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType, -                             IRBuilder<> Builder) { -    unsigned Align = DL.getABITypeAlignment(EltType); -    return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load"); -  } - -  StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr, -                               Type *EltType, IRBuilder<> Builder) { -    unsigned Align = DL.getABITypeAlignment(EltType); -    return Builder.CreateAlignedStore(ColumnValue, ColumnPtr, Align); -  } - -    /// Turns \p BasePtr into an elementwise pointer to \p EltType.    Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {      unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace(); @@ -545,11 +755,11 @@ public:      case Intrinsic::matrix_transpose:        LowerTranspose(Inst);        break; -    case Intrinsic::matrix_columnwise_load: -      LowerColumnwiseLoad(Inst); +    case Intrinsic::matrix_column_major_load: +      LowerColumnMajorLoad(Inst);        break; -    case Intrinsic::matrix_columnwise_store: -      LowerColumnwiseStore(Inst); +    case Intrinsic::matrix_column_major_store: +      LowerColumnMajorStore(Inst);        break;      default:        return false; @@ -557,108 +767,200 @@ public:      return true;    } -  void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride, -                 ShapeInfo Shape) { -    IRBuilder<> Builder(Inst); -    auto VType = cast<VectorType>(Inst->getType()); +  /// Compute the alignment for a column/row \p Idx with \p Stride between them. +  /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a +  /// ConstantInt, reduce the initial alignment based on the byte offset. For +  /// non-ConstantInt strides, return the common alignment of the initial +  /// alignment and the element size in bytes. +  Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy, +                         MaybeAlign A) const { +    Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy); +    if (Idx == 0) +      return InitialAlign; + +    TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy); +    if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) { +      uint64_t StrideInBytes = +          ConstStride->getZExtValue() * ElementSizeInBits / 8; +      return commonAlignment(InitialAlign, Idx * StrideInBytes); +    } +    return commonAlignment(InitialAlign, ElementSizeInBits / 8); +  } + +  /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between +  /// vectors. +  MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride, +                      bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) { +    auto VType = cast<VectorType>(Ty);      Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); -    ColumnMatrixTy Result; -    // Distance between start of one column and the start of the next -    for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) { -      Value *GEP = -          computeColumnAddr(EltPtr, Builder.getInt32(C), Stride, Shape.NumRows, -                            VType->getElementType(), Builder); -      Value *Column = createColumnLoad(GEP, VType->getElementType(), Builder); -      Result.addColumn(Column); +    MatrixTy Result; +    for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) { +      Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride, +                                     Shape.getStride(), VType->getElementType(), +                                     Builder); +      Value *Vector = Builder.CreateAlignedLoad( +          GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign), +          IsVolatile, "col.load"); + +      Result.addVector(Vector);      } +    return Result.addNumLoads(getNumOps(Result.getVectorTy()) * +                              Result.getNumVectors()); +  } -    finalizeLowering(Inst, Result, Builder); +  /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix, +  /// starting at \p MatrixPtr[I][J]. +  MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile, +                      ShapeInfo MatrixShape, Value *I, Value *J, +                      ShapeInfo ResultShape, Type *EltTy, +                      IRBuilder<> &Builder) { + +    Value *Offset = Builder.CreateAdd( +        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + +    unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); +    Value *EltPtr = +        Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); +    Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); +    auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows * +                                                   ResultShape.NumColumns); +    Type *TilePtrTy = PointerType::get(TileTy, AS); +    Value *TilePtr = +        Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + +    return loadMatrix(TileTy, TilePtr, Align, +                      Builder.getInt64(MatrixShape.getStride()), IsVolatile, +                      ResultShape, Builder); +  } + +  /// Lower a load instruction with shape information. +  void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride, +                 bool IsVolatile, ShapeInfo Shape) { +    IRBuilder<> Builder(Inst); +    finalizeLowering(Inst, +                     loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile, +                                Shape, Builder), +                     Builder);    } -  /// Lowers llvm.matrix.columnwise.load. +  /// Lowers llvm.matrix.column.major.load.    ///    /// The intrinsic loads a matrix from memory using a stride between columns. -  void LowerColumnwiseLoad(CallInst *Inst) { +  void LowerColumnMajorLoad(CallInst *Inst) { +    assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && +           "Intrinsic only supports column-major layout!");      Value *Ptr = Inst->getArgOperand(0);      Value *Stride = Inst->getArgOperand(1); -    LowerLoad(Inst, Ptr, Stride, -              {Inst->getArgOperand(2), Inst->getArgOperand(3)}); +    LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride, +              cast<ConstantInt>(Inst->getArgOperand(2))->isOne(), +              {Inst->getArgOperand(3), Inst->getArgOperand(4)});    } -  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride, -                  ShapeInfo Shape) { -    IRBuilder<> Builder(Inst); -    auto VType = cast<VectorType>(Matrix->getType()); +  /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p +  /// MatrixPtr[I][J]. +  void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr, +                   MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape, +                   Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) { +    Value *Offset = Builder.CreateAdd( +        Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I); + +    unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace(); +    Value *EltPtr = +        Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS)); +    Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset); +    auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() * +                                                   StoreVal.getNumColumns()); +    Type *TilePtrTy = PointerType::get(TileTy, AS); +    Value *TilePtr = +        Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast"); + +    storeMatrix(TileTy, StoreVal, TilePtr, MAlign, +                Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder); +  } + +  /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between +  /// vectors. +  MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr, +                       MaybeAlign MAlign, Value *Stride, bool IsVolatile, +                       IRBuilder<> &Builder) { +    auto VType = cast<VectorType>(Ty);      Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder); -    auto LM = getMatrix(Matrix, Shape, Builder); -    for (auto C : enumerate(LM.columns())) { -      Value *GEP = -          computeColumnAddr(EltPtr, Builder.getInt32(C.index()), Stride, -                            Shape.NumRows, VType->getElementType(), Builder); -      createColumnStore(C.value(), GEP, VType->getElementType(), Builder); +    for (auto Vec : enumerate(StoreVal.vectors())) { +      Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()), +                                     Stride, StoreVal.getStride(), +                                     VType->getElementType(), Builder); +      Builder.CreateAlignedStore(Vec.value(), GEP, +                                 getAlignForIndex(Vec.index(), Stride, +                                                  VType->getElementType(), +                                                  MAlign), +                                 IsVolatile);      } +    return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) * +                                   StoreVal.getNumVectors()); +  } -    ToRemove.push_back(Inst); +  /// Lower a store instruction with shape information. +  void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A, +                  Value *Stride, bool IsVolatile, ShapeInfo Shape) { +    IRBuilder<> Builder(Inst); +    auto StoreVal = getMatrix(Matrix, Shape, Builder); +    finalizeLowering(Inst, +                     storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride, +                                 IsVolatile, Builder), +                     Builder);    } -  /// Lowers llvm.matrix.columnwise.store. +  /// Lowers llvm.matrix.column.major.store.    ///    /// The intrinsic store a matrix back memory using a stride between columns. -  void LowerColumnwiseStore(CallInst *Inst) { +  void LowerColumnMajorStore(CallInst *Inst) { +    assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && +           "Intrinsic only supports column-major layout!");      Value *Matrix = Inst->getArgOperand(0);      Value *Ptr = Inst->getArgOperand(1);      Value *Stride = Inst->getArgOperand(2); -    LowerStore(Inst, Matrix, Ptr, Stride, -               {Inst->getArgOperand(3), Inst->getArgOperand(4)}); -  } - -  /// Extract a column vector of \p NumElts starting at index (\p I, \p J) from -  /// the matrix \p LM represented as a vector of column vectors. -  Value *extractVector(const ColumnMatrixTy &LM, unsigned I, unsigned J, -                       unsigned NumElts, IRBuilder<> Builder) { -    Value *Col = LM.getColumn(J); -    Value *Undef = UndefValue::get(Col->getType()); -    Constant *Mask = createSequentialMask(Builder, I, NumElts, 0); -    return Builder.CreateShuffleVector(Col, Undef, Mask, "block"); +    LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride, +               cast<ConstantInt>(Inst->getArgOperand(3))->isOne(), +               {Inst->getArgOperand(4), Inst->getArgOperand(5)});    }    // Set elements I..I+NumElts-1 to Block    Value *insertVector(Value *Col, unsigned I, Value *Block, -                      IRBuilder<> Builder) { +                      IRBuilder<> &Builder) {      // First, bring Block to the same size as Col      unsigned BlockNumElts = -        cast<VectorType>(Block->getType())->getNumElements(); -    unsigned NumElts = cast<VectorType>(Col->getType())->getNumElements(); +        cast<FixedVectorType>(Block->getType())->getNumElements(); +    unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();      assert(NumElts >= BlockNumElts && "Too few elements for current block"); -    Value *ExtendMask = -        createSequentialMask(Builder, 0, BlockNumElts, NumElts - BlockNumElts);      Value *Undef = UndefValue::get(Block->getType()); -    Block = Builder.CreateShuffleVector(Block, Undef, ExtendMask); +    Block = Builder.CreateShuffleVector( +        Block, Undef, +        createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));      // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,      // 8, 4, 5, 6 -    SmallVector<Constant *, 16> Mask; +    SmallVector<int, 16> Mask;      unsigned i;      for (i = 0; i < I; i++) -      Mask.push_back(Builder.getInt32(i)); +      Mask.push_back(i); -    unsigned VecNumElts = cast<VectorType>(Col->getType())->getNumElements(); +    unsigned VecNumElts = +        cast<FixedVectorType>(Col->getType())->getNumElements();      for (; i < I + BlockNumElts; i++) -      Mask.push_back(Builder.getInt32(i - I + VecNumElts)); +      Mask.push_back(i - I + VecNumElts);      for (; i < VecNumElts; i++) -      Mask.push_back(Builder.getInt32(i)); - -    Value *MaskVal = ConstantVector::get(Mask); +      Mask.push_back(i); -    return Builder.CreateShuffleVector(Col, Block, MaskVal); +    return Builder.CreateShuffleVector(Col, Block, Mask);    }    Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, -                      IRBuilder<> &Builder, bool AllowContraction) { - +                      IRBuilder<> &Builder, bool AllowContraction, +                      unsigned &NumComputeOps) { +    NumComputeOps += getNumOps(A->getType());      if (!Sum)        return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); @@ -666,14 +968,16 @@ public:        if (AllowContraction) {          // Use fmuladd for floating point operations and let the backend decide          // if that's profitable. -        Value *FMulAdd = Intrinsic::getDeclaration( +        Function *FMulAdd = Intrinsic::getDeclaration(              Func.getParent(), Intrinsic::fmuladd, A->getType());          return Builder.CreateCall(FMulAdd, {A, B, Sum});        } +      NumComputeOps += getNumOps(A->getType());        Value *Mul = Builder.CreateFMul(A, B);        return Builder.CreateFAdd(Sum, Mul);      } +    NumComputeOps += getNumOps(A->getType());      Value *Mul = Builder.CreateMul(A, B);      return Builder.CreateAdd(Sum, Mul);    } @@ -683,7 +987,7 @@ public:    /// cached value when they are lowered. For other users, \p Matrix is    /// flattened and the uses are updated to use it. Also marks \p Inst for    /// deletion. -  void finalizeLowering(Instruction *Inst, ColumnMatrixTy Matrix, +  void finalizeLowering(Instruction *Inst, MatrixTy Matrix,                          IRBuilder<> &Builder) {      Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix)); @@ -699,6 +1003,294 @@ public:      }    } +  /// Compute \p Result += \p A * \p B for input matrices with left-associating +  /// addition. +  void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, +                          const MatrixTy &B, bool AllowContraction, +                          IRBuilder<> &Builder, bool isTiled) { +    const unsigned VF = std::max<unsigned>( +        TTI.getRegisterBitWidth(true) / +            Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(), +        1U); +    unsigned R = Result.getNumRows(); +    unsigned C = Result.getNumColumns(); +    unsigned M = A.getNumColumns(); + +    bool IsFP = Result.getElementType()->isFloatingPointTy(); +    assert(A.isColumnMajor() == B.isColumnMajor() && +           Result.isColumnMajor() == A.isColumnMajor() && +           "operands must agree on matrix layout"); +    unsigned NumComputeOps = 0; +    if (A.isColumnMajor()) { +      // Multiply columns from the first operand with scalars from the second +      // operand. Then move along the K axes and accumulate the columns.  With +      // this the adds can be vectorized without reassociation. +      for (unsigned J = 0; J < C; ++J) { +        unsigned BlockSize = VF; +        // If Result is zero, we don't need to accumulate in the K==0 iteration. +        bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J)); + +        for (unsigned I = 0; I < R; I += BlockSize) { +          // Gradually lower the vectorization factor to cover the remainder. +          while (I + BlockSize > R) +            BlockSize /= 2; + +          Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) +                               : nullptr; +          for (unsigned K = 0; K < M; ++K) { +            Value *L = A.extractVector(I, K, BlockSize, Builder); +            Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); +            Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); +            Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, +                               Result.getElementType()->isFloatingPointTy(), +                               Builder, AllowContraction, NumComputeOps); +          } +          Result.setVector(J, +                           insertVector(Result.getVector(J), I, Sum, Builder)); +        } +      } +    } else { +      // Multiply rows from the second operand with scalars from the first +      // operand. Then move along the K axes and accumulate the rows.  With this +      // the adds can be vectorized without reassociation. +      for (unsigned I = 0; I < R; ++I) { +        unsigned BlockSize = VF; +        bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I)); +        for (unsigned J = 0; J < C; J += BlockSize) { +          // Gradually lower the vectorization factor to cover the remainder. +          while (J + BlockSize > C) +            BlockSize /= 2; + +          Value *Sum = nullptr; +          for (unsigned K = 0; K < M; ++K) { +            Value *R = B.extractVector(K, J, BlockSize, Builder); +            Value *LH = Builder.CreateExtractElement(A.getVector(I), K); +            Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); +            Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, +                               IsFP, Builder, AllowContraction, NumComputeOps); +          } +          Result.setVector(I, +                           insertVector(Result.getVector(I), J, Sum, Builder)); +        } +      } +    } +    Result.addNumComputeOps(NumComputeOps); +  } + +  /// Ensure that the memory in \p Load does not alias \p Store by potentially +  /// copying it to a new location.  This new or otherwise the original location +  /// is returned. +  Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store, +                               CallInst *MatMul) { +    MemoryLocation StoreLoc = MemoryLocation::get(Store); +    MemoryLocation LoadLoc = MemoryLocation::get(Load); + +    AliasResult LdAliased = AA.alias(LoadLoc, StoreLoc); + +    // If we can statically determine noalias we're good. +    if (!LdAliased) +      return Load->getPointerOperand(); + +    // Create code to check if the memory locations of the Load and Store +    // overlap and if they do, copy Load's operand to a new buffer. + +    // First, create  new blocks for 2n part of the check and the copy. +    BasicBlock *Check0 = MatMul->getParent(); +    // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a +    // DT. Manually collect dominator tree updates, to avoid unnecessary work, +    // as we adjust Check0 and Check1's branches. +    SmallVector<DominatorTree::UpdateType, 4> DTUpdates; +    for (BasicBlock *Succ : successors(Check0)) +      DTUpdates.push_back({DT.Delete, Check0, Succ}); + +    BasicBlock *Check1 = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, +                                    nullptr, "alias_cont"); +    BasicBlock *Copy = +        SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, nullptr, "copy"); +    BasicBlock *Fusion = SplitBlock(MatMul->getParent(), MatMul, nullptr, &LI, +                                    nullptr, "no_alias"); + +    // Check if the loaded memory location begins before the end of the store +    // location. If the condition holds, they might overlap, otherwise they are +    // guaranteed to not overlap. +    IRBuilder<> Builder(MatMul); +    Check0->getTerminator()->eraseFromParent(); +    Builder.SetInsertPoint(Check0); +    Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout()); +    Value *StoreBegin = Builder.CreatePtrToInt( +        const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin"); +    Value *StoreEnd = Builder.CreateAdd( +        StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()), +        "store.end", true, true); +    Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr), +                                              IntPtrTy, "load.begin"); +    Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1, +                         Fusion); + +    // Check if the store begins before the end of the load location. If the +    // condition holds, they alias, otherwise they are guaranteed to not +    // overlap. +    Check1->getTerminator()->eraseFromParent(); +    Builder.SetInsertPoint(Check1, Check1->begin()); +    Value *LoadEnd = Builder.CreateAdd( +        LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()), +        "load.end", true, true); +    Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy, +                         Fusion); + +    // Copy load operand to new alloca. +    Builder.SetInsertPoint(Copy, Copy->begin()); +    AllocaInst *NewLd = +        Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace()); +    Builder.CreateMemCpy(NewLd, NewLd->getAlign(), +                         Load->getPointerOperand(), Load->getAlign(), +                         LoadLoc.Size.getValue()); +    Builder.SetInsertPoint(Fusion, Fusion->begin()); +    PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3); +    PHI->addIncoming(Load->getPointerOperand(), Check0); +    PHI->addIncoming(Load->getPointerOperand(), Check1); +    PHI->addIncoming(NewLd, Copy); + +    // Adjust DT. +    DTUpdates.push_back({DT.Insert, Check0, Check1}); +    DTUpdates.push_back({DT.Insert, Check0, Fusion}); +    DTUpdates.push_back({DT.Insert, Check1, Copy}); +    DTUpdates.push_back({DT.Insert, Check1, Fusion}); +    DT.applyUpdates(DTUpdates); +    return PHI; +  } + +  bool isFusionProfitable(CallInst *MatMul) { +    if (ForceFusion) +      return true; + +    ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); +    ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + +    const unsigned R = LShape.NumRows; +    const unsigned C = RShape.NumColumns; +    const unsigned M = LShape.NumColumns; +    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); + +    const unsigned VF = +        std::max<unsigned>(TTI.getRegisterBitWidth(true) / +                               EltType->getPrimitiveSizeInBits().getFixedSize(), +                           1U); + +    // Cost model for tiling +    // +    // For tiling to be beneficial, we need reuse either along the R or +    // the C axis.  We vectorize along the R axis so that means at least +    // 3 elements. +    // TODO: Also consider cost of copying if operands alias. +    if (R <= VF && C == 1) +      return false; +    // Then we need enough elements to exceed the number of vector +    // registers we have.  Note that this is an oversimplification since +    // fusing also takes some extra loads which may exceed the number of +    // reloads necessary. +    unsigned Op0Regs = (R + VF - 1) / VF * M; +    unsigned Op1Regs = (M + VF - 1) / VF * C; +    return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true); +  } + +  MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) { +    MatrixTy Res; +    auto *ColumType = FixedVectorType::get(EltType, R); +    for (unsigned I = 0; I < C; ++I) +      Res.addVector(ConstantAggregateZero::get(ColumType)); +    return Res; +  } + +  void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1, +                      StoreInst *Store, +                      SmallPtrSetImpl<Instruction *> &FusedInsts) { +    assert(MatrixLayout == MatrixLayoutTy::ColumnMajor && +           "Tiling only supported for column-major matrixes at the moment!"); +    if (!isFusionProfitable(MatMul)) +      return; + +    ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3)); +    ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); + +    const unsigned R = LShape.NumRows; +    const unsigned C = RShape.NumColumns; +    const unsigned M = LShape.NumColumns; +    auto *EltType = cast<VectorType>(MatMul->getType())->getElementType(); + +    Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul); +    Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); +    Value *CPtr = Store->getPointerOperand(); + +    bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) && +                                                  MatMul->hasAllowContract()); +    IRBuilder<> Builder(Store); +    for (unsigned J = 0; J < C; J += TileSize) +      for (unsigned I = 0; I < R; I += TileSize) { +        const unsigned TileR = std::min(R - I, unsigned(TileSize)); +        const unsigned TileC = std::min(C - J, unsigned(TileSize)); +        MatrixTy Res = getZeroMatrix(EltType, TileR, TileC); + +        for (unsigned K = 0; K < M; K += TileSize) { +          const unsigned TileM = std::min(M - K, unsigned(TileSize)); +          MatrixTy A = +              loadMatrix(APtr, LoadOp0->getAlign(), LoadOp0->isVolatile(), +                         LShape, Builder.getInt64(I), Builder.getInt64(K), +                         {TileR, TileM}, EltType, Builder); +          MatrixTy B = +              loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), +                         RShape, Builder.getInt64(K), Builder.getInt64(J), +                         {TileM, TileC}, EltType, Builder); +          emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); +        } +        storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, +                    Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); +      } + +    // Mark eliminated instructions as fused and remove them. +    FusedInsts.insert(Store); +    FusedInsts.insert(MatMul); +    Store->eraseFromParent(); +    MatMul->eraseFromParent(); +    if (LoadOp0->hasNUses(0)) { +      FusedInsts.insert(LoadOp0); +      LoadOp0->eraseFromParent(); +    } +    if (LoadOp1->hasNUses(0)) { +      FusedInsts.insert(LoadOp1); +      LoadOp1->eraseFromParent(); +    } +  } + +  /// Try to lower matrix multiply chains by fusing operations. +  /// +  /// Currently we only lower {ld, ld} -> matmul -> st chains. +  // +  /// No need to return a MatrixTy object for the result of the operation, since +  /// the single store user will be lowered as part of this. Instructions that +  /// are completely eliminated by fusion are added to \p FusedInsts. +  void LowerMatrixMultiplyFused(CallInst *MatMul, +                                SmallPtrSetImpl<Instruction *> &FusedInsts) { +    if (!FuseMatrix || !MatMul->hasOneUse() || +        MatrixLayout != MatrixLayoutTy::ColumnMajor) +      return; + +    auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0)); +    auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1)); +    auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin()); +    if (LoadOp0 && LoadOp1 && Store) { +      // The store address must dominate the MatMul instruction, otherwise +      // we create invalid IR. +      // FIXME: See if we can hoist the store address computation. +      auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1)); +      if (AddrI && (!DT.dominates(AddrI, MatMul))) +        return; + +      emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts); +      return; +    } +  } +    /// Lowers llvm.matrix.multiply.    void LowerMultiply(CallInst *MatMul) {      IRBuilder<> Builder(MatMul); @@ -706,97 +1298,80 @@ public:      ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));      ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4)); -    const ColumnMatrixTy &Lhs = -        getMatrix(MatMul->getArgOperand(0), LShape, Builder); -    const ColumnMatrixTy &Rhs = -        getMatrix(MatMul->getArgOperand(1), RShape, Builder); +    const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder); +    const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);      const unsigned R = LShape.NumRows; -    const unsigned M = LShape.NumColumns;      const unsigned C = RShape.NumColumns; -    assert(M == RShape.NumRows); +    assert(LShape.NumColumns == RShape.NumRows);      // Initialize the output -    ColumnMatrixTy Result; -    for (unsigned J = 0; J < C; ++J) -      Result.addColumn(UndefValue::get(VectorType::get(EltType, R))); - -    const unsigned VF = std::max(TTI.getRegisterBitWidth(true) / -                                     EltType->getPrimitiveSizeInBits(), -                                 uint64_t(1)); +    MatrixTy Result(R, C, EltType);      bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&                                                    MatMul->hasAllowContract()); -    // Multiply columns from the first operand with scalars from the second -    // operand.  Then move along the K axes and accumulate the columns.  With -    // this the adds can be vectorized without reassociation. -    for (unsigned J = 0; J < C; ++J) { -      unsigned BlockSize = VF; -      for (unsigned I = 0; I < R; I += BlockSize) { -        // Gradually lower the vectorization factor to cover the remainder. -        while (I + BlockSize > R) -          BlockSize /= 2; - -        Value *Sum = nullptr; -        for (unsigned K = 0; K < M; ++K) { -          Value *L = extractVector(Lhs, I, K, BlockSize, Builder); -          Value *RH = Builder.CreateExtractElement(Rhs.getColumn(J), K); -          Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); -          Sum = createMulAdd(Sum, L, Splat, EltType->isFloatingPointTy(), -                             Builder, AllowContract); -        } -        Result.setColumn(J, insertVector(Result.getColumn(J), I, Sum, Builder)); -      } -    } +    emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);      finalizeLowering(MatMul, Result, Builder);    }    /// Lowers llvm.matrix.transpose.    void LowerTranspose(CallInst *Inst) { -    ColumnMatrixTy Result; +    MatrixTy Result;      IRBuilder<> Builder(Inst);      Value *InputVal = Inst->getArgOperand(0);      VectorType *VectorTy = cast<VectorType>(InputVal->getType());      ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2)); -    ColumnMatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); - -    for (unsigned Row = 0; Row < ArgShape.NumRows; ++Row) { -      // Build a single column vector for this row. First initialize it. -      Value *ResultColumn = UndefValue::get( -          VectorType::get(VectorTy->getElementType(), ArgShape.NumColumns)); - -      // Go through the elements of this row and insert it into the resulting -      // column vector. -      for (auto C : enumerate(InputMatrix.columns())) { -        Value *Elt = Builder.CreateExtractElement(C.value(), Row); -        // We insert at index Column since that is the row index after the -        // transpose. -        ResultColumn = -            Builder.CreateInsertElement(ResultColumn, Elt, C.index()); +    MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder); + +    const unsigned NewNumVecs = +        InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns; +    const unsigned NewNumElts = +        InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows; + +    for (unsigned I = 0; I < NewNumVecs; ++I) { +      // Build a single result vector. First initialize it. +      Value *ResultVector = UndefValue::get( +          FixedVectorType::get(VectorTy->getElementType(), NewNumElts)); +      // Go through the old elements and insert it into the resulting vector. +      for (auto J : enumerate(InputMatrix.vectors())) { +        Value *Elt = Builder.CreateExtractElement(J.value(), I); +        // Row and column indices are transposed. +        ResultVector = +            Builder.CreateInsertElement(ResultVector, Elt, J.index());        } -      Result.addColumn(ResultColumn); +      Result.addVector(ResultVector);      } -    finalizeLowering(Inst, Result, Builder); +    // TODO: Improve estimate of operations needed for transposes. Currently we +    // just count the insertelement/extractelement instructions, but do not +    // account for later simplifications/combines. +    finalizeLowering( +        Inst, +        Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns), +        Builder);    }    /// Lower load instructions, if shape information is available. -  bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) { +  bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {      auto I = ShapeMap.find(Inst);      if (I == ShapeMap.end())        return false; -    LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second); +    LowerLoad(Inst, Ptr, Inst->getAlign(), +              Builder.getInt64(I->second.getStride()), Inst->isVolatile(), +              I->second);      return true;    } -  bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr, +  bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,                    IRBuilder<> &Builder) {      auto I = ShapeMap.find(StoredVal);      if (I == ShapeMap.end())        return false; -    LowerStore(Inst, StoredVal, Ptr, Builder.getInt32(I->second.NumRows), I->second); +    LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(), +               Builder.getInt64(I->second.getStride()), Inst->isVolatile(), +               I->second);      return true;    } @@ -812,12 +1387,15 @@ public:      IRBuilder<> Builder(Inst);      ShapeInfo &Shape = I->second; -    ColumnMatrixTy LoweredLhs = getMatrix(Lhs, Shape, Builder); -    ColumnMatrixTy LoweredRhs = getMatrix(Rhs, Shape, Builder); +    MatrixTy Result; +    MatrixTy A = getMatrix(Lhs, Shape, Builder); +    MatrixTy B = getMatrix(Rhs, Shape, Builder); +    assert(A.isColumnMajor() == B.isColumnMajor() && +           Result.isColumnMajor() == A.isColumnMajor() && +           "operands must agree on matrix layout"); -    // Add each column and store the result back into the opmapping -    ColumnMatrixTy Result; -    auto BuildColumnOp = [&Builder, Inst](Value *LHS, Value *RHS) { +    // Helper to perform binary op on vectors. +    auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {        switch (Inst->getOpcode()) {        case Instruction::Add:          return Builder.CreateAdd(LHS, RHS); @@ -835,20 +1413,462 @@ public:          llvm_unreachable("Unsupported binary operator for matrix");        }      }; -    for (unsigned C = 0; C < Shape.NumColumns; ++C) -      Result.addColumn( -          BuildColumnOp(LoweredLhs.getColumn(C), LoweredRhs.getColumn(C))); -    finalizeLowering(Inst, Result, Builder); +    for (unsigned I = 0; I < Shape.getNumVectors(); ++I) +      Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I))); + +    finalizeLowering(Inst, +                     Result.addNumComputeOps(getNumOps(Result.getVectorTy()) * +                                             Result.getNumVectors()), +                     Builder);      return true;    } + +  /// Helper to linearize a matrix expression tree into a string. Currently +  /// matrix expressions are linarized by starting at an expression leaf and +  /// linearizing bottom up. +  struct ExprLinearizer { +    unsigned LengthToBreak = 100; +    std::string Str; +    raw_string_ostream Stream; +    unsigned LineLength = 0; +    const DataLayout &DL; + +    /// Mapping from instructions to matrixes. It is used to identify +    /// matrix instructions. +    const MapVector<Value *, MatrixTy> &Inst2Matrix; + +    /// Mapping from values to the leaves of all expressions that the value is +    /// part of. +    const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared; + +    /// Set of matrix expressions in the scope of a given DISubprogram. +    const SmallSetVector<Value *, 32> &ExprsInSubprogram; + +    /// Leaf node of the expression to linearize. +    Value *Leaf; + +    /// Used to keep track of sub-expressions that get reused while linearizing +    /// the expression. Re-used sub-expressions are marked as (reused). +    SmallPtrSet<Value *, 8> ReusedExprs; + +    ExprLinearizer(const DataLayout &DL, +                   const MapVector<Value *, MatrixTy> &Inst2Matrix, +                   const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, +                   const SmallSetVector<Value *, 32> &ExprsInSubprogram, +                   Value *Leaf) +        : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared), +          ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {} + +    void indent(unsigned N) { +      LineLength += N; +      for (unsigned i = 0; i < N; i++) +        Stream << " "; +    } + +    void lineBreak() { +      Stream << "\n"; +      LineLength = 0; +    } + +    void maybeIndent(unsigned Indent) { +      if (LineLength >= LengthToBreak) +        lineBreak(); + +      if (LineLength == 0) +        indent(Indent); +    } + +    void write(StringRef S) { +      LineLength += S.size(); +      Stream << S; +    } + +    Value *getUnderlyingObjectThroughLoads(Value *V) { +      if (Value *Ptr = getPointerOperand(V)) +        return getUnderlyingObjectThroughLoads(Ptr); +      else if (V->getType()->isPointerTy()) +        return GetUnderlyingObject(V, DL); +      return V; +    } + +    /// Returns true if \p V is a matrix value in the given subprogram. +    bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); } + +    /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to +    /// \p SS. +    void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) { +      auto M = Inst2Matrix.find(V); +      if (M == Inst2Matrix.end()) +        SS << "unknown"; +      else { +        SS << M->second.getNumRows(); +        SS << "x"; +        SS << M->second.getNumColumns(); +      } +    } + +    /// Write the called function name. Handles calls to llvm.matrix.* +    /// specially: we write the name, followed by the dimensions of the input +    /// matrixes, followed by the scalar type name. +    void writeFnName(CallInst *CI) { +      if (!CI->getCalledFunction()) +        write("<no called fn>"); +      else { +        StringRef Name = CI->getCalledFunction()->getName(); +        if (!Name.startswith("llvm.matrix")) { +          write(Name); +          return; +        } +        IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); +        write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {})) +                  .drop_front(StringRef("llvm.matrix.").size())); +        write("."); +        std::string Tmp = ""; +        raw_string_ostream SS(Tmp); + +        switch (II->getIntrinsicID()) { +        case Intrinsic::matrix_multiply: +          prettyPrintMatrixType(II->getOperand(0), SS); +          SS << "."; +          prettyPrintMatrixType(II->getOperand(1), SS); +          SS << "." << *II->getType()->getScalarType(); +          break; +        case Intrinsic::matrix_transpose: +          prettyPrintMatrixType(II->getOperand(0), SS); +          SS << "." << *II->getType()->getScalarType(); +          break; +        case Intrinsic::matrix_column_major_load: +          prettyPrintMatrixType(II, SS); +          SS << "." << *II->getType()->getScalarType(); +          break; +        case Intrinsic::matrix_column_major_store: +          prettyPrintMatrixType(II->getOperand(0), SS); +          SS << "." << *II->getOperand(0)->getType()->getScalarType(); +          break; +        default: +          llvm_unreachable("Unhandled case"); +        } +        SS.flush(); +        write(Tmp); +      } +    } + +    unsigned getNumShapeArgs(CallInst *CI) const { +      if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { +        switch (II->getIntrinsicID()) { +        case Intrinsic::matrix_multiply: +          return 3; +        case Intrinsic::matrix_transpose: +          return 2; +        case Intrinsic::matrix_column_major_load: +        case Intrinsic::matrix_column_major_store: +          return 3; +        default: +          return 0; +        } +      } +      return 0; +    } + +    /// Special printing for values: for pointers, we print if they refer to an +    /// (function) external address or a stack address, for other values we +    /// either print the constant or "scalar"/"matrix" for other values. +    void write(Value *V) { +      V = getUnderlyingObjectThroughLoads(V); +      if (V->getType()->isPointerTy()) { +        if (isa<AllocaInst>(V)) { +          Stream << "stack addr"; +          LineLength += StringRef("stack addr").size(); +        } else { +          Stream << "addr"; +          LineLength += StringRef("addr").size(); +        } +        if (!V->getName().empty()) { +          Stream << " %" << V->getName() << ""; +          LineLength += V->getName().size() + 2; +        } +        return; +      } + +      std::string Tmp; +      raw_string_ostream TmpStream(Tmp); + +      if (auto *CI = dyn_cast<ConstantInt>(V)) +        TmpStream << CI->getValue(); +      else if (isa<Constant>(V)) +        TmpStream << "constant"; +      else { +        if (isMatrix(V)) +          TmpStream << "matrix"; +        else +          TmpStream << "scalar"; +      } +      TmpStream.flush(); +      Tmp = std::string(StringRef(Tmp).trim()); +      LineLength += Tmp.size(); +      Stream << Tmp; +    } + +    /// Linearize expression \p Expr starting at an indentation of \p Indent. +    /// Expressions that are re-used multiple times are prefixed with (reused) +    /// at the re-used root instruction. +    void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused, +                       bool ParentShared) { +      auto *I = cast<Instruction>(Expr); +      maybeIndent(Indent); +      SmallVector<Value *, 8> Ops; + +      // Is Expr shared with other expression leaves? +      bool ExprShared = false; + +      // Deal with shared subtrees. Mark them as shared, if required. +      if (!ParentShared) { +        auto SI = Shared.find(Expr); +        assert(SI != Shared.end() && SI->second.count(Leaf)); + +        for (Value *S : SI->second) { +          if (S == Leaf) +            continue; +          DebugLoc DL = cast<Instruction>(S)->getDebugLoc(); +          write("shared with remark at line " + std::to_string(DL.getLine()) + +                " column " + std::to_string(DL.getCol()) + " ("); +        } +        ExprShared = SI->second.size() > 1; +      } + +      bool Reused = !ReusedExprs.insert(Expr).second; +      if (Reused && !ParentReused) +        write("(reused) "); + +      if (auto *CI = dyn_cast<CallInst>(I)) { +        writeFnName(CI); + +        Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI)); +      } else if (isa<BitCastInst>(Expr)) { +        // Special case bitcasts, which are used to materialize matrixes from +        // non-matrix ops. +        write("matrix"); +        return; +      } else { +        Ops.append(I->value_op_begin(), I->value_op_end()); +        write(std::string(I->getOpcodeName())); +      } + +      write(std::string("(")); + +      unsigned NumOpsToBreak = 1; +      if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>())) +        NumOpsToBreak = 2; + +      for (Value *Op : Ops) { +        if (Ops.size() > NumOpsToBreak) +          lineBreak(); + +        maybeIndent(Indent + 1); +        if (isMatrix(Op)) +          linearizeExpr(Op, Indent + 1, Reused, ExprShared); +        else +          write(Op); +        if (Op != Ops.back()) +          write(", "); +      } + +      write(")"); +    } + +    const std::string &getResult() { +      Stream.flush(); +      return Str; +    } +  }; + +  /// Generate remarks for matrix operations in a function. To generate remarks +  /// for matrix expressions, the following approach is used: +  /// 1. Use the inlined-at debug information to group matrix operations to the +  ///    DISubprograms they are contained in. +  /// 2. Collect leaves of matrix expressions (done in +  ///    RemarkGenerator::getExpressionLeaves) for each subprogram - expression +  //     mapping.  Leaves are lowered matrix instructions without other matrix +  //     users (like stores) in the current subprogram. +  /// 3. For each leaf, create a remark containing a linearizied version of the +  ///    matrix expression. The expression is linearized by a recursive +  ///    bottom-up traversal of the matrix operands, starting at a leaf. Note +  ///    that multiple leaves can share sub-expressions. Shared subexpressions +  ///    are explicitly marked as shared(). +  struct RemarkGenerator { +    const MapVector<Value *, MatrixTy> &Inst2Matrix; +    OptimizationRemarkEmitter &ORE; +    Function &Func; +    const DataLayout &DL; + +    RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix, +                    OptimizationRemarkEmitter &ORE, Function &Func) +        : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func), +          DL(Func.getParent()->getDataLayout()) {} + +    /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are +    /// instructions in Inst2Matrix returning void or without any users in +    /// \p ExprsInSubprogram. Currently that should only include stores. +    SmallVector<Value *, 4> +    getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) { +      SmallVector<Value *, 4> Leaves; +      for (auto *Expr : ExprsInSubprogram) +        if (Expr->getType()->isVoidTy() || +            !any_of(Expr->users(), [&ExprsInSubprogram](User *U) { +              return ExprsInSubprogram.count(U); +            })) +          Leaves.push_back(Expr); +      return Leaves; +    } + +    /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf +    /// to all visited expressions in \p Shared. Limit the matrix operations to +    /// the ones in \p ExprsInSubprogram. +    void collectSharedInfo(Value *Leaf, Value *V, +                           const SmallSetVector<Value *, 32> &ExprsInSubprogram, +                           DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) { + +      if (!ExprsInSubprogram.count(V)) +        return; + +      auto I = Shared.insert({V, {}}); +      I.first->second.insert(Leaf); + +      for (Value *Op : cast<Instruction>(V)->operand_values()) +        collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared); +      return; +    } + +    /// Calculate the number of exclusive and shared op counts for expression +    /// starting at \p V. Expressions used multiple times are counted once. +    /// Limit the matrix operations to the ones in \p ExprsInSubprogram. +    std::pair<OpInfoTy, OpInfoTy> +    sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs, +               const SmallSetVector<Value *, 32> &ExprsInSubprogram, +               DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const { +      if (!ExprsInSubprogram.count(Root)) +        return {}; + +      // Already counted this expression. Stop. +      if (!ReusedExprs.insert(Root).second) +        return {}; + +      OpInfoTy SharedCount; +      OpInfoTy Count; + +      auto I = Shared.find(Root); +      auto CM = Inst2Matrix.find(Root); +      if (I->second.size() == 1) +        Count = CM->second.getOpInfo(); +      else +        SharedCount = CM->second.getOpInfo(); + +      for (Value *Op : cast<Instruction>(Root)->operand_values()) { +        auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared); +        Count += C.first; +        SharedCount += C.second; +      } +      return {Count, SharedCount}; +    } + +    void emitRemarks() { +      if (!ORE.allowExtraAnalysis(DEBUG_TYPE)) +        return; + +      // Map matrix operations to their containting subprograms, by traversing +      // the inlinedAt chain. If the function does not have a DISubprogram, we +      // only map them to the containing function. +      MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs; +      for (auto &KV : Inst2Matrix) { +        if (Func.getSubprogram()) { +          auto *I = cast<Instruction>(KV.first); +          DILocation *Context = I->getDebugLoc(); +          while (Context) { +            auto I = +                Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}}); +            I.first->second.push_back(KV.first); +            Context = DebugLoc(Context).getInlinedAt(); +          } +        } else { +          auto I = Subprog2Exprs.insert({nullptr, {}}); +          I.first->second.push_back(KV.first); +        } +      } +      for (auto &KV : Subprog2Exprs) { +        SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(), +                                                      KV.second.end()); +        auto Leaves = getExpressionLeaves(ExprsInSubprogram); + +        DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared; +        for (Value *Leaf : Leaves) +          collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared); + +        // Generate remarks for each leaf. +        for (auto *L : Leaves) { + +          DebugLoc Loc = cast<Instruction>(L)->getDebugLoc(); +          DILocation *Context = cast<Instruction>(L)->getDebugLoc(); +          while (Context) { +            if (getSubprogram(Context->getScope()) == KV.first) { +              Loc = Context; +              break; +            } +            Context = DebugLoc(Context).getInlinedAt(); +          } + +          SmallPtrSet<Value *, 8> ReusedExprs; +          OpInfoTy Counts, SharedCounts; +          std::tie(Counts, SharedCounts) = +              sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared); + +          OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc, +                                 cast<Instruction>(L)->getParent()); + +          Rem << "Lowered with "; +          Rem << ore::NV("NumStores", Counts.NumStores) << " stores, " +              << ore::NV("NumLoads", Counts.NumLoads) << " loads, " +              << ore::NV("NumComputeOps", Counts.NumComputeOps) +              << " compute ops"; + +          if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 || +              SharedCounts.NumComputeOps > 0) { +            Rem << ",\nadditionally " +                << ore::NV("NumStores", SharedCounts.NumStores) << " stores, " +                << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, " +                << ore::NV("NumFPOps", SharedCounts.NumComputeOps) +                << " compute ops" +                << " are shared with other expressions"; +          } + +          Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL)); +          ORE.emit(Rem); +        } +      } +    } + +    std::string +    linearize(Value *L, +              const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared, +              const SmallSetVector<Value *, 32> &ExprsInSubprogram, +              const DataLayout &DL) { +      ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L); +      Lin.linearizeExpr(L, 0, false, false); +      return Lin.getResult(); +    } +  };  };  } // namespace  PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,                                                   FunctionAnalysisManager &AM) {    auto &TTI = AM.getResult<TargetIRAnalysis>(F); -  LowerMatrixIntrinsics LMT(F, TTI); +  auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); +  auto &AA = AM.getResult<AAManager>(F); +  auto &DT = AM.getResult<DominatorTreeAnalysis>(F); +  auto &LI = AM.getResult<LoopAnalysis>(F); + +  LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);    if (LMT.Visit()) {      PreservedAnalyses PA;      PA.preserveSet<CFGAnalyses>(); @@ -869,15 +1889,24 @@ public:    }    bool runOnFunction(Function &F) override { -    auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); -    LowerMatrixIntrinsics LMT(F, *TTI); +    auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); +    auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); +    auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); +    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); +    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +    LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);      bool C = LMT.Visit();      return C;    }    void getAnalysisUsage(AnalysisUsage &AU) const override {      AU.addRequired<TargetTransformInfoWrapperPass>(); -    AU.setPreservesCFG(); +    AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); +    AU.addRequired<AAResultsWrapperPass>(); +    AU.addRequired<DominatorTreeWrapperPass>(); +    AU.addPreserved<DominatorTreeWrapperPass>(); +    AU.addRequired<LoopInfoWrapperPass>(); +    AU.addPreserved<LoopInfoWrapperPass>();    }  };  } // namespace @@ -886,6 +1915,10 @@ static const char pass_name[] = "Lower the matrix intrinsics";  char LowerMatrixIntrinsicsLegacyPass::ID = 0;  INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,                        false, false) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)  INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,                      false, false) diff --git a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp index c24fa40860eb..4b4196edc12b 100644 --- a/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -27,7 +27,6 @@  #include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/Argument.h"  #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/DerivedTypes.h" @@ -173,8 +172,8 @@ public:    void addStore(int64_t OffsetFromFirst, StoreInst *SI) {      int64_t StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType()); -    addRange(OffsetFromFirst, StoreSize, -             SI->getPointerOperand(), SI->getAlignment(), SI); +    addRange(OffsetFromFirst, StoreSize, SI->getPointerOperand(), +             SI->getAlign().value(), SI);    }    void addMemSet(int64_t OffsetFromFirst, MemSetInst *MSI) { @@ -387,13 +386,8 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,      // Get the starting pointer of the block.      StartPtr = Range.StartPtr; -    // Determine alignment -    const Align Alignment = DL.getValueOrABITypeAlignment( -        MaybeAlign(Range.Alignment), -        cast<PointerType>(StartPtr->getType())->getElementType()); -      AMemSet = Builder.CreateMemSet(StartPtr, ByteVal, Range.End - Range.Start, -                                   Alignment); +                                   MaybeAlign(Range.Alignment));      LLVM_DEBUG(dbgs() << "Replace stores:\n"; for (Instruction *SI                                                     : Range.TheStores) dbgs()                                                << *SI << '\n'; @@ -413,23 +407,6 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst,    return AMemSet;  } -static Align findStoreAlignment(const DataLayout &DL, const StoreInst *SI) { -  return DL.getValueOrABITypeAlignment(MaybeAlign(SI->getAlignment()), -                                       SI->getOperand(0)->getType()); -} - -static Align findLoadAlignment(const DataLayout &DL, const LoadInst *LI) { -  return DL.getValueOrABITypeAlignment(MaybeAlign(LI->getAlignment()), -                                       LI->getType()); -} - -static Align findCommonAlignment(const DataLayout &DL, const StoreInst *SI, -                                 const LoadInst *LI) { -  Align StoreAlign = findStoreAlignment(DL, SI); -  Align LoadAlign = findLoadAlignment(DL, LI); -  return commonAlignment(StoreAlign, LoadAlign); -} -  // This method try to lift a store instruction before position P.  // It will lift the store and its argument + that anything that  // may alias with these. @@ -585,12 +562,12 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {            Instruction *M;            if (UseMemMove)              M = Builder.CreateMemMove( -                SI->getPointerOperand(), findStoreAlignment(DL, SI), -                LI->getPointerOperand(), findLoadAlignment(DL, LI), Size); +                SI->getPointerOperand(), SI->getAlign(), +                LI->getPointerOperand(), LI->getAlign(), Size);            else              M = Builder.CreateMemCpy( -                SI->getPointerOperand(), findStoreAlignment(DL, SI), -                LI->getPointerOperand(), findLoadAlignment(DL, LI), Size); +                SI->getPointerOperand(), SI->getAlign(), +                LI->getPointerOperand(), LI->getAlign(), Size);            LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => "                              << *M << "\n"); @@ -642,7 +619,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {              LI, SI->getPointerOperand()->stripPointerCasts(),              LI->getPointerOperand()->stripPointerCasts(),              DL.getTypeStoreSize(SI->getOperand(0)->getType()), -            findCommonAlignment(DL, SI, LI).value(), C); +            commonAlignment(SI->getAlign(), LI->getAlign()), C);          if (changed) {            MD->removeInstruction(SI);            SI->eraseFromParent(); @@ -675,11 +652,9 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {      auto *T = V->getType();      if (T->isAggregateType()) {        uint64_t Size = DL.getTypeStoreSize(T); -      const Align MA = -          DL.getValueOrABITypeAlignment(MaybeAlign(SI->getAlignment()), T);        IRBuilder<> Builder(SI); -      auto *M = -          Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, MA); +      auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, +                                     SI->getAlign());        LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); @@ -713,7 +688,7 @@ bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {  /// the call write its result directly into the destination of the memcpy.  bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,                                           Value *cpySrc, uint64_t cpyLen, -                                         unsigned cpyAlign, CallInst *C) { +                                         Align cpyAlign, CallInst *C) {    // The general transformation to keep in mind is    //    //   call @func(..., src, ...) @@ -733,10 +708,6 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,      if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start)        return false; -  // Deliberately get the source and destination with bitcasts stripped away, -  // because we'll need to do type comparisons based on the underlying type. -  CallSite CS(C); -    // Require that src be an alloca.  This simplifies the reasoning considerably.    AllocaInst *srcAlloca = dyn_cast<AllocaInst>(cpySrc);    if (!srcAlloca) @@ -795,9 +766,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,    }    // Check that dest points to memory that is at least as aligned as src. -  unsigned srcAlign = srcAlloca->getAlignment(); -  if (!srcAlign) -    srcAlign = DL.getABITypeAlignment(srcAlloca->getAllocatedType()); +  Align srcAlign = srcAlloca->getAlign();    bool isDestSufficientlyAligned = srcAlign <= cpyAlign;    // If dest is not aligned enough and we can't increase its alignment then    // bail out. @@ -836,8 +805,8 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,    // Check that src isn't captured by the called function since the    // transformation can cause aliasing issues in that case. -  for (unsigned i = 0, e = CS.arg_size(); i != e; ++i) -    if (CS.getArgument(i) == cpySrc && !CS.doesNotCapture(i)) +  for (unsigned ArgI = 0, E = C->arg_size(); ArgI != E; ++ArgI) +    if (C->getArgOperand(ArgI) == cpySrc && !C->doesNotCapture(ArgI))        return false;    // Since we're changing the parameter to the callsite, we need to make sure @@ -864,25 +833,26 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,    if (cpySrc->getType()->getPointerAddressSpace() !=        cpyDest->getType()->getPointerAddressSpace())      return false; -  for (unsigned i = 0; i < CS.arg_size(); ++i) -    if (CS.getArgument(i)->stripPointerCasts() == cpySrc && +  for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI) +    if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc &&          cpySrc->getType()->getPointerAddressSpace() != -        CS.getArgument(i)->getType()->getPointerAddressSpace()) +            C->getArgOperand(ArgI)->getType()->getPointerAddressSpace())        return false;    // All the checks have passed, so do the transformation.    bool changedArgument = false; -  for (unsigned i = 0; i < CS.arg_size(); ++i) -    if (CS.getArgument(i)->stripPointerCasts() == cpySrc) { +  for (unsigned ArgI = 0; ArgI < C->arg_size(); ++ArgI) +    if (C->getArgOperand(ArgI)->stripPointerCasts() == cpySrc) {        Value *Dest = cpySrc->getType() == cpyDest->getType() ?  cpyDest          : CastInst::CreatePointerCast(cpyDest, cpySrc->getType(),                                        cpyDest->getName(), C);        changedArgument = true; -      if (CS.getArgument(i)->getType() == Dest->getType()) -        CS.setArgument(i, Dest); +      if (C->getArgOperand(ArgI)->getType() == Dest->getType()) +        C->setArgOperand(ArgI, Dest);        else -        CS.setArgument(i, CastInst::CreatePointerCast(Dest, -                          CS.getArgument(i)->getType(), Dest->getName(), C)); +        C->setArgOperand(ArgI, CastInst::CreatePointerCast( +                                   Dest, C->getArgOperand(ArgI)->getType(), +                                   Dest->getName(), C));      }    if (!changedArgument) @@ -891,7 +861,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest,    // If the destination wasn't sufficiently aligned then increase its alignment.    if (!isDestSufficientlyAligned) {      assert(isa<AllocaInst>(cpyDest) && "Can only increase alloca alignment!"); -    cast<AllocaInst>(cpyDest)->setAlignment(MaybeAlign(srcAlign)); +    cast<AllocaInst>(cpyDest)->setAlignment(srcAlign);    }    // Drop any cached information about the call, because we may have changed @@ -1127,15 +1097,16 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,  /// B to be a memcpy from X to Z (or potentially a memmove, depending on  /// circumstances). This allows later passes to remove the first memcpy  /// altogether. -bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { +bool MemCpyOptPass::processMemCpy(MemCpyInst *M, BasicBlock::iterator &BBI) {    // We can only optimize non-volatile memcpy's.    if (M->isVolatile()) return false;    // If the source and destination of the memcpy are the same, then zap it.    if (M->getSource() == M->getDest()) { +    ++BBI;      MD->removeInstruction(M);      M->eraseFromParent(); -    return false; +    return true;    }    // If copying from a constant, try to turn the memcpy into a memset. @@ -1176,10 +1147,10 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) {      if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) {        // FIXME: Can we pass in either of dest/src alignment here instead        // of conservatively taking the minimum? -      unsigned Align = MinAlign(M->getDestAlignment(), M->getSourceAlignment()); +      Align Alignment = std::min(M->getDestAlign().valueOrOne(), +                                 M->getSourceAlign().valueOrOne());        if (performCallSlotOptzn(M, M->getDest(), M->getSource(), -                               CopySize->getZExtValue(), Align, -                               C)) { +                               CopySize->getZExtValue(), Alignment, C)) {          MD->removeInstruction(M);          M->eraseFromParent();          return true; @@ -1247,15 +1218,15 @@ bool MemCpyOptPass::processMemMove(MemMoveInst *M) {  }  /// This is called on every byval argument in call sites. -bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { -  const DataLayout &DL = CS.getCaller()->getParent()->getDataLayout(); +bool MemCpyOptPass::processByValArgument(CallBase &CB, unsigned ArgNo) { +  const DataLayout &DL = CB.getCaller()->getParent()->getDataLayout();    // Find out what feeds this byval argument. -  Value *ByValArg = CS.getArgument(ArgNo); +  Value *ByValArg = CB.getArgOperand(ArgNo);    Type *ByValTy = cast<PointerType>(ByValArg->getType())->getElementType();    uint64_t ByValSize = DL.getTypeAllocSize(ByValTy);    MemDepResult DepInfo = MD->getPointerDependencyFrom(        MemoryLocation(ByValArg, LocationSize::precise(ByValSize)), true, -      CS.getInstruction()->getIterator(), CS.getInstruction()->getParent()); +      CB.getIterator(), CB.getParent());    if (!DepInfo.isClobber())      return false; @@ -1274,16 +1245,17 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) {    // Get the alignment of the byval.  If the call doesn't specify the alignment,    // then it is some target specific value that we can't know. -  unsigned ByValAlign = CS.getParamAlignment(ArgNo); -  if (ByValAlign == 0) return false; +  MaybeAlign ByValAlign = CB.getParamAlign(ArgNo); +  if (!ByValAlign) return false;    // If it is greater than the memcpy, then we check to see if we can force the    // source of the memcpy to the alignment we need.  If we fail, we bail out.    AssumptionCache &AC = LookupAssumptionCache();    DominatorTree &DT = LookupDomTree(); -  if (MDep->getSourceAlignment() < ByValAlign && -      getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, -                                 CS.getInstruction(), &AC, &DT) < ByValAlign) +  MaybeAlign MemDepAlign = MDep->getSourceAlign(); +  if ((!MemDepAlign || *MemDepAlign < *ByValAlign) && +      getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, &CB, &AC, +                                 &DT) < *ByValAlign)      return false;    // The address space of the memcpy source must match the byval argument @@ -1302,21 +1274,25 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) {    // not just the defining memcpy.    MemDepResult SourceDep = MD->getPointerDependencyFrom(        MemoryLocation::getForSource(MDep), false, -      CS.getInstruction()->getIterator(), MDep->getParent()); +      CB.getIterator(), MDep->getParent());    if (!SourceDep.isClobber() || SourceDep.getInst() != MDep)      return false;    Value *TmpCast = MDep->getSource(); -  if (MDep->getSource()->getType() != ByValArg->getType()) -    TmpCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), -                              "tmpcast", CS.getInstruction()); +  if (MDep->getSource()->getType() != ByValArg->getType()) { +    BitCastInst *TmpBitCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), +                                              "tmpcast", &CB); +    // Set the tmpcast's DebugLoc to MDep's +    TmpBitCast->setDebugLoc(MDep->getDebugLoc()); +    TmpCast = TmpBitCast; +  }    LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n"                      << "  " << *MDep << "\n" -                    << "  " << *CS.getInstruction() << "\n"); +                    << "  " << CB << "\n");    // Otherwise we're good!  Update the byval argument. -  CS.setArgument(ArgNo, TmpCast); +  CB.setArgOperand(ArgNo, TmpCast);    ++NumMemCpyInstr;    return true;  } @@ -1347,13 +1323,13 @@ bool MemCpyOptPass::iterateOnFunction(Function &F) {        else if (MemSetInst *M = dyn_cast<MemSetInst>(I))          RepeatInstruction = processMemSet(M, BI);        else if (MemCpyInst *M = dyn_cast<MemCpyInst>(I)) -        RepeatInstruction = processMemCpy(M); +        RepeatInstruction = processMemCpy(M, BI);        else if (MemMoveInst *M = dyn_cast<MemMoveInst>(I))          RepeatInstruction = processMemMove(M); -      else if (auto CS = CallSite(I)) { -        for (unsigned i = 0, e = CS.arg_size(); i != e; ++i) -          if (CS.isByValArgument(i)) -            MadeChange |= processByValArgument(CS, i); +      else if (auto *CB = dyn_cast<CallBase>(I)) { +        for (unsigned i = 0, e = CB->arg_size(); i != e; ++i) +          if (CB->isByValArgument(i)) +            MadeChange |= processByValArgument(*CB, i);        }        // Reprocess the instruction if desired. diff --git a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 6b0d0202d9bb..69aa0cebe170 100644 --- a/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -354,15 +354,11 @@ bool MergedLoadStoreMotion::run(Function &F, AliasAnalysis &AA) {    // optimization opportunities.    // This loop doesn't care about newly inserted/split blocks     // since they never will be diamond heads. -  for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE;) { -    BasicBlock *BB = &*FI++; - +  for (BasicBlock &BB : make_early_inc_range(F))      // Hoist equivalent loads and sink stores      // outside diamonds when possible -    if (isDiamondHead(BB)) { -      Changed |= mergeStores(BB); -    } -  } +    if (isDiamondHead(&BB)) +      Changed |= mergeStores(&BB);    return Changed;  } diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index bba9082e31b2..4e010f8704d0 100644 --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -213,7 +213,7 @@ bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_,    return Changed;  } -// Whitelist the instruction types NaryReassociate handles for now. +// Explicitly list the instruction types NaryReassociate handles for now.  static bool isPotentiallyNaryReassociable(Instruction *I) {    switch (I->getOpcode()) {    case Instruction::Add: diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp index 6a643480f312..0ed1773373a7 100644 --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -106,6 +106,7 @@  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Scalar/GVNExpression.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"  #include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/PredicateInfo.h"  #include "llvm/Transforms/Utils/VNCoercion.h" @@ -495,6 +496,7 @@ class NewGVN {    AliasAnalysis *AA = nullptr;    MemorySSA *MSSA = nullptr;    MemorySSAWalker *MSSAWalker = nullptr; +  AssumptionCache *AC = nullptr;    const DataLayout &DL;    std::unique_ptr<PredicateInfo> PredInfo; @@ -658,7 +660,7 @@ public:    NewGVN(Function &F, DominatorTree *DT, AssumptionCache *AC,           TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA,           const DataLayout &DL) -      : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), +      : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), AC(AC), DL(DL),          PredInfo(std::make_unique<PredicateInfo>(F, *DT, *AC)),          SQ(DL, TLI, DT, AC, /*CtxI=*/nullptr, /*UseInstrInfo=*/false) {} @@ -898,7 +900,7 @@ bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const {  #ifndef NDEBUG  static std::string getBlockName(const BasicBlock *B) { -  return DOTGraphTraits<const Function *>::getSimpleNodeLabel(B, nullptr); +  return DOTGraphTraits<DOTFuncInfo *>::getSimpleNodeLabel(B, nullptr);  }  #endif @@ -1334,8 +1336,6 @@ LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp,    // Give store and loads same opcode so they value number together.    E->setOpcode(0);    E->op_push_back(PointerOp); -  if (LI) -    E->setAlignment(MaybeAlign(LI->getAlignment()));    // TODO: Value number heap versions. We may be able to discover    // things alias analysis can't on it's own (IE that a store and a @@ -1470,7 +1470,8 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr,    // undef value.  This can happen when loading for a fresh allocation with no    // intervening stores, for example.  Note that this is only true in the case    // that the result of the allocation is pointer equal to the load ptr. -  if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) { +  if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || +      isAlignedAllocLikeFn(DepInst, TLI)) {      return createConstantExpression(UndefValue::get(LoadType));    }    // If this load occurs either right after a lifetime begin, @@ -2030,10 +2031,12 @@ NewGVN::performSymbolicEvaluation(Value *V,      case Instruction::Select:      case Instruction::ExtractElement:      case Instruction::InsertElement: -    case Instruction::ShuffleVector:      case Instruction::GetElementPtr:        E = createExpression(I);        break; +    case Instruction::ShuffleVector: +      // FIXME: Add support for shufflevector to createExpression. +      return nullptr;      default:        return nullptr;      } @@ -3433,7 +3436,7 @@ bool NewGVN::runGVN() {    // Sort dominator tree children arrays into RPO.    for (auto &B : RPOT) {      auto *Node = DT->getNode(B); -    if (Node->getChildren().size() > 1) +    if (Node->getNumChildren() > 1)        llvm::sort(Node->begin(), Node->end(),                   [&](const DomTreeNode *A, const DomTreeNode *B) {                     return RPOOrdering[A] < RPOOrdering[B]; @@ -3693,6 +3696,7 @@ void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) {        Inst.replaceAllUsesWith(UndefValue::get(Inst.getType()));      if (isa<LandingPadInst>(Inst))        continue; +    salvageKnowledge(&Inst, AC);      Inst.eraseFromParent();      ++NumGVNInstrDeleted; diff --git a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp index 5c4a89977c38..4553b23532f2 100644 --- a/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -189,7 +189,8 @@ static bool needsStatepoint(CallBase *Call, const TargetLibraryInfo &TLI) {        return false;    } -  return !(isStatepoint(Call) || isGCRelocate(Call) || isGCResult(Call)); +  return !(isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) || +           isa<GCResultInst>(Call));  }  /// Returns true if this loop is known to contain a call safepoint which @@ -650,7 +651,7 @@ InsertSafepointPoll(Instruction *InsertBefore,    // Do the actual inlining    InlineFunctionInfo IFI; -  bool InlineStatus = InlineFunction(PollCall, IFI); +  bool InlineStatus = InlineFunction(*PollCall, IFI).isSuccess();    assert(InlineStatus && "inline must succeed");    (void)InlineStatus; // suppress warning in release-asserts diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp index 41940e980faa..ba7f367267fe 100644 --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -29,6 +29,7 @@  #include "llvm/ADT/SmallSet.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BasicAliasAnalysis.h"  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/Argument.h" @@ -254,15 +255,15 @@ static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name,    }  } -static BinaryOperator *CreateNeg(Value *S1, const Twine &Name, -                                 Instruction *InsertBefore, Value *FlagsOp) { +static Instruction *CreateNeg(Value *S1, const Twine &Name, +                              Instruction *InsertBefore, Value *FlagsOp) {    if (S1->getType()->isIntOrIntVectorTy())      return BinaryOperator::CreateNeg(S1, Name, InsertBefore); -  else { -    BinaryOperator *Res = BinaryOperator::CreateFNeg(S1, Name, InsertBefore); -    Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); -    return Res; -  } + +  if (auto *FMFSource = dyn_cast<Instruction>(FlagsOp)) +    return UnaryOperator::CreateFNegFMF(S1, FMFSource, Name, InsertBefore); + +  return UnaryOperator::CreateFNeg(S1, Name, InsertBefore);  }  /// Replace 0-X with X*-1. @@ -914,7 +915,7 @@ static Value *NegateValue(Value *V, Instruction *BI,    // Insert a 'neg' instruction that subtracts the value from zero to get the    // negation. -  BinaryOperator *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI); +  Instruction *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI);    ToRedo.insert(NewNeg);    return NewNeg;  } @@ -975,7 +976,8 @@ static BinaryOperator *BreakUpSubtract(Instruction *Sub,  /// this into a multiply by a constant to assist with further reassociation.  static BinaryOperator *ConvertShiftToMul(Instruction *Shl) {    Constant *MulCst = ConstantInt::get(Shl->getType(), 1); -  MulCst = ConstantExpr::getShl(MulCst, cast<Constant>(Shl->getOperand(1))); +  auto *SA = cast<ConstantInt>(Shl->getOperand(1)); +  MulCst = ConstantExpr::getShl(MulCst, SA);    BinaryOperator *Mul =      BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl); @@ -988,10 +990,12 @@ static BinaryOperator *ConvertShiftToMul(Instruction *Shl) {    // We can safely preserve the nuw flag in all cases.  It's also safe to turn a    // nuw nsw shl into a nuw nsw mul.  However, nsw in isolation requires special -  // handling. +  // handling.  It can be preserved as long as we're not left shifting by +  // bitwidth - 1.    bool NSW = cast<BinaryOperator>(Shl)->hasNoSignedWrap();    bool NUW = cast<BinaryOperator>(Shl)->hasNoUnsignedWrap(); -  if (NSW && NUW) +  unsigned BitWidth = Shl->getType()->getIntegerBitWidth(); +  if (NSW && (NUW || SA->getValue().ult(BitWidth - 1)))      Mul->setHasNoSignedWrap(true);    Mul->setHasNoUnsignedWrap(NUW);    return Mul; @@ -1076,7 +1080,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {          const APFloat &F1 = FC1->getValueAPF();          APFloat F2(FC2->getValueAPF());          F2.changeSign(); -        if (F1.compare(F2) == APFloat::cmpEqual) { +        if (F1 == F2) {            FoundFactor = NeedsNegate = true;            Factors.erase(Factors.begin() + i);            break; @@ -1721,7 +1725,7 @@ static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,  }  /// Build a tree of multiplies, computing the product of Ops. -static Value *buildMultiplyTree(IRBuilder<> &Builder, +static Value *buildMultiplyTree(IRBuilderBase &Builder,                                  SmallVectorImpl<Value*> &Ops) {    if (Ops.size() == 1)      return Ops.back(); @@ -1744,7 +1748,7 @@ static Value *buildMultiplyTree(IRBuilder<> &Builder,  /// DAG of multiplies to compute the final product, and return that product  /// value.  Value * -ReassociatePass::buildMinimalMultiplyDAG(IRBuilder<> &Builder, +ReassociatePass::buildMinimalMultiplyDAG(IRBuilderBase &Builder,                                           SmallVectorImpl<Factor> &Factors) {    assert(Factors[0].Power);    SmallVector<Value *, 4> OuterProduct; @@ -1899,7 +1903,7 @@ void ReassociatePass::RecursivelyEraseDeadInsts(Instruction *I,    ValueRankMap.erase(I);    Insts.remove(I);    RedoInsts.remove(I); -  llvm::salvageDebugInfoOrMarkUndef(*I); +  llvm::salvageDebugInfo(*I);    I->eraseFromParent();    for (auto Op : Ops)      if (Instruction *OpInst = dyn_cast<Instruction>(Op)) @@ -1916,7 +1920,7 @@ void ReassociatePass::EraseInst(Instruction *I) {    // Erase the dead instruction.    ValueRankMap.erase(I);    RedoInsts.remove(I); -  llvm::salvageDebugInfoOrMarkUndef(*I); +  llvm::salvageDebugInfo(*I);    I->eraseFromParent();    // Optimize its operands.    SmallPtrSet<Instruction *, 8> Visited; // Detect self-referential nodes. @@ -2457,6 +2461,8 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) {    if (MadeChange) {      PreservedAnalyses PA;      PA.preserveSet<CFGAnalyses>(); +    PA.preserve<AAManager>(); +    PA.preserve<BasicAA>();      PA.preserve<GlobalsAA>();      return PA;    } @@ -2487,6 +2493,8 @@ namespace {      void getAnalysisUsage(AnalysisUsage &AU) const override {        AU.setPreservesCFG(); +      AU.addPreserved<AAResultsWrapperPass>(); +      AU.addPreserved<BasicAAWrapperPass>();        AU.addPreserved<GlobalsAAWrapperPass>();      }    }; diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index b242f100faff..dc2ad14ae61e 100644 --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -271,7 +271,7 @@ struct PartiallyConstructedSafepointRecord {    /// The *new* gc.statepoint instruction itself.  This produces the token    /// that normal path gc.relocates and the gc.result are tied to. -  Instruction *StatepointToken; +  GCStatepointInst *StatepointToken;    /// Instruction to which exceptional gc relocates are attached    /// Makes it easier to iterate through them during relocationViaAlloca. @@ -381,14 +381,19 @@ static void analyzeParsePointLiveness(        dbgs() << " " << V->getName() << " " << *V << "\n";    }    if (PrintLiveSetSize) { -    dbgs() << "Safepoint For: " << Call->getCalledValue()->getName() << "\n"; +    dbgs() << "Safepoint For: " << Call->getCalledOperand()->getName() << "\n";      dbgs() << "Number live values: " << LiveSet.size() << "\n";    }    Result.LiveSet = LiveSet;  } +// Returns true is V is a knownBaseResult.  static bool isKnownBaseResult(Value *V); +// Returns true if V is a BaseResult that already exists in the IR, i.e. it is +// not created by the findBasePointers algorithm. +static bool isOriginalBaseResult(Value *V); +  namespace {  /// A single base defining value - An immediate base defining value for an @@ -633,15 +638,20 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {    return Def;  } +/// This value is a base pointer that is not generated by RS4GC, i.e. it already +/// exists in the code. +static bool isOriginalBaseResult(Value *V) { +  // no recursion possible +  return !isa<PHINode>(V) && !isa<SelectInst>(V) && +         !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) && +         !isa<ShuffleVectorInst>(V); +} +  /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,  /// is it known to be a base pointer?  Or do we need to continue searching.  static bool isKnownBaseResult(Value *V) { -  if (!isa<PHINode>(V) && !isa<SelectInst>(V) && -      !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) && -      !isa<ShuffleVectorInst>(V)) { -    // no recursion possible +  if (isOriginalBaseResult(V))      return true; -  }    if (isa<Instruction>(V) &&        cast<Instruction>(V)->getMetadata("is_base_value")) {      // This is a previously inserted base phi or select.  We know @@ -653,6 +663,12 @@ static bool isKnownBaseResult(Value *V) {    return false;  } +// Returns true if First and Second values are both scalar or both vector. +static bool areBothVectorOrScalar(Value *First, Value *Second) { +  return isa<VectorType>(First->getType()) == +         isa<VectorType>(Second->getType()); +} +  namespace {  /// Models the state of a single base defining value in the findBasePointer @@ -762,7 +778,7 @@ static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) {  static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {    Value *Def = findBaseOrBDV(I, Cache); -  if (isKnownBaseResult(Def)) +  if (isKnownBaseResult(Def) && areBothVectorOrScalar(Def, I))      return Def;    // Here's the rough algorithm: @@ -810,13 +826,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {      States.insert({Def, BDVState()});      while (!Worklist.empty()) {        Value *Current = Worklist.pop_back_val(); -      assert(!isKnownBaseResult(Current) && "why did it get added?"); +      assert(!isOriginalBaseResult(Current) && "why did it get added?");        auto visitIncomingValue = [&](Value *InVal) {          Value *Base = findBaseOrBDV(InVal, Cache); -        if (isKnownBaseResult(Base)) +        if (isKnownBaseResult(Base) && areBothVectorOrScalar(Base, InVal))            // Known bases won't need new instructions introduced and can be -          // ignored safely +          // ignored safely. However, this can only be done when InVal and Base +          // are both scalar or both vector. Otherwise, we need to find a +          // correct BDV for InVal, by creating an entry in the lattice +          // (States).            return;          assert(isExpectedBDVType(Base) && "the only non-base values "                 "we see should be base defining values"); @@ -853,10 +872,10 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {    // Return a phi state for a base defining value.  We'll generate a new    // base state for known bases and expect to find a cached state otherwise. -  auto getStateForBDV = [&](Value *baseValue) { -    if (isKnownBaseResult(baseValue)) -      return BDVState(baseValue); -    auto I = States.find(baseValue); +  auto GetStateForBDV = [&](Value *BaseValue, Value *Input) { +    if (isKnownBaseResult(BaseValue) && areBothVectorOrScalar(BaseValue, Input)) +      return BDVState(BaseValue); +    auto I = States.find(BaseValue);      assert(I != States.end() && "lookup failed!");      return I->second;    }; @@ -873,13 +892,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {      // much faster.      for (auto Pair : States) {        Value *BDV = Pair.first; -      assert(!isKnownBaseResult(BDV) && "why did it get added?"); +      // Only values that do not have known bases or those that have differing +      // type (scalar versus vector) from a possible known base should be in the +      // lattice. +      assert((!isKnownBaseResult(BDV) || +             !areBothVectorOrScalar(BDV, Pair.second.getBaseValue())) && +                 "why did it get added?");        // Given an input value for the current instruction, return a BDVState        // instance which represents the BDV of that value.        auto getStateForInput = [&](Value *V) mutable {          Value *BDV = findBaseOrBDV(V, Cache); -        return getStateForBDV(BDV); +        return GetStateForBDV(BDV, V);        };        BDVState NewState; @@ -926,20 +950,26 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {    }  #endif -  // Insert Phis for all conflicts -  // TODO: adjust naming patterns to avoid this order of iteration dependency +  // Handle all instructions that have a vector BDV, but the instruction itself +  // is of scalar type.    for (auto Pair : States) {      Instruction *I = cast<Instruction>(Pair.first);      BDVState State = Pair.second; -    assert(!isKnownBaseResult(I) && "why did it get added?"); +    auto *BaseValue = State.getBaseValue(); +    // Only values that do not have known bases or those that have differing +    // type (scalar versus vector) from a possible known base should be in the +    // lattice. +    assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, BaseValue)) && +           "why did it get added?");      assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); +    if (!State.isBase() || !isa<VectorType>(BaseValue->getType())) +      continue;      // extractelement instructions are a bit special in that we may need to      // insert an extract even when we know an exact base for the instruction.      // The problem is that we need to convert from a vector base to a scalar      // base for the particular indice we're interested in. -    if (State.isBase() && isa<ExtractElementInst>(I) && -        isa<VectorType>(State.getBaseValue()->getType())) { +    if (isa<ExtractElementInst>(I)) {        auto *EE = cast<ExtractElementInst>(I);        // TODO: In many cases, the new instruction is just EE itself.  We should        // exploit this, but can't do it here since it would break the invariant @@ -948,7 +978,27 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {            State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE);        BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));        States[I] = BDVState(BDVState::Base, BaseInst); +    } else if (!isa<VectorType>(I->getType())) { +      // We need to handle cases that have a vector base but the instruction is +      // a scalar type (these could be phis or selects or any instruction that +      // are of scalar type, but the base can be a vector type).  We +      // conservatively set this as conflict.  Setting the base value for these +      // conflicts is handled in the next loop which traverses States. +      States[I] = BDVState(BDVState::Conflict);      } +  } + +  // Insert Phis for all conflicts +  // TODO: adjust naming patterns to avoid this order of iteration dependency +  for (auto Pair : States) { +    Instruction *I = cast<Instruction>(Pair.first); +    BDVState State = Pair.second; +    // Only values that do not have known bases or those that have differing +    // type (scalar versus vector) from a possible known base should be in the +    // lattice. +    assert((!isKnownBaseResult(I) || !areBothVectorOrScalar(I, State.getBaseValue())) && +           "why did it get added?"); +    assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");      // Since we're joining a vector and scalar base, they can never be the      // same.  As a result, we should always see insert element having reached @@ -987,7 +1037,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {          auto *SV = cast<ShuffleVectorInst>(I);          UndefValue *VecUndef = UndefValue::get(SV->getOperand(0)->getType());          std::string Name = suffixed_name_or(I, ".base", "base_sv"); -        return new ShuffleVectorInst(VecUndef, VecUndef, SV->getOperand(2), +        return new ShuffleVectorInst(VecUndef, VecUndef, SV->getShuffleMask(),                                       Name, SV);        }      }; @@ -1008,7 +1058,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {    auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) {      Value *BDV = findBaseOrBDV(Input, Cache);      Value *Base = nullptr; -    if (isKnownBaseResult(BDV)) { +    if (isKnownBaseResult(BDV) && areBothVectorOrScalar(BDV, Input)) {        Base = BDV;      } else {        // Either conflict or base. @@ -1029,7 +1079,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {      Instruction *BDV = cast<Instruction>(Pair.first);      BDVState State = Pair.second; -    assert(!isKnownBaseResult(BDV) && "why did it get added?"); +    // Only values that do not have known bases or those that have differing +    // type (scalar versus vector) from a possible known base should be in the +    // lattice. +    assert((!isKnownBaseResult(BDV) || +            !areBothVectorOrScalar(BDV, State.getBaseValue())) && +           "why did it get added?");      assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");      if (!State.isConflict())        continue; @@ -1119,7 +1174,11 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) {      auto *BDV = Pair.first;      Value *Base = Pair.second.getBaseValue();      assert(BDV && Base); -    assert(!isKnownBaseResult(BDV) && "why did it get added?"); +    // Only values that do not have known bases or those that have differing +    // type (scalar versus vector) from a possible known base should be in the +    // lattice. +    assert((!isKnownBaseResult(BDV) || !areBothVectorOrScalar(BDV, Base)) && +           "why did it get added?");      LLVM_DEBUG(          dbgs() << "Updating base value cache" @@ -1238,7 +1297,8 @@ normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent,  // Create new attribute set containing only attributes which can be transferred  // from original call to the safepoint. -static AttributeList legalizeCallAttributes(AttributeList AL) { +static AttributeList legalizeCallAttributes(LLVMContext &Ctx, +                                            AttributeList AL) {    if (AL.isEmpty())      return AL; @@ -1252,7 +1312,6 @@ static AttributeList legalizeCallAttributes(AttributeList AL) {    }    // Just skip parameter and return attributes for now -  LLVMContext &Ctx = AL.getContext();    return AttributeList::get(Ctx, AttributeList::FunctionIndex,                              AttributeSet::get(Ctx, FnAttrs));  } @@ -1261,16 +1320,14 @@ static AttributeList legalizeCallAttributes(AttributeList AL) {  /// statepoint.  /// Inputs:  ///   liveVariables - list of variables to be relocated. -///   liveStart - index of the first live variable.  ///   basePtrs - base pointers.  ///   statepointToken - statepoint instruction to which relocates should be  ///   bound.  ///   Builder - Llvm IR builder to be used to construct new calls.  static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, -                              const int LiveStart,                                ArrayRef<Value *> BasePtrs,                                Instruction *StatepointToken, -                              IRBuilder<> Builder) { +                              IRBuilder<> &Builder) {    if (LiveVariables.empty())      return; @@ -1295,7 +1352,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,      auto AS = Ty->getScalarType()->getPointerAddressSpace();      Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS);      if (auto *VT = dyn_cast<VectorType>(Ty)) -      NewTy = VectorType::get(NewTy, VT->getNumElements()); +      NewTy = FixedVectorType::get(NewTy, +                                   cast<FixedVectorType>(VT)->getNumElements());      return Intrinsic::getDeclaration(M, Intrinsic::experimental_gc_relocate,                                       {NewTy});    }; @@ -1307,9 +1365,8 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,    for (unsigned i = 0; i < LiveVariables.size(); i++) {      // Generate the gc.relocate call and save the result -    Value *BaseIdx = -      Builder.getInt32(LiveStart + FindIndex(LiveVariables, BasePtrs[i])); -    Value *LiveIdx = Builder.getInt32(LiveStart + i); +    Value *BaseIdx = Builder.getInt32(FindIndex(LiveVariables, BasePtrs[i])); +    Value *LiveIdx = Builder.getInt32(i);      Type *Ty = LiveVariables[i]->getType();      if (!TypeToDeclMap.count(Ty)) @@ -1431,12 +1488,14 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */    uint32_t Flags = uint32_t(StatepointFlags::None);    ArrayRef<Use> CallArgs(Call->arg_begin(), Call->arg_end()); -  ArrayRef<Use> DeoptArgs = GetDeoptBundleOperands(Call); -  ArrayRef<Use> TransitionArgs; -  if (auto TransitionBundle = -          Call->getOperandBundle(LLVMContext::OB_gc_transition)) { +  Optional<ArrayRef<Use>> DeoptArgs; +  if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_deopt)) +    DeoptArgs = Bundle->Inputs; +  Optional<ArrayRef<Use>> TransitionArgs; +  if (auto Bundle = Call->getOperandBundle(LLVMContext::OB_gc_transition)) { +    TransitionArgs = Bundle->Inputs; +    // TODO: This flag no longer serves a purpose and can be removed later      Flags |= uint32_t(StatepointFlags::GCTransition); -    TransitionArgs = TransitionBundle->Inputs;    }    // Instead of lowering calls to @llvm.experimental.deoptimize as normal calls @@ -1459,7 +1518,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */      assert(DeoptLowering.equals("live-through") && "Unsupported value!");    } -  Value *CallTarget = Call->getCalledValue(); +  Value *CallTarget = Call->getCalledOperand();    if (Function *F = dyn_cast<Function>(CallTarget)) {      if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize) {        // Calls to llvm.experimental.deoptimize are lowered to calls to the @@ -1485,7 +1544,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */    }    // Create the statepoint given all the arguments -  Instruction *Token = nullptr; +  GCStatepointInst *Token = nullptr;    if (auto *CI = dyn_cast<CallInst>(Call)) {      CallInst *SPCall = Builder.CreateGCStatepointCall(          StatepointID, NumPatchBytes, CallTarget, Flags, CallArgs, @@ -1498,9 +1557,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */      // function attributes.  In case if we can handle this set of attributes -      // set up function attrs directly on statepoint and return attrs later for      // gc_result intrinsic. -    SPCall->setAttributes(legalizeCallAttributes(CI->getAttributes())); +    SPCall->setAttributes( +        legalizeCallAttributes(CI->getContext(), CI->getAttributes())); -    Token = SPCall; +    Token = cast<GCStatepointInst>(SPCall);      // Put the following gc_result and gc_relocate calls immediately after the      // the old call (which we're about to delete) @@ -1524,9 +1584,10 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */      // function attributes.  In case if we can handle this set of attributes -      // set up function attrs directly on statepoint and return attrs later for      // gc_result intrinsic. -    SPInvoke->setAttributes(legalizeCallAttributes(II->getAttributes())); +    SPInvoke->setAttributes( +        legalizeCallAttributes(II->getContext(), II->getAttributes())); -    Token = SPInvoke; +    Token = cast<GCStatepointInst>(SPInvoke);      // Generate gc relocates in exceptional path      BasicBlock *UnwindBlock = II->getUnwindDest(); @@ -1541,9 +1602,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */      Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst();      Result.UnwindToken = ExceptionalToken; -    const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx(); -    CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, ExceptionalToken, -                      Builder); +    CreateGCRelocates(LiveVariables, BasePtrs, ExceptionalToken, Builder);      // Generate gc relocates and returns for normal block      BasicBlock *NormalDest = II->getNormalDest(); @@ -1589,8 +1648,7 @@ makeStatepointExplicitImpl(CallBase *Call, /* to replace */    Result.StatepointToken = Token;    // Second, create a gc.relocate for every live variable -  const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx(); -  CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, Token, Builder); +  CreateGCRelocates(LiveVariables, BasePtrs, Token, Builder);  }  // Replace an existing gc.statepoint with a new one and a set of gc.relocates @@ -1651,8 +1709,8 @@ insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs,                              cast<AllocaInst>(Alloca)->getAllocatedType(),                              suffixed_name_or(Relocate, ".casted", "")); -    StoreInst *Store = new StoreInst(CastedRelocatedValue, Alloca); -    Store->insertAfter(cast<Instruction>(CastedRelocatedValue)); +    new StoreInst(CastedRelocatedValue, Alloca, +                  cast<Instruction>(CastedRelocatedValue)->getNextNode());  #ifndef NDEBUG      VisitedLiveValues.insert(OriginalValue); @@ -1674,8 +1732,8 @@ static void insertRematerializationStores(             "Can not find alloca for rematerialized value");      Value *Alloca = AllocaMap[OriginalValue]; -    StoreInst *Store = new StoreInst(RematerializedValue, Alloca); -    Store->insertAfter(RematerializedValue); +    new StoreInst(RematerializedValue, Alloca, +                  RematerializedValue->getNextNode());  #ifndef NDEBUG      VisitedLiveValues.insert(OriginalValue); @@ -1780,8 +1838,7 @@ static void relocationViaAlloca(          for (auto *AI : ToClobber) {            auto PT = cast<PointerType>(AI->getAllocatedType());            Constant *CPN = ConstantPointerNull::get(PT); -          StoreInst *Store = new StoreInst(CPN, AI); -          Store->insertBefore(IP); +          new StoreInst(CPN, AI, IP);          }        }; @@ -1843,7 +1900,8 @@ static void relocationViaAlloca(      // Emit store for the initial gc value.  Store must be inserted after load,      // otherwise store will be in alloca's use list and an extra load will be      // inserted before it. -    StoreInst *Store = new StoreInst(Def, Alloca); +    StoreInst *Store = new StoreInst(Def, Alloca, /*volatile*/ false, +                                     DL.getABITypeAlign(Def->getType()));      if (Instruction *Inst = dyn_cast<Instruction>(Def)) {        if (InvokeInst *Invoke = dyn_cast<InvokeInst>(Inst)) {          // InvokeInst is a terminator so the store need to be inserted into its @@ -1966,7 +2024,9 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain,               "non noop cast is found during rematerialization");        Type *SrcTy = CI->getOperand(0)->getType(); -      Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, CI); +      Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, +                                   TargetTransformInfo::TCK_SizeAndLatency, +                                   CI);      } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) {        // Cost of the address calculation @@ -2344,9 +2404,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,      // That Value* no longer exists and we need to use the new gc_result.      // Thankfully, the live set is embedded in the statepoint (and updated), so      // we just grab that. -    Statepoint Statepoint(Info.StatepointToken); -    Live.insert(Live.end(), Statepoint.gc_args_begin(), -                Statepoint.gc_args_end()); +    Live.insert(Live.end(), Info.StatepointToken->gc_args_begin(), +                Info.StatepointToken->gc_args_end());  #ifndef NDEBUG      // Do some basic sanity checks on our liveness results before performing      // relocation.  Relocation can and will turn mistakes in liveness results @@ -2354,7 +2413,7 @@ static bool insertParsePoints(Function &F, DominatorTree &DT,      // TODO: It would be nice to test consistency as well      assert(DT.isReachableFromEntry(Info.StatepointToken->getParent()) &&             "statepoint must be reachable or liveness is meaningless"); -    for (Value *V : Statepoint.gc_args()) { +    for (Value *V : Info.StatepointToken->gc_args()) {        if (!isa<Instruction>(V))          // Non-instruction values trivial dominate all possible uses          continue; @@ -2523,7 +2582,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT,    auto NeedsRewrite = [&TLI](Instruction &I) {      if (const auto *Call = dyn_cast<CallBase>(&I)) -      return !callsGCLeafFunction(Call, TLI) && !isStatepoint(Call); +      return !callsGCLeafFunction(Call, TLI) && !isa<GCStatepointInst>(Call);      return false;    }; @@ -2608,10 +2667,10 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT,      unsigned VF = 0;      for (unsigned i = 0; i < I.getNumOperands(); i++) -      if (I.getOperand(i)->getType()->isVectorTy()) { +      if (auto *OpndVTy = dyn_cast<VectorType>(I.getOperand(i)->getType())) {          assert(VF == 0 || -               VF == I.getOperand(i)->getType()->getVectorNumElements()); -        VF = I.getOperand(i)->getType()->getVectorNumElements(); +               VF == cast<FixedVectorType>(OpndVTy)->getNumElements()); +        VF = cast<FixedVectorType>(OpndVTy)->getNumElements();        }      // It's the vector to scalar traversal through the pointer operand which diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp index e696ea83a300..5ebd3b71fe78 100644 --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -27,12 +27,13 @@  #include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h"  #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/ValueLattice.h"  #include "llvm/Analysis/ValueLatticeUtils.h"  #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h" @@ -67,123 +68,44 @@ using namespace llvm;  STATISTIC(NumInstRemoved, "Number of instructions removed");  STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); +STATISTIC(NumInstReplaced, +          "Number of instructions replaced with (simpler) instruction");  STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP");  STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP");  STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); - +STATISTIC( +    IPNumInstReplaced, +    "Number of instructions replaced with (simpler) instruction by IPSCCP"); + +// The maximum number of range extensions allowed for operations requiring +// widening. +static const unsigned MaxNumRangeExtensions = 10; + +/// Returns MergeOptions with MaxWidenSteps set to MaxNumRangeExtensions. +static ValueLatticeElement::MergeOptions getMaxWidenStepsOpts() { +  return ValueLatticeElement::MergeOptions().setMaxWidenSteps( +      MaxNumRangeExtensions); +}  namespace { -/// LatticeVal class - This class represents the different lattice values that -/// an LLVM value may occupy.  It is a simple class with value semantics. -/// -class LatticeVal { -  enum LatticeValueTy { -    /// unknown - This LLVM Value has no known value yet. -    unknown, - -    /// constant - This LLVM Value has a specific constant value. -    constant, - -    /// forcedconstant - This LLVM Value was thought to be undef until -    /// ResolvedUndefsIn.  This is treated just like 'constant', but if merged -    /// with another (different) constant, it goes to overdefined, instead of -    /// asserting. -    forcedconstant, - -    /// overdefined - This instruction is not known to be constant, and we know -    /// it has a value. -    overdefined -  }; - -  /// Val: This stores the current lattice value along with the Constant* for -  /// the constant if this is a 'constant' or 'forcedconstant' value. -  PointerIntPair<Constant *, 2, LatticeValueTy> Val; - -  LatticeValueTy getLatticeValue() const { -    return Val.getInt(); -  } - -public: -  LatticeVal() : Val(nullptr, unknown) {} - -  bool isUnknown() const { return getLatticeValue() == unknown; } - -  bool isConstant() const { -    return getLatticeValue() == constant || getLatticeValue() == forcedconstant; -  } - -  bool isOverdefined() const { return getLatticeValue() == overdefined; } - -  Constant *getConstant() const { -    assert(isConstant() && "Cannot get the constant of a non-constant!"); -    return Val.getPointer(); -  } - -  /// markOverdefined - Return true if this is a change in status. -  bool markOverdefined() { -    if (isOverdefined()) -      return false; - -    Val.setInt(overdefined); -    return true; -  } - -  /// markConstant - Return true if this is a change in status. -  bool markConstant(Constant *V) { -    if (getLatticeValue() == constant) { // Constant but not forcedconstant. -      assert(getConstant() == V && "Marking constant with different value"); -      return false; -    } - -    if (isUnknown()) { -      Val.setInt(constant); -      assert(V && "Marking constant with NULL"); -      Val.setPointer(V); -    } else { -      assert(getLatticeValue() == forcedconstant && -             "Cannot move from overdefined to constant!"); -      // Stay at forcedconstant if the constant is the same. -      if (V == getConstant()) return false; - -      // Otherwise, we go to overdefined.  Assumptions made based on the -      // forced value are possibly wrong.  Assuming this is another constant -      // could expose a contradiction. -      Val.setInt(overdefined); -    } -    return true; -  } - -  /// getConstantInt - If this is a constant with a ConstantInt value, return it -  /// otherwise return null. -  ConstantInt *getConstantInt() const { -    if (isConstant()) -      return dyn_cast<ConstantInt>(getConstant()); -    return nullptr; -  } - -  /// getBlockAddress - If this is a constant with a BlockAddress value, return -  /// it, otherwise return null. -  BlockAddress *getBlockAddress() const { -    if (isConstant()) -      return dyn_cast<BlockAddress>(getConstant()); -    return nullptr; -  } - -  void markForcedConstant(Constant *V) { -    assert(isUnknown() && "Can't force a defined value!"); -    Val.setInt(forcedconstant); -    Val.setPointer(V); -  } +// Helper to check if \p LV is either a constant or a constant +// range with a single element. This should cover exactly the same cases as the +// old ValueLatticeElement::isConstant() and is intended to be used in the +// transition to ValueLatticeElement. +bool isConstant(const ValueLatticeElement &LV) { +  return LV.isConstant() || +         (LV.isConstantRange() && LV.getConstantRange().isSingleElement()); +} -  ValueLatticeElement toValueLattice() const { -    if (isOverdefined()) -      return ValueLatticeElement::getOverdefined(); -    if (isConstant()) -      return ValueLatticeElement::get(getConstant()); -    return ValueLatticeElement(); -  } -}; +// Helper to check if \p LV is either overdefined or a constant range with more +// than a single element. This should cover exactly the same cases as the old +// ValueLatticeElement::isOverdefined() and is intended to be used in the +// transition to ValueLatticeElement. +bool isOverdefined(const ValueLatticeElement &LV) { +  return LV.isOverdefined() || +         (LV.isConstantRange() && !LV.getConstantRange().isSingleElement()); +}  //===----------------------------------------------------------------------===//  // @@ -194,28 +116,28 @@ class SCCPSolver : public InstVisitor<SCCPSolver> {    const DataLayout &DL;    std::function<const TargetLibraryInfo &(Function &)> GetTLI;    SmallPtrSet<BasicBlock *, 8> BBExecutable; // The BBs that are executable. -  DenseMap<Value *, LatticeVal> ValueState;  // The state each value is in. -  // The state each parameter is in. -  DenseMap<Value *, ValueLatticeElement> ParamState; +  DenseMap<Value *, ValueLatticeElement> +      ValueState; // The state each value is in.    /// StructValueState - This maintains ValueState for values that have    /// StructType, for example for formal arguments, calls, insertelement, etc. -  DenseMap<std::pair<Value *, unsigned>, LatticeVal> StructValueState; +  DenseMap<std::pair<Value *, unsigned>, ValueLatticeElement> StructValueState;    /// GlobalValue - If we are tracking any values for the contents of a global    /// variable, we keep a mapping from the constant accessor to the element of    /// the global, to the currently known value.  If the value becomes    /// overdefined, it's entry is simply removed from this map. -  DenseMap<GlobalVariable *, LatticeVal> TrackedGlobals; +  DenseMap<GlobalVariable *, ValueLatticeElement> TrackedGlobals;    /// TrackedRetVals - If we are tracking arguments into and the return    /// value out of a function, it will have an entry in this map, indicating    /// what the known return value for the function is. -  MapVector<Function *, LatticeVal> TrackedRetVals; +  MapVector<Function *, ValueLatticeElement> TrackedRetVals;    /// TrackedMultipleRetVals - Same as TrackedRetVals, but used for functions    /// that return multiple values. -  MapVector<std::pair<Function *, unsigned>, LatticeVal> TrackedMultipleRetVals; +  MapVector<std::pair<Function *, unsigned>, ValueLatticeElement> +      TrackedMultipleRetVals;    /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is    /// represented here for efficient lookup. @@ -251,6 +173,8 @@ class SCCPSolver : public InstVisitor<SCCPSolver> {    DenseMap<Function *, AnalysisResultsForFn> AnalysisResults;    DenseMap<Value *, SmallPtrSet<User *, 2>> AdditionalUsers; +  LLVMContext &Ctx; +  public:    void addAnalysis(Function &F, AnalysisResultsForFn A) {      AnalysisResults.insert({&F, std::move(A)}); @@ -270,8 +194,9 @@ public:    }    SCCPSolver(const DataLayout &DL, -             std::function<const TargetLibraryInfo &(Function &)> GetTLI) -      : DL(DL), GetTLI(std::move(GetTLI)) {} +             std::function<const TargetLibraryInfo &(Function &)> GetTLI, +             LLVMContext &Ctx) +      : DL(DL), GetTLI(std::move(GetTLI)), Ctx(Ctx) {}    /// MarkBlockExecutable - This method can be used by clients to mark all of    /// the blocks that are known to be intrinsically live in the processed unit. @@ -292,7 +217,7 @@ public:    void TrackValueOfGlobalVariable(GlobalVariable *GV) {      // We only track the contents of scalar globals.      if (GV->getValueType()->isSingleValueType()) { -      LatticeVal &IV = TrackedGlobals[GV]; +      ValueLatticeElement &IV = TrackedGlobals[GV];        if (!isa<UndefValue>(GV->getInitializer()))          IV.markConstant(GV->getInitializer());      } @@ -306,10 +231,10 @@ public:      if (auto *STy = dyn_cast<StructType>(F->getReturnType())) {        MRVFunctionsTracked.insert(F);        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) -        TrackedMultipleRetVals.insert(std::make_pair(std::make_pair(F, i), -                                                     LatticeVal())); +        TrackedMultipleRetVals.insert( +            std::make_pair(std::make_pair(F, i), ValueLatticeElement()));      } else -      TrackedRetVals.insert(std::make_pair(F, LatticeVal())); +      TrackedRetVals.insert(std::make_pair(F, ValueLatticeElement()));    }    /// AddMustTailCallee - If the SCCP solver finds that this function is called @@ -352,8 +277,8 @@ public:    // block to the 'To' basic block is currently feasible.    bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); -  std::vector<LatticeVal> getStructLatticeValueFor(Value *V) const { -    std::vector<LatticeVal> StructValues; +  std::vector<ValueLatticeElement> getStructLatticeValueFor(Value *V) const { +    std::vector<ValueLatticeElement> StructValues;      auto *STy = dyn_cast<StructType>(V->getType());      assert(STy && "getStructLatticeValueFor() can be called only on structs");      for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { @@ -364,23 +289,26 @@ public:      return StructValues;    } -  const LatticeVal &getLatticeValueFor(Value *V) const { +  void removeLatticeValueFor(Value *V) { ValueState.erase(V); } + +  const ValueLatticeElement &getLatticeValueFor(Value *V) const {      assert(!V->getType()->isStructTy() &&             "Should use getStructLatticeValueFor"); -    DenseMap<Value *, LatticeVal>::const_iterator I = ValueState.find(V); +    DenseMap<Value *, ValueLatticeElement>::const_iterator I = +        ValueState.find(V);      assert(I != ValueState.end() &&             "V not found in ValueState nor Paramstate map!");      return I->second;    }    /// getTrackedRetVals - Get the inferred return value map. -  const MapVector<Function*, LatticeVal> &getTrackedRetVals() { +  const MapVector<Function *, ValueLatticeElement> &getTrackedRetVals() {      return TrackedRetVals;    }    /// getTrackedGlobals - Get and return the set of inferred initializers for    /// global variables. -  const DenseMap<GlobalVariable*, LatticeVal> &getTrackedGlobals() { +  const DenseMap<GlobalVariable *, ValueLatticeElement> &getTrackedGlobals() {      return TrackedGlobals;    } @@ -407,32 +335,59 @@ public:    }    // isStructLatticeConstant - Return true if all the lattice values -  // corresponding to elements of the structure are not overdefined, +  // corresponding to elements of the structure are constants,    // false otherwise.    bool isStructLatticeConstant(Function *F, StructType *STy) {      for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {        const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i));        assert(It != TrackedMultipleRetVals.end()); -      LatticeVal LV = It->second; -      if (LV.isOverdefined()) +      ValueLatticeElement LV = It->second; +      if (!isConstant(LV))          return false;      }      return true;    } +  /// Helper to return a Constant if \p LV is either a constant or a constant +  /// range with a single element. +  Constant *getConstant(const ValueLatticeElement &LV) const { +    if (LV.isConstant()) +      return LV.getConstant(); + +    if (LV.isConstantRange()) { +      auto &CR = LV.getConstantRange(); +      if (CR.getSingleElement()) +        return ConstantInt::get(Ctx, *CR.getSingleElement()); +    } +    return nullptr; +  } +  private: -  // pushToWorkList - Helper for markConstant/markForcedConstant/markOverdefined -  void pushToWorkList(LatticeVal &IV, Value *V) { +  ConstantInt *getConstantInt(const ValueLatticeElement &IV) const { +    return dyn_cast_or_null<ConstantInt>(getConstant(IV)); +  } + +  // pushToWorkList - Helper for markConstant/markOverdefined +  void pushToWorkList(ValueLatticeElement &IV, Value *V) {      if (IV.isOverdefined())        return OverdefinedInstWorkList.push_back(V);      InstWorkList.push_back(V);    } +  // Helper to push \p V to the worklist, after updating it to \p IV. Also +  // prints a debug message with the updated value. +  void pushToWorkListMsg(ValueLatticeElement &IV, Value *V) { +    LLVM_DEBUG(dbgs() << "updated " << IV << ": " << *V << '\n'); +    pushToWorkList(IV, V); +  } +    // markConstant - Make a value be marked as "constant".  If the value    // is not already a constant, add it to the instruction work list so that    // the users of the instruction are updated later. -  bool markConstant(LatticeVal &IV, Value *V, Constant *C) { -    if (!IV.markConstant(C)) return false; +  bool markConstant(ValueLatticeElement &IV, Value *V, Constant *C, +                    bool MayIncludeUndef = false) { +    if (!IV.markConstant(C, MayIncludeUndef)) +      return false;      LLVM_DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n');      pushToWorkList(IV, V);      return true; @@ -443,18 +398,10 @@ private:      return markConstant(ValueState[V], V, C);    } -  void markForcedConstant(Value *V, Constant *C) { -    assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); -    LatticeVal &IV = ValueState[V]; -    IV.markForcedConstant(C); -    LLVM_DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n'); -    pushToWorkList(IV, V); -  } -    // markOverdefined - Make a value be marked as "overdefined". If the    // value is not already overdefined, add it to the overdefined instruction    // work list so that the users of the instruction are updated later. -  bool markOverdefined(LatticeVal &IV, Value *V) { +  bool markOverdefined(ValueLatticeElement &IV, Value *V) {      if (!IV.markOverdefined()) return false;      LLVM_DEBUG(dbgs() << "markOverdefined: "; @@ -466,71 +413,59 @@ private:      return true;    } -  bool mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) { -    if (IV.isOverdefined() || MergeWithV.isUnknown()) -      return false; // Noop. -    if (MergeWithV.isOverdefined()) -      return markOverdefined(IV, V); -    if (IV.isUnknown()) -      return markConstant(IV, V, MergeWithV.getConstant()); -    if (IV.getConstant() != MergeWithV.getConstant()) -      return markOverdefined(IV, V); +  /// Merge \p MergeWithV into \p IV and push \p V to the worklist, if \p IV +  /// changes. +  bool mergeInValue(ValueLatticeElement &IV, Value *V, +                    ValueLatticeElement MergeWithV, +                    ValueLatticeElement::MergeOptions Opts = { +                        /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) { +    if (IV.mergeIn(MergeWithV, Opts)) { +      pushToWorkList(IV, V); +      LLVM_DEBUG(dbgs() << "Merged " << MergeWithV << " into " << *V << " : " +                        << IV << "\n"); +      return true; +    }      return false;    } -  bool mergeInValue(Value *V, LatticeVal MergeWithV) { +  bool mergeInValue(Value *V, ValueLatticeElement MergeWithV, +                    ValueLatticeElement::MergeOptions Opts = { +                        /*MayIncludeUndef=*/false, /*CheckWiden=*/false}) {      assert(!V->getType()->isStructTy() &&             "non-structs should use markConstant"); -    return mergeInValue(ValueState[V], V, MergeWithV); +    return mergeInValue(ValueState[V], V, MergeWithV, Opts);    } -  /// getValueState - Return the LatticeVal object that corresponds to the -  /// value.  This function handles the case when the value hasn't been seen yet -  /// by properly seeding constants etc. -  LatticeVal &getValueState(Value *V) { +  /// getValueState - Return the ValueLatticeElement object that corresponds to +  /// the value.  This function handles the case when the value hasn't been seen +  /// yet by properly seeding constants etc. +  ValueLatticeElement &getValueState(Value *V) {      assert(!V->getType()->isStructTy() && "Should use getStructValueState"); -    std::pair<DenseMap<Value*, LatticeVal>::iterator, bool> I = -      ValueState.insert(std::make_pair(V, LatticeVal())); -    LatticeVal &LV = I.first->second; +    auto I = ValueState.insert(std::make_pair(V, ValueLatticeElement())); +    ValueLatticeElement &LV = I.first->second;      if (!I.second)        return LV;  // Common case, already in the map. -    if (auto *C = dyn_cast<Constant>(V)) { -      // Undef values remain unknown. -      if (!isa<UndefValue>(V)) -        LV.markConstant(C);          // Constants are constant -    } - -    // All others are underdefined by default. -    return LV; -  } - -  ValueLatticeElement &getParamState(Value *V) { -    assert(!V->getType()->isStructTy() && "Should use getStructValueState"); - -    std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool> -        PI = ParamState.insert(std::make_pair(V, ValueLatticeElement())); -    ValueLatticeElement &LV = PI.first->second; -    if (PI.second) -      LV = getValueState(V).toValueLattice(); +    if (auto *C = dyn_cast<Constant>(V)) +      LV.markConstant(C);          // Constants are constant +    // All others are unknown by default.      return LV;    } -  /// getStructValueState - Return the LatticeVal object that corresponds to the -  /// value/field pair.  This function handles the case when the value hasn't -  /// been seen yet by properly seeding constants etc. -  LatticeVal &getStructValueState(Value *V, unsigned i) { +  /// getStructValueState - Return the ValueLatticeElement object that +  /// corresponds to the value/field pair.  This function handles the case when +  /// the value hasn't been seen yet by properly seeding constants etc. +  ValueLatticeElement &getStructValueState(Value *V, unsigned i) {      assert(V->getType()->isStructTy() && "Should use getValueState");      assert(i < cast<StructType>(V->getType())->getNumElements() &&             "Invalid element #"); -    std::pair<DenseMap<std::pair<Value*, unsigned>, LatticeVal>::iterator, -              bool> I = StructValueState.insert( -                        std::make_pair(std::make_pair(V, i), LatticeVal())); -    LatticeVal &LV = I.first->second; +    auto I = StructValueState.insert( +        std::make_pair(std::make_pair(V, i), ValueLatticeElement())); +    ValueLatticeElement &LV = I.first->second;      if (!I.second)        return LV;  // Common case, already in the map. @@ -589,9 +524,20 @@ private:    // Mark I's users as changed, including AdditionalUsers.    void markUsersAsChanged(Value *I) { -    for (User *U : I->users()) -      if (auto *UI = dyn_cast<Instruction>(U)) -        OperandChangedState(UI); +    // Functions include their arguments in the use-list. Changed function +    // values mean that the result of the function changed. We only need to +    // update the call sites with the new function result and do not have to +    // propagate the call arguments. +    if (isa<Function>(I)) { +      for (User *U : I->users()) { +        if (auto *CB = dyn_cast<CallBase>(U)) +          handleCallResult(*CB); +      } +    } else { +      for (User *U : I->users()) +        if (auto *UI = dyn_cast<Instruction>(U)) +          OperandChangedState(UI); +    }      auto Iter = AdditionalUsers.find(I);      if (Iter != AdditionalUsers.end()) { @@ -600,6 +546,9 @@ private:            OperandChangedState(UI);      }    } +  void handleCallOverdefined(CallBase &CB); +  void handleCallResult(CallBase &CB); +  void handleCallArguments(CallBase &CB);  private:    friend class InstVisitor<SCCPSolver>; @@ -634,20 +583,20 @@ private:    void visitGetElementPtrInst(GetElementPtrInst &I);    void visitCallInst      (CallInst &I) { -    visitCallSite(&I); +    visitCallBase(I);    }    void visitInvokeInst    (InvokeInst &II) { -    visitCallSite(&II); +    visitCallBase(II);      visitTerminator(II);    }    void visitCallBrInst    (CallBrInst &CBI) { -    visitCallSite(&CBI); +    visitCallBase(CBI);      visitTerminator(CBI);    } -  void visitCallSite      (CallSite CS); +  void visitCallBase      (CallBase &CB);    void visitResumeInst    (ResumeInst &I) { /*returns void*/ }    void visitUnreachableInst(UnreachableInst &I) { /*returns void*/ }    void visitFenceInst     (FenceInst &I) { /*returns void*/ } @@ -673,12 +622,12 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI,        return;      } -    LatticeVal BCValue = getValueState(BI->getCondition()); -    ConstantInt *CI = BCValue.getConstantInt(); +    ValueLatticeElement BCValue = getValueState(BI->getCondition()); +    ConstantInt *CI = getConstantInt(BCValue);      if (!CI) {        // Overdefined condition variables, and branches on unfoldable constant        // conditions, mean the branch could go either way. -      if (!BCValue.isUnknown()) +      if (!BCValue.isUnknownOrUndef())          Succs[0] = Succs[1] = true;        return;      } @@ -699,12 +648,12 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI,        Succs[0] = true;        return;      } -    LatticeVal SCValue = getValueState(SI->getCondition()); -    ConstantInt *CI = SCValue.getConstantInt(); +    ValueLatticeElement SCValue = getValueState(SI->getCondition()); +    ConstantInt *CI = getConstantInt(SCValue);      if (!CI) {   // Overdefined or unknown condition?        // All destinations are executable! -      if (!SCValue.isUnknown()) +      if (!SCValue.isUnknownOrUndef())          Succs.assign(TI.getNumSuccessors(), true);        return;      } @@ -717,11 +666,11 @@ void SCCPSolver::getFeasibleSuccessors(Instruction &TI,    // the target as executable.    if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) {      // Casts are folded by visitCastInst. -    LatticeVal IBRValue = getValueState(IBR->getAddress()); -    BlockAddress *Addr = IBRValue.getBlockAddress(); +    ValueLatticeElement IBRValue = getValueState(IBR->getAddress()); +    BlockAddress *Addr = dyn_cast_or_null<BlockAddress>(getConstant(IBRValue));      if (!Addr) {   // Overdefined or unknown condition?        // All destinations are executable! -      if (!IBRValue.isUnknown()) +      if (!IBRValue.isUnknownOrUndef())          Succs.assign(TI.getNumSuccessors(), true);        return;      } @@ -786,50 +735,43 @@ void SCCPSolver::visitPHINode(PHINode &PN) {      return (void)markOverdefined(&PN);    if (getValueState(&PN).isOverdefined()) -    return;  // Quick exit +    return; // Quick exit    // Super-extra-high-degree PHI nodes are unlikely to ever be marked constant,    // and slow us down a lot.  Just mark them overdefined.    if (PN.getNumIncomingValues() > 64)      return (void)markOverdefined(&PN); +  unsigned NumActiveIncoming = 0; +    // Look at all of the executable operands of the PHI node.  If any of them    // are overdefined, the PHI becomes overdefined as well.  If they are all    // constant, and they agree with each other, the PHI becomes the identical -  // constant.  If they are constant and don't agree, the PHI is overdefined. -  // If there are no executable operands, the PHI remains unknown. -  Constant *OperandVal = nullptr; +  // constant.  If they are constant and don't agree, the PHI is a constant +  // range. If there are no executable operands, the PHI remains unknown. +  ValueLatticeElement PhiState = getValueState(&PN);    for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { -    LatticeVal IV = getValueState(PN.getIncomingValue(i)); -    if (IV.isUnknown()) continue;  // Doesn't influence PHI node. -      if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent()))        continue; -    if (IV.isOverdefined())    // PHI node becomes overdefined! -      return (void)markOverdefined(&PN); - -    if (!OperandVal) {   // Grab the first value. -      OperandVal = IV.getConstant(); -      continue; -    } - -    // There is already a reachable operand.  If we conflict with it, -    // then the PHI node becomes overdefined.  If we agree with it, we -    // can continue on. - -    // Check to see if there are two different constants merging, if so, the PHI -    // node is overdefined. -    if (IV.getConstant() != OperandVal) -      return (void)markOverdefined(&PN); -  } - -  // If we exited the loop, this means that the PHI node only has constant -  // arguments that agree with each other(and OperandVal is the constant) or -  // OperandVal is null because there are no defined incoming arguments.  If -  // this is the case, the PHI remains unknown. -  if (OperandVal) -    markConstant(&PN, OperandVal);      // Acquire operand value +    ValueLatticeElement IV = getValueState(PN.getIncomingValue(i)); +    PhiState.mergeIn(IV); +    NumActiveIncoming++; +    if (PhiState.isOverdefined()) +      break; +  } + +  // We allow up to 1 range extension per active incoming value and one +  // additional extension. Note that we manually adjust the number of range +  // extensions to match the number of active incoming values. This helps to +  // limit multiple extensions caused by the same incoming value, if other +  // incoming values are equal. +  mergeInValue(&PN, PhiState, +               ValueLatticeElement::MergeOptions().setMaxWidenSteps( +                   NumActiveIncoming + 1)); +  ValueLatticeElement &PhiStateRef = getValueState(&PN); +  PhiStateRef.setNumRangeExtensions( +      std::max(NumActiveIncoming, PhiStateRef.getNumRangeExtensions()));  }  void SCCPSolver::visitReturnInst(ReturnInst &I) { @@ -840,8 +782,7 @@ void SCCPSolver::visitReturnInst(ReturnInst &I) {    // If we are tracking the return value of this function, merge it in.    if (!TrackedRetVals.empty() && !ResultOp->getType()->isStructTy()) { -    MapVector<Function*, LatticeVal>::iterator TFRVI = -      TrackedRetVals.find(F); +    auto TFRVI = TrackedRetVals.find(F);      if (TFRVI != TrackedRetVals.end()) {        mergeInValue(TFRVI->second, F, getValueState(ResultOp));        return; @@ -871,18 +812,28 @@ void SCCPSolver::visitTerminator(Instruction &TI) {  }  void SCCPSolver::visitCastInst(CastInst &I) { -  LatticeVal OpSt = getValueState(I.getOperand(0)); -  if (OpSt.isOverdefined())          // Inherit overdefinedness of operand -    markOverdefined(&I); -  else if (OpSt.isConstant()) { +  // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would +  // discover a concrete value later. +  if (ValueState[&I].isOverdefined()) +    return; + +  ValueLatticeElement OpSt = getValueState(I.getOperand(0)); +  if (Constant *OpC = getConstant(OpSt)) {      // Fold the constant as we build. -    Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpSt.getConstant(), -                                          I.getType(), DL); +    Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpC, I.getType(), DL);      if (isa<UndefValue>(C))        return;      // Propagate constant value      markConstant(&I, C); -  } +  } else if (OpSt.isConstantRange() && I.getDestTy()->isIntegerTy()) { +    auto &LV = getValueState(&I); +    ConstantRange OpRange = OpSt.getConstantRange(); +    Type *DestTy = I.getDestTy(); +    ConstantRange Res = +        OpRange.castOp(I.getOpcode(), DL.getTypeSizeInBits(DestTy)); +    mergeInValue(LV, &I, ValueLatticeElement::getRange(Res)); +  } else if (!OpSt.isUnknownOrUndef()) +    markOverdefined(&I);  }  void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { @@ -891,6 +842,11 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) {    if (EVI.getType()->isStructTy())      return (void)markOverdefined(&EVI); +  // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would +  // discover a concrete value later. +  if (ValueState[&EVI].isOverdefined()) +    return (void)markOverdefined(&EVI); +    // If this is extracting from more than one level of struct, we don't know.    if (EVI.getNumIndices() != 1)      return (void)markOverdefined(&EVI); @@ -898,7 +854,7 @@ void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) {    Value *AggVal = EVI.getAggregateOperand();    if (AggVal->getType()->isStructTy()) {      unsigned i = *EVI.idx_begin(); -    LatticeVal EltVal = getStructValueState(AggVal, i); +    ValueLatticeElement EltVal = getStructValueState(AggVal, i);      mergeInValue(getValueState(&EVI), &EVI, EltVal);    } else {      // Otherwise, must be extracting from an array. @@ -911,6 +867,11 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {    if (!STy)      return (void)markOverdefined(&IVI); +  // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would +  // discover a concrete value later. +  if (isOverdefined(ValueState[&IVI])) +    return (void)markOverdefined(&IVI); +    // If this has more than one index, we can't handle it, drive all results to    // undef.    if (IVI.getNumIndices() != 1) @@ -923,7 +884,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {    for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {      // This passes through all values that aren't the inserted element.      if (i != Idx) { -      LatticeVal EltVal = getStructValueState(Aggr, i); +      ValueLatticeElement EltVal = getStructValueState(Aggr, i);        mergeInValue(getStructValueState(&IVI, i), &IVI, EltVal);        continue;      } @@ -933,7 +894,7 @@ void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) {        // We don't track structs in structs.        markOverdefined(getStructValueState(&IVI, i), &IVI);      else { -      LatticeVal InVal = getValueState(Val); +      ValueLatticeElement InVal = getValueState(Val);        mergeInValue(getStructValueState(&IVI, i), &IVI, InVal);      }    } @@ -945,11 +906,16 @@ void SCCPSolver::visitSelectInst(SelectInst &I) {    if (I.getType()->isStructTy())      return (void)markOverdefined(&I); -  LatticeVal CondValue = getValueState(I.getCondition()); -  if (CondValue.isUnknown()) +  // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would +  // discover a concrete value later. +  if (ValueState[&I].isOverdefined()) +    return (void)markOverdefined(&I); + +  ValueLatticeElement CondValue = getValueState(I.getCondition()); +  if (CondValue.isUnknownOrUndef())      return; -  if (ConstantInt *CondCB = CondValue.getConstantInt()) { +  if (ConstantInt *CondCB = getConstantInt(CondValue)) {      Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue();      mergeInValue(&I, getValueState(OpVal));      return; @@ -958,30 +924,27 @@ void SCCPSolver::visitSelectInst(SelectInst &I) {    // Otherwise, the condition is overdefined or a constant we can't evaluate.    // See if we can produce something better than overdefined based on the T/F    // value. -  LatticeVal TVal = getValueState(I.getTrueValue()); -  LatticeVal FVal = getValueState(I.getFalseValue()); - -  // select ?, C, C -> C. -  if (TVal.isConstant() && FVal.isConstant() && -      TVal.getConstant() == FVal.getConstant()) -    return (void)markConstant(&I, FVal.getConstant()); - -  if (TVal.isUnknown())   // select ?, undef, X -> X. -    return (void)mergeInValue(&I, FVal); -  if (FVal.isUnknown())   // select ?, X, undef -> X. -    return (void)mergeInValue(&I, TVal); -  markOverdefined(&I); +  ValueLatticeElement TVal = getValueState(I.getTrueValue()); +  ValueLatticeElement FVal = getValueState(I.getFalseValue()); + +  bool Changed = ValueState[&I].mergeIn(TVal); +  Changed |= ValueState[&I].mergeIn(FVal); +  if (Changed) +    pushToWorkListMsg(ValueState[&I], &I);  }  // Handle Unary Operators.  void SCCPSolver::visitUnaryOperator(Instruction &I) { -  LatticeVal V0State = getValueState(I.getOperand(0)); +  ValueLatticeElement V0State = getValueState(I.getOperand(0)); -  LatticeVal &IV = ValueState[&I]; -  if (IV.isOverdefined()) return; +  ValueLatticeElement &IV = ValueState[&I]; +  // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would +  // discover a concrete value later. +  if (isOverdefined(IV)) +    return (void)markOverdefined(&I); -  if (V0State.isConstant()) { -    Constant *C = ConstantExpr::get(I.getOpcode(), V0State.getConstant()); +  if (isConstant(V0State)) { +    Constant *C = ConstantExpr::get(I.getOpcode(), getConstant(V0State));      // op Y -> undef.      if (isa<UndefValue>(C)) @@ -990,7 +953,7 @@ void SCCPSolver::visitUnaryOperator(Instruction &I) {    }    // If something is undef, wait for it to resolve. -  if (!V0State.isOverdefined()) +  if (!isOverdefined(V0State))      return;    markOverdefined(&I); @@ -998,101 +961,90 @@ void SCCPSolver::visitUnaryOperator(Instruction &I) {  // Handle Binary Operators.  void SCCPSolver::visitBinaryOperator(Instruction &I) { -  LatticeVal V1State = getValueState(I.getOperand(0)); -  LatticeVal V2State = getValueState(I.getOperand(1)); +  ValueLatticeElement V1State = getValueState(I.getOperand(0)); +  ValueLatticeElement V2State = getValueState(I.getOperand(1)); -  LatticeVal &IV = ValueState[&I]; -  if (IV.isOverdefined()) return; - -  if (V1State.isConstant() && V2State.isConstant()) { -    Constant *C = ConstantExpr::get(I.getOpcode(), V1State.getConstant(), -                                    V2State.getConstant()); -    // X op Y -> undef. -    if (isa<UndefValue>(C)) -      return; -    return (void)markConstant(IV, &I, C); -  } +  ValueLatticeElement &IV = ValueState[&I]; +  if (IV.isOverdefined()) +    return;    // If something is undef, wait for it to resolve. -  if (!V1State.isOverdefined() && !V2State.isOverdefined()) +  if (V1State.isUnknownOrUndef() || V2State.isUnknownOrUndef())      return; -  // Otherwise, one of our operands is overdefined.  Try to produce something -  // better than overdefined with some tricks. -  // If this is 0 / Y, it doesn't matter that the second operand is -  // overdefined, and we can replace it with zero. -  if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv) -    if (V1State.isConstant() && V1State.getConstant()->isNullValue()) -      return (void)markConstant(IV, &I, V1State.getConstant()); - -  // If this is: -  // -> AND/MUL with 0 -  // -> OR with -1 -  // it doesn't matter that the other operand is overdefined. -  if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul || -      I.getOpcode() == Instruction::Or) { -    LatticeVal *NonOverdefVal = nullptr; -    if (!V1State.isOverdefined()) -      NonOverdefVal = &V1State; -    else if (!V2State.isOverdefined()) -      NonOverdefVal = &V2State; - -    if (NonOverdefVal) { -      if (NonOverdefVal->isUnknown()) -        return; +  if (V1State.isOverdefined() && V2State.isOverdefined()) +    return (void)markOverdefined(&I); -      if (I.getOpcode() == Instruction::And || -          I.getOpcode() == Instruction::Mul) { -        // X and 0 = 0 -        // X * 0 = 0 -        if (NonOverdefVal->getConstant()->isNullValue()) -          return (void)markConstant(IV, &I, NonOverdefVal->getConstant()); -      } else { -        // X or -1 = -1 -        if (ConstantInt *CI = NonOverdefVal->getConstantInt()) -          if (CI->isMinusOne()) -            return (void)markConstant(IV, &I, NonOverdefVal->getConstant()); -      } +  // If either of the operands is a constant, try to fold it to a constant. +  // TODO: Use information from notconstant better. +  if ((V1State.isConstant() || V2State.isConstant())) { +    Value *V1 = isConstant(V1State) ? getConstant(V1State) : I.getOperand(0); +    Value *V2 = isConstant(V2State) ? getConstant(V2State) : I.getOperand(1); +    Value *R = SimplifyBinOp(I.getOpcode(), V1, V2, SimplifyQuery(DL)); +    auto *C = dyn_cast_or_null<Constant>(R); +    if (C) { +      // X op Y -> undef. +      if (isa<UndefValue>(C)) +        return; +      // Conservatively assume that the result may be based on operands that may +      // be undef. Note that we use mergeInValue to combine the constant with +      // the existing lattice value for I, as different constants might be found +      // after one of the operands go to overdefined, e.g. due to one operand +      // being a special floating value. +      ValueLatticeElement NewV; +      NewV.markConstant(C, /*MayIncludeUndef=*/true); +      return (void)mergeInValue(&I, NewV);      }    } -  markOverdefined(&I); +  // Only use ranges for binary operators on integers. +  if (!I.getType()->isIntegerTy()) +    return markOverdefined(&I); + +  // Try to simplify to a constant range. +  ConstantRange A = ConstantRange::getFull(I.getType()->getScalarSizeInBits()); +  ConstantRange B = ConstantRange::getFull(I.getType()->getScalarSizeInBits()); +  if (V1State.isConstantRange()) +    A = V1State.getConstantRange(); +  if (V2State.isConstantRange()) +    B = V2State.getConstantRange(); + +  ConstantRange R = A.binaryOp(cast<BinaryOperator>(&I)->getOpcode(), B); +  mergeInValue(&I, ValueLatticeElement::getRange(R)); + +  // TODO: Currently we do not exploit special values that produce something +  // better than overdefined with an overdefined operand for vector or floating +  // point types, like and <4 x i32> overdefined, zeroinitializer.  }  // Handle ICmpInst instruction.  void SCCPSolver::visitCmpInst(CmpInst &I) {    // Do not cache this lookup, getValueState calls later in the function might    // invalidate the reference. -  if (ValueState[&I].isOverdefined()) return; +  if (isOverdefined(ValueState[&I])) +    return (void)markOverdefined(&I);    Value *Op1 = I.getOperand(0);    Value *Op2 = I.getOperand(1);    // For parameters, use ParamState which includes constant range info if    // available. -  auto V1Param = ParamState.find(Op1); -  ValueLatticeElement V1State = (V1Param != ParamState.end()) -                                    ? V1Param->second -                                    : getValueState(Op1).toValueLattice(); - -  auto V2Param = ParamState.find(Op2); -  ValueLatticeElement V2State = V2Param != ParamState.end() -                                    ? V2Param->second -                                    : getValueState(Op2).toValueLattice(); +  auto V1State = getValueState(Op1); +  auto V2State = getValueState(Op2);    Constant *C = V1State.getCompare(I.getPredicate(), I.getType(), V2State);    if (C) {      if (isa<UndefValue>(C))        return; -    LatticeVal CV; +    ValueLatticeElement CV;      CV.markConstant(C);      mergeInValue(&I, CV);      return;    }    // If operands are still unknown, wait for it to resolve. -  if (!V1State.isOverdefined() && !V2State.isOverdefined() && -      !ValueState[&I].isConstant()) +  if ((V1State.isUnknownOrUndef() || V2State.isUnknownOrUndef()) && +      !isConstant(ValueState[&I]))      return;    markOverdefined(&I); @@ -1101,21 +1053,26 @@ void SCCPSolver::visitCmpInst(CmpInst &I) {  // Handle getelementptr instructions.  If all operands are constants then we  // can turn this into a getelementptr ConstantExpr.  void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { -  if (ValueState[&I].isOverdefined()) return; +  if (isOverdefined(ValueState[&I])) +    return (void)markOverdefined(&I);    SmallVector<Constant*, 8> Operands;    Operands.reserve(I.getNumOperands());    for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { -    LatticeVal State = getValueState(I.getOperand(i)); -    if (State.isUnknown()) +    ValueLatticeElement State = getValueState(I.getOperand(i)); +    if (State.isUnknownOrUndef())        return;  // Operands are not resolved yet. -    if (State.isOverdefined()) +    if (isOverdefined(State))        return (void)markOverdefined(&I); -    assert(State.isConstant() && "Unknown state!"); -    Operands.push_back(State.getConstant()); +    if (Constant *C = getConstant(State)) { +      Operands.push_back(C); +      continue; +    } + +    return (void)markOverdefined(&I);    }    Constant *Ptr = Operands[0]; @@ -1136,230 +1093,297 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) {      return;    GlobalVariable *GV = cast<GlobalVariable>(SI.getOperand(1)); -  DenseMap<GlobalVariable*, LatticeVal>::iterator I = TrackedGlobals.find(GV); -  if (I == TrackedGlobals.end() || I->second.isOverdefined()) return; +  auto I = TrackedGlobals.find(GV); +  if (I == TrackedGlobals.end()) +    return;    // Get the value we are storing into the global, then merge it. -  mergeInValue(I->second, GV, getValueState(SI.getOperand(0))); +  mergeInValue(I->second, GV, getValueState(SI.getOperand(0)), +               ValueLatticeElement::MergeOptions().setCheckWiden(false));    if (I->second.isOverdefined())      TrackedGlobals.erase(I);      // No need to keep tracking this!  } +static ValueLatticeElement getValueFromMetadata(const Instruction *I) { +  if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) +    if (I->getType()->isIntegerTy()) +      return ValueLatticeElement::getRange( +          getConstantRangeFromMetadata(*Ranges)); +  // TODO: Also handle MD_nonnull. +  return ValueLatticeElement::getOverdefined(); +} +  // Handle load instructions.  If the operand is a constant pointer to a constant  // global, we can replace the load with the loaded constant value!  void SCCPSolver::visitLoadInst(LoadInst &I) { -  // If this load is of a struct, just mark the result overdefined. -  if (I.getType()->isStructTy()) +  // If this load is of a struct or the load is volatile, just mark the result +  // as overdefined. +  if (I.getType()->isStructTy() || I.isVolatile())      return (void)markOverdefined(&I); -  LatticeVal PtrVal = getValueState(I.getOperand(0)); -  if (PtrVal.isUnknown()) return;   // The pointer is not resolved yet! - -  LatticeVal &IV = ValueState[&I]; -  if (IV.isOverdefined()) return; +  // ResolvedUndefsIn might mark I as overdefined. Bail out, even if we would +  // discover a concrete value later. +  if (ValueState[&I].isOverdefined()) +    return (void)markOverdefined(&I); -  if (!PtrVal.isConstant() || I.isVolatile()) -    return (void)markOverdefined(IV, &I); +  ValueLatticeElement PtrVal = getValueState(I.getOperand(0)); +  if (PtrVal.isUnknownOrUndef()) +    return; // The pointer is not resolved yet! -  Constant *Ptr = PtrVal.getConstant(); +  ValueLatticeElement &IV = ValueState[&I]; -  // load null is undefined. -  if (isa<ConstantPointerNull>(Ptr)) { -    if (NullPointerIsDefined(I.getFunction(), I.getPointerAddressSpace())) -      return (void)markOverdefined(IV, &I); -    else -      return; -  } +  if (isConstant(PtrVal)) { +    Constant *Ptr = getConstant(PtrVal); -  // Transform load (constant global) into the value loaded. -  if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { -    if (!TrackedGlobals.empty()) { -      // If we are tracking this global, merge in the known value for it. -      DenseMap<GlobalVariable*, LatticeVal>::iterator It = -        TrackedGlobals.find(GV); -      if (It != TrackedGlobals.end()) { -        mergeInValue(IV, &I, It->second); +    // load null is undefined. +    if (isa<ConstantPointerNull>(Ptr)) { +      if (NullPointerIsDefined(I.getFunction(), I.getPointerAddressSpace())) +        return (void)markOverdefined(IV, &I); +      else          return; +    } + +    // Transform load (constant global) into the value loaded. +    if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { +      if (!TrackedGlobals.empty()) { +        // If we are tracking this global, merge in the known value for it. +        auto It = TrackedGlobals.find(GV); +        if (It != TrackedGlobals.end()) { +          mergeInValue(IV, &I, It->second, getMaxWidenStepsOpts()); +          return; +        }        }      } -  } -  // Transform load from a constant into a constant if possible. -  if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) { -    if (isa<UndefValue>(C)) -      return; -    return (void)markConstant(IV, &I, C); +    // Transform load from a constant into a constant if possible. +    if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) { +      if (isa<UndefValue>(C)) +        return; +      return (void)markConstant(IV, &I, C); +    }    } -  // Otherwise we cannot say for certain what value this load will produce. -  // Bail out. -  markOverdefined(IV, &I); +  // Fall back to metadata. +  mergeInValue(&I, getValueFromMetadata(&I));  } -void SCCPSolver::visitCallSite(CallSite CS) { -  Function *F = CS.getCalledFunction(); -  Instruction *I = CS.getInstruction(); +void SCCPSolver::visitCallBase(CallBase &CB) { +  handleCallResult(CB); +  handleCallArguments(CB); +} -  if (auto *II = dyn_cast<IntrinsicInst>(I)) { -    if (II->getIntrinsicID() == Intrinsic::ssa_copy) { -      if (ValueState[I].isOverdefined()) +void SCCPSolver::handleCallOverdefined(CallBase &CB) { +  Function *F = CB.getCalledFunction(); + +  // Void return and not tracking callee, just bail. +  if (CB.getType()->isVoidTy()) +    return; + +  // Always mark struct return as overdefined. +  if (CB.getType()->isStructTy()) +    return (void)markOverdefined(&CB); + +  // Otherwise, if we have a single return value case, and if the function is +  // a declaration, maybe we can constant fold it. +  if (F && F->isDeclaration() && canConstantFoldCallTo(&CB, F)) { +    SmallVector<Constant *, 8> Operands; +    for (auto AI = CB.arg_begin(), E = CB.arg_end(); AI != E; ++AI) { +      if (AI->get()->getType()->isStructTy()) +        return markOverdefined(&CB); // Can't handle struct args. +      ValueLatticeElement State = getValueState(*AI); + +      if (State.isUnknownOrUndef()) +        return; // Operands are not resolved yet. +      if (isOverdefined(State)) +        return (void)markOverdefined(&CB); +      assert(isConstant(State) && "Unknown state!"); +      Operands.push_back(getConstant(State)); +    } + +    if (isOverdefined(getValueState(&CB))) +      return (void)markOverdefined(&CB); + +    // If we can constant fold this, mark the result of the call as a +    // constant. +    if (Constant *C = ConstantFoldCall(&CB, F, Operands, &GetTLI(*F))) { +      // call -> undef. +      if (isa<UndefValue>(C))          return; +      return (void)markConstant(&CB, C); +    } +  } + +  // Fall back to metadata. +  mergeInValue(&CB, getValueFromMetadata(&CB)); +} + +void SCCPSolver::handleCallArguments(CallBase &CB) { +  Function *F = CB.getCalledFunction(); +  // If this is a local function that doesn't have its address taken, mark its +  // entry block executable and merge in the actual arguments to the call into +  // the formal arguments of the function. +  if (!TrackingIncomingArguments.empty() && +      TrackingIncomingArguments.count(F)) { +    MarkBlockExecutable(&F->front()); + +    // Propagate information from this call site into the callee. +    auto CAI = CB.arg_begin(); +    for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); AI != E; +         ++AI, ++CAI) { +      // If this argument is byval, and if the function is not readonly, there +      // will be an implicit copy formed of the input aggregate. +      if (AI->hasByValAttr() && !F->onlyReadsMemory()) { +        markOverdefined(&*AI); +        continue; +      } + +      if (auto *STy = dyn_cast<StructType>(AI->getType())) { +        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { +          ValueLatticeElement CallArg = getStructValueState(*CAI, i); +          mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg, +                       getMaxWidenStepsOpts()); +        } +      } else +        mergeInValue(&*AI, getValueState(*CAI), getMaxWidenStepsOpts()); +    } +  } +} + +void SCCPSolver::handleCallResult(CallBase &CB) { +  Function *F = CB.getCalledFunction(); -      auto *PI = getPredicateInfoFor(I); -      if (!PI) +  if (auto *II = dyn_cast<IntrinsicInst>(&CB)) { +    if (II->getIntrinsicID() == Intrinsic::ssa_copy) { +      if (ValueState[&CB].isOverdefined())          return; -      Value *CopyOf = I->getOperand(0); -      auto *PBranch = dyn_cast<PredicateBranch>(PI); -      if (!PBranch) { -        mergeInValue(ValueState[I], I, getValueState(CopyOf)); +      Value *CopyOf = CB.getOperand(0); +      ValueLatticeElement CopyOfVal = getValueState(CopyOf); +      auto *PI = getPredicateInfoFor(&CB); +      assert(PI && "Missing predicate info for ssa.copy"); + +      CmpInst *Cmp; +      bool TrueEdge; +      if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) { +        Cmp = dyn_cast<CmpInst>(PBranch->Condition); +        TrueEdge = PBranch->TrueEdge; +      } else if (auto *PAssume = dyn_cast<PredicateAssume>(PI)) { +        Cmp = dyn_cast<CmpInst>(PAssume->Condition); +        TrueEdge = true; +      } else { +        mergeInValue(ValueState[&CB], &CB, CopyOfVal);          return;        } -      Value *Cond = PBranch->Condition; -        // Everything below relies on the condition being a comparison. -      auto *Cmp = dyn_cast<CmpInst>(Cond);        if (!Cmp) { -        mergeInValue(ValueState[I], I, getValueState(CopyOf)); +        mergeInValue(ValueState[&CB], &CB, CopyOfVal);          return;        } +      Value *RenamedOp = PI->RenamedOp;        Value *CmpOp0 = Cmp->getOperand(0);        Value *CmpOp1 = Cmp->getOperand(1); -      if (CopyOf != CmpOp0 && CopyOf != CmpOp1) { -        mergeInValue(ValueState[I], I, getValueState(CopyOf)); +      // Bail out if neither of the operands matches RenamedOp. +      if (CmpOp0 != RenamedOp && CmpOp1 != RenamedOp) { +        mergeInValue(ValueState[&CB], &CB, getValueState(CopyOf));          return;        } -      if (CmpOp0 != CopyOf) +      auto Pred = Cmp->getPredicate(); +      if (CmpOp1 == RenamedOp) {          std::swap(CmpOp0, CmpOp1); +        Pred = Cmp->getSwappedPredicate(); +      } -      LatticeVal OriginalVal = getValueState(CopyOf); -      LatticeVal EqVal = getValueState(CmpOp1); -      LatticeVal &IV = ValueState[I]; -      if (PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_EQ) { -        addAdditionalUser(CmpOp1, I); -        if (OriginalVal.isConstant()) -          mergeInValue(IV, I, OriginalVal); -        else -          mergeInValue(IV, I, EqVal); +      // Wait until CmpOp1 is resolved. +      if (getValueState(CmpOp1).isUnknown()) { +        addAdditionalUser(CmpOp1, &CB);          return;        } -      if (!PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_NE) { -        addAdditionalUser(CmpOp1, I); -        if (OriginalVal.isConstant()) -          mergeInValue(IV, I, OriginalVal); -        else -          mergeInValue(IV, I, EqVal); + +      // The code below relies on PredicateInfo only inserting copies for the +      // true branch when the branch condition is an AND and only inserting +      // copies for the false branch when the branch condition is an OR. This +      // ensures we can intersect the range from the condition with the range of +      // CopyOf. +      if (!TrueEdge) +        Pred = CmpInst::getInversePredicate(Pred); + +      ValueLatticeElement CondVal = getValueState(CmpOp1); +      ValueLatticeElement &IV = ValueState[&CB]; +      if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) { +        auto ImposedCR = +            ConstantRange::getFull(DL.getTypeSizeInBits(CopyOf->getType())); + +        // Get the range imposed by the condition. +        if (CondVal.isConstantRange()) +          ImposedCR = ConstantRange::makeAllowedICmpRegion( +              Pred, CondVal.getConstantRange()); + +        // Combine range info for the original value with the new range from the +        // condition. +        auto CopyOfCR = CopyOfVal.isConstantRange() +                            ? CopyOfVal.getConstantRange() +                            : ConstantRange::getFull( +                                  DL.getTypeSizeInBits(CopyOf->getType())); +        auto NewCR = ImposedCR.intersectWith(CopyOfCR); +        // If the existing information is != x, do not use the information from +        // a chained predicate, as the != x information is more likely to be +        // helpful in practice. +        if (!CopyOfCR.contains(NewCR) && CopyOfCR.getSingleMissingElement()) +          NewCR = CopyOfCR; + +        addAdditionalUser(CmpOp1, &CB); +        // TODO: Actually filp MayIncludeUndef for the created range to false, +        // once most places in the optimizer respect the branches on +        // undef/poison are UB rule. The reason why the new range cannot be +        // undef is as follows below: +        // The new range is based on a branch condition. That guarantees that +        // neither of the compare operands can be undef in the branch targets, +        // unless we have conditions that are always true/false (e.g. icmp ule +        // i32, %a, i32_max). For the latter overdefined/empty range will be +        // inferred, but the branch will get folded accordingly anyways. +        mergeInValue( +            IV, &CB, +            ValueLatticeElement::getRange(NewCR, /*MayIncludeUndef=*/true)); +        return; +      } else if (Pred == CmpInst::ICMP_EQ && CondVal.isConstant()) { +        // For non-integer values or integer constant expressions, only +        // propagate equal constants. +        addAdditionalUser(CmpOp1, &CB); +        mergeInValue(IV, &CB, CondVal);          return;        } -      return (void)mergeInValue(IV, I, getValueState(CopyOf)); +      return (void)mergeInValue(IV, &CB, CopyOfVal);      }    }    // The common case is that we aren't tracking the callee, either because we    // are not doing interprocedural analysis or the callee is indirect, or is    // external.  Handle these cases first. -  if (!F || F->isDeclaration()) { -CallOverdefined: -    // Void return and not tracking callee, just bail. -    if (I->getType()->isVoidTy()) return; - -    // Otherwise, if we have a single return value case, and if the function is -    // a declaration, maybe we can constant fold it. -    if (F && F->isDeclaration() && !I->getType()->isStructTy() && -        canConstantFoldCallTo(cast<CallBase>(CS.getInstruction()), F)) { -      SmallVector<Constant*, 8> Operands; -      for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); -           AI != E; ++AI) { -        if (AI->get()->getType()->isStructTy()) -          return markOverdefined(I); // Can't handle struct args. -        LatticeVal State = getValueState(*AI); - -        if (State.isUnknown()) -          return;  // Operands are not resolved yet. -        if (State.isOverdefined()) -          return (void)markOverdefined(I); -        assert(State.isConstant() && "Unknown state!"); -        Operands.push_back(State.getConstant()); -      } - -      if (getValueState(I).isOverdefined()) -        return; - -      // If we can constant fold this, mark the result of the call as a -      // constant. -      if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), F, -                                         Operands, &GetTLI(*F))) { -        // call -> undef. -        if (isa<UndefValue>(C)) -          return; -        return (void)markConstant(I, C); -      } -    } - -    // Otherwise, we don't know anything about this call, mark it overdefined. -    return (void)markOverdefined(I); -  } - -  // If this is a local function that doesn't have its address taken, mark its -  // entry block executable and merge in the actual arguments to the call into -  // the formal arguments of the function. -  if (!TrackingIncomingArguments.empty() && TrackingIncomingArguments.count(F)){ -    MarkBlockExecutable(&F->front()); - -    // Propagate information from this call site into the callee. -    CallSite::arg_iterator CAI = CS.arg_begin(); -    for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); -         AI != E; ++AI, ++CAI) { -      // If this argument is byval, and if the function is not readonly, there -      // will be an implicit copy formed of the input aggregate. -      if (AI->hasByValAttr() && !F->onlyReadsMemory()) { -        markOverdefined(&*AI); -        continue; -      } - -      if (auto *STy = dyn_cast<StructType>(AI->getType())) { -        for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { -          LatticeVal CallArg = getStructValueState(*CAI, i); -          mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg); -        } -      } else { -        // Most other parts of the Solver still only use the simpler value -        // lattice, so we propagate changes for parameters to both lattices. -        LatticeVal ConcreteArgument = getValueState(*CAI); -        bool ParamChanged = -            getParamState(&*AI).mergeIn(ConcreteArgument.toValueLattice(), DL); -         bool ValueChanged = mergeInValue(&*AI, ConcreteArgument); -        // Add argument to work list, if the state of a parameter changes but -        // ValueState does not change (because it is already overdefined there), -        // We have to take changes in ParamState into account, as it is used -        // when evaluating Cmp instructions. -        if (!ValueChanged && ParamChanged) -          pushToWorkList(ValueState[&*AI], &*AI); -      } -    } -  } +  if (!F || F->isDeclaration()) +    return handleCallOverdefined(CB);    // If this is a single/zero retval case, see if we're tracking the function.    if (auto *STy = dyn_cast<StructType>(F->getReturnType())) {      if (!MRVFunctionsTracked.count(F)) -      goto CallOverdefined;  // Not tracking this callee. +      return handleCallOverdefined(CB); // Not tracking this callee.      // If we are tracking this callee, propagate the result of the function      // into this call site.      for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) -      mergeInValue(getStructValueState(I, i), I, -                   TrackedMultipleRetVals[std::make_pair(F, i)]); +      mergeInValue(getStructValueState(&CB, i), &CB, +                   TrackedMultipleRetVals[std::make_pair(F, i)], +                   getMaxWidenStepsOpts());    } else { -    MapVector<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F); +    auto TFRVI = TrackedRetVals.find(F);      if (TFRVI == TrackedRetVals.end()) -      goto CallOverdefined;  // Not tracking this callee. +      return handleCallOverdefined(CB); // Not tracking this callee.      // If so, propagate the return value of the callee into this call result. -    mergeInValue(I, TFRVI->second); +    mergeInValue(&CB, TFRVI->second, getMaxWidenStepsOpts());    }  } @@ -1429,10 +1453,8 @@ void SCCPSolver::Solve() {  /// constraints on the condition of the branch, as that would impact other users  /// of the value.  /// -/// This scan also checks for values that use undefs, whose results are actually -/// defined.  For example, 'zext i8 undef to i32' should produce all zeros -/// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero, -/// even if X isn't defined. +/// This scan also checks for values that use undefs. It conservatively marks +/// them as overdefined.  bool SCCPSolver::ResolvedUndefsIn(Function &F) {    for (BasicBlock &BB : F) {      if (!BBExecutable.count(&BB)) @@ -1446,8 +1468,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {          // Only a few things that can be structs matter for undef.          // Tracked calls must never be marked overdefined in ResolvedUndefsIn. -        if (CallSite CS = CallSite(&I)) -          if (Function *F = CS.getCalledFunction()) +        if (auto *CB = dyn_cast<CallBase>(&I)) +          if (Function *F = CB->getCalledFunction())              if (MRVFunctionsTracked.count(F))                continue; @@ -1455,19 +1477,18 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {          // tracked as precisely as their operands.          if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I))            continue; -          // Send the results of everything else to overdefined.  We could be          // more precise than this but it isn't worth bothering.          for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { -          LatticeVal &LV = getStructValueState(&I, i); -          if (LV.isUnknown()) +          ValueLatticeElement &LV = getStructValueState(&I, i); +          if (LV.isUnknownOrUndef())              markOverdefined(LV, &I);          }          continue;        } -      LatticeVal &LV = getValueState(&I); -      if (!LV.isUnknown()) +      ValueLatticeElement &LV = getValueState(&I); +      if (!LV.isUnknownOrUndef())          continue;        // There are two reasons a call can have an undef result @@ -1475,195 +1496,20 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {        // 2. It could be constant-foldable.        // Because of the way we solve return values, tracked calls must        // never be marked overdefined in ResolvedUndefsIn. -      if (CallSite CS = CallSite(&I)) { -        if (Function *F = CS.getCalledFunction()) +      if (auto *CB = dyn_cast<CallBase>(&I)) +        if (Function *F = CB->getCalledFunction())            if (TrackedRetVals.count(F))              continue; -        // If the call is constant-foldable, we mark it overdefined because -        // we do not know what return values are valid. -        markOverdefined(&I); -        return true; -      } - -      // extractvalue is safe; check here because the argument is a struct. -      if (isa<ExtractValueInst>(I)) -        continue; - -      // Compute the operand LatticeVals, for convenience below. -      // Anything taking a struct is conservatively assumed to require -      // overdefined markings. -      if (I.getOperand(0)->getType()->isStructTy()) { -        markOverdefined(&I); -        return true; -      } -      LatticeVal Op0LV = getValueState(I.getOperand(0)); -      LatticeVal Op1LV; -      if (I.getNumOperands() == 2) { -        if (I.getOperand(1)->getType()->isStructTy()) { -          markOverdefined(&I); -          return true; -        } - -        Op1LV = getValueState(I.getOperand(1)); -      } -      // If this is an instructions whose result is defined even if the input is -      // not fully defined, propagate the information. -      Type *ITy = I.getType(); -      switch (I.getOpcode()) { -      case Instruction::Add: -      case Instruction::Sub: -      case Instruction::Trunc: -      case Instruction::FPTrunc: -      case Instruction::BitCast: -        break; // Any undef -> undef -      case Instruction::FSub: -      case Instruction::FAdd: -      case Instruction::FMul: -      case Instruction::FDiv: -      case Instruction::FRem: -        // Floating-point binary operation: be conservative. -        if (Op0LV.isUnknown() && Op1LV.isUnknown()) -          markForcedConstant(&I, Constant::getNullValue(ITy)); -        else -          markOverdefined(&I); -        return true; -      case Instruction::FNeg: -        break; // fneg undef -> undef -      case Instruction::ZExt: -      case Instruction::SExt: -      case Instruction::FPToUI: -      case Instruction::FPToSI: -      case Instruction::FPExt: -      case Instruction::PtrToInt: -      case Instruction::IntToPtr: -      case Instruction::SIToFP: -      case Instruction::UIToFP: -        // undef -> 0; some outputs are impossible -        markForcedConstant(&I, Constant::getNullValue(ITy)); -        return true; -      case Instruction::Mul: -      case Instruction::And: -        // Both operands undef -> undef -        if (Op0LV.isUnknown() && Op1LV.isUnknown()) -          break; -        // undef * X -> 0.   X could be zero. -        // undef & X -> 0.   X could be zero. -        markForcedConstant(&I, Constant::getNullValue(ITy)); -        return true; -      case Instruction::Or: -        // Both operands undef -> undef -        if (Op0LV.isUnknown() && Op1LV.isUnknown()) -          break; -        // undef | X -> -1.   X could be -1. -        markForcedConstant(&I, Constant::getAllOnesValue(ITy)); -        return true; -      case Instruction::Xor: -        // undef ^ undef -> 0; strictly speaking, this is not strictly -        // necessary, but we try to be nice to people who expect this -        // behavior in simple cases -        if (Op0LV.isUnknown() && Op1LV.isUnknown()) { -          markForcedConstant(&I, Constant::getNullValue(ITy)); -          return true; -        } -        // undef ^ X -> undef -        break; -      case Instruction::SDiv: -      case Instruction::UDiv: -      case Instruction::SRem: -      case Instruction::URem: -        // X / undef -> undef.  No change. -        // X % undef -> undef.  No change. -        if (Op1LV.isUnknown()) break; - -        // X / 0 -> undef.  No change. -        // X % 0 -> undef.  No change. -        if (Op1LV.isConstant() && Op1LV.getConstant()->isZeroValue()) -          break; - -        // undef / X -> 0.   X could be maxint. -        // undef % X -> 0.   X could be 1. -        markForcedConstant(&I, Constant::getNullValue(ITy)); -        return true; -      case Instruction::AShr: -        // X >>a undef -> undef. -        if (Op1LV.isUnknown()) break; - -        // Shifting by the bitwidth or more is undefined. -        if (Op1LV.isConstant()) { -          if (auto *ShiftAmt = Op1LV.getConstantInt()) -            if (ShiftAmt->getLimitedValue() >= -                ShiftAmt->getType()->getScalarSizeInBits()) -              break; -        } - -        // undef >>a X -> 0 -        markForcedConstant(&I, Constant::getNullValue(ITy)); -        return true; -      case Instruction::LShr: -      case Instruction::Shl: -        // X << undef -> undef. -        // X >> undef -> undef. -        if (Op1LV.isUnknown()) break; - -        // Shifting by the bitwidth or more is undefined. -        if (Op1LV.isConstant()) { -          if (auto *ShiftAmt = Op1LV.getConstantInt()) -            if (ShiftAmt->getLimitedValue() >= -                ShiftAmt->getType()->getScalarSizeInBits()) -              break; -        } - -        // undef << X -> 0 -        // undef >> X -> 0 -        markForcedConstant(&I, Constant::getNullValue(ITy)); -        return true; -      case Instruction::Select: -        Op1LV = getValueState(I.getOperand(1)); -        // undef ? X : Y  -> X or Y.  There could be commonality between X/Y. -        if (Op0LV.isUnknown()) { -          if (!Op1LV.isConstant())  // Pick the constant one if there is any. -            Op1LV = getValueState(I.getOperand(2)); -        } else if (Op1LV.isUnknown()) { -          // c ? undef : undef -> undef.  No change. -          Op1LV = getValueState(I.getOperand(2)); -          if (Op1LV.isUnknown()) -            break; -          // Otherwise, c ? undef : x -> x. -        } else { -          // Leave Op1LV as Operand(1)'s LatticeValue. -        } - -        if (Op1LV.isConstant()) -          markForcedConstant(&I, Op1LV.getConstant()); -        else -          markOverdefined(&I); -        return true; -      case Instruction::Load: +      if (isa<LoadInst>(I)) {          // A load here means one of two things: a load of undef from a global,          // a load from an unknown pointer.  Either way, having it return undef          // is okay. -        break; -      case Instruction::ICmp: -        // X == undef -> undef.  Other comparisons get more complicated. -        Op0LV = getValueState(I.getOperand(0)); -        Op1LV = getValueState(I.getOperand(1)); - -        if ((Op0LV.isUnknown() || Op1LV.isUnknown()) && -            cast<ICmpInst>(&I)->isEquality()) -          break; -        markOverdefined(&I); -        return true; -      case Instruction::Call: -      case Instruction::Invoke: -      case Instruction::CallBr: -        llvm_unreachable("Call-like instructions should have be handled early"); -      default: -        // If we don't know what should happen here, conservatively mark it -        // overdefined. -        markOverdefined(&I); -        return true; +        continue;        } + +      markOverdefined(&I); +      return true;      }      // Check to see if we have a branch or switch on an undefined value.  If so @@ -1672,7 +1518,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {      Instruction *TI = BB.getTerminator();      if (auto *BI = dyn_cast<BranchInst>(TI)) {        if (!BI->isConditional()) continue; -      if (!getValueState(BI->getCondition()).isUnknown()) +      if (!getValueState(BI->getCondition()).isUnknownOrUndef())          continue;        // If the input to SCCP is actually branch on undef, fix the undef to @@ -1700,7 +1546,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {        if (IBR->getNumSuccessors() < 1)          continue; -      if (!getValueState(IBR->getAddress()).isUnknown()) +      if (!getValueState(IBR->getAddress()).isUnknownOrUndef())          continue;        // If the input to SCCP is actually branch on undef, fix the undef to @@ -1724,7 +1570,8 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {      }      if (auto *SI = dyn_cast<SwitchInst>(TI)) { -      if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown()) +      if (!SI->getNumCases() || +          !getValueState(SI->getCondition()).isUnknownOrUndef())          continue;        // If the input to SCCP is actually switch on undef, fix the undef to @@ -1753,25 +1600,26 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) {  static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {    Constant *Const = nullptr;    if (V->getType()->isStructTy()) { -    std::vector<LatticeVal> IVs = Solver.getStructLatticeValueFor(V); -    if (llvm::any_of(IVs, -                     [](const LatticeVal &LV) { return LV.isOverdefined(); })) +    std::vector<ValueLatticeElement> IVs = Solver.getStructLatticeValueFor(V); +    if (any_of(IVs, +               [](const ValueLatticeElement &LV) { return isOverdefined(LV); }))        return false;      std::vector<Constant *> ConstVals;      auto *ST = cast<StructType>(V->getType());      for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { -      LatticeVal V = IVs[i]; -      ConstVals.push_back(V.isConstant() -                              ? V.getConstant() +      ValueLatticeElement V = IVs[i]; +      ConstVals.push_back(isConstant(V) +                              ? Solver.getConstant(V)                                : UndefValue::get(ST->getElementType(i)));      }      Const = ConstantStruct::get(ST, ConstVals);    } else { -    const LatticeVal &IV = Solver.getLatticeValueFor(V); -    if (IV.isOverdefined()) +    const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); +    if (isOverdefined(IV))        return false; -    Const = IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); +    Const = +        isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());    }    assert(Const && "Constant is nullptr here!"); @@ -1779,8 +1627,7 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {    // unless the call itself can be removed    CallInst *CI = dyn_cast<CallInst>(V);    if (CI && CI->isMustTailCall() && !CI->isSafeToRemove()) { -    CallSite CS(CI); -    Function *F = CS.getCalledFunction(); +    Function *F = CI->getCalledFunction();      // Don't zap returns of the callee      if (F) @@ -1798,13 +1645,49 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) {    return true;  } +static bool simplifyInstsInBlock(SCCPSolver &Solver, BasicBlock &BB, +                                 SmallPtrSetImpl<Value *> &InsertedValues, +                                 Statistic &InstRemovedStat, +                                 Statistic &InstReplacedStat) { +  bool MadeChanges = false; +  for (Instruction &Inst : make_early_inc_range(BB)) { +    if (Inst.getType()->isVoidTy()) +      continue; +    if (tryToReplaceWithConstant(Solver, &Inst)) { +      if (Inst.isSafeToRemove()) +        Inst.eraseFromParent(); +      // Hey, we just changed something! +      MadeChanges = true; +      ++InstRemovedStat; +    } else if (isa<SExtInst>(&Inst)) { +      Value *ExtOp = Inst.getOperand(0); +      if (isa<Constant>(ExtOp) || InsertedValues.count(ExtOp)) +        continue; +      const ValueLatticeElement &IV = Solver.getLatticeValueFor(ExtOp); +      if (!IV.isConstantRange(/*UndefAllowed=*/false)) +        continue; +      if (IV.getConstantRange().isAllNonNegative()) { +        auto *ZExt = new ZExtInst(ExtOp, Inst.getType(), "", &Inst); +        InsertedValues.insert(ZExt); +        Inst.replaceAllUsesWith(ZExt); +        Solver.removeLatticeValueFor(&Inst); +        Inst.eraseFromParent(); +        InstReplacedStat++; +        MadeChanges = true; +      } +    } +  } +  return MadeChanges; +} +  // runSCCP() - Run the Sparse Conditional Constant Propagation algorithm,  // and return true if the function was modified.  static bool runSCCP(Function &F, const DataLayout &DL,                      const TargetLibraryInfo *TLI) {    LLVM_DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n");    SCCPSolver Solver( -      DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; }); +      DL, [TLI](Function &F) -> const TargetLibraryInfo & { return *TLI; }, +      F.getContext());    // Mark the first block of the function as being executable.    Solver.MarkBlockExecutable(&F.front()); @@ -1827,6 +1710,7 @@ static bool runSCCP(Function &F, const DataLayout &DL,    // delete their contents now.  Note that we cannot actually delete the blocks,    // as we cannot modify the CFG of the function. +  SmallPtrSet<Value *, 32> InsertedValues;    for (BasicBlock &BB : F) {      if (!Solver.isBlockExecutable(&BB)) {        LLVM_DEBUG(dbgs() << "  BasicBlock Dead:" << BB); @@ -1838,21 +1722,8 @@ static bool runSCCP(Function &F, const DataLayout &DL,        continue;      } -    // Iterate over all of the instructions in a function, replacing them with -    // constants if we have found them to be of constant values. -    for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { -      Instruction *Inst = &*BI++; -      if (Inst->getType()->isVoidTy() || Inst->isTerminator()) -        continue; - -      if (tryToReplaceWithConstant(Solver, Inst)) { -        if (isInstructionTriviallyDead(Inst)) -          Inst->eraseFromParent(); -        // Hey, we just changed something! -        MadeChanges = true; -        ++NumInstRemoved; -      } -    } +    MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues, +                                        NumInstRemoved, NumInstReplaced);    }    return MadeChanges; @@ -1942,14 +1813,15 @@ static void findReturnsToZap(Function &F,                 // uses (like blockaddresses) could stuck around, without being                 // used in the underlying IR, meaning we do not have lattice                 // values for them. -               if (!CallSite(U)) +               if (!isa<CallBase>(U))                   return true;                 if (U->getType()->isStructTy()) { -                 return all_of( -                     Solver.getStructLatticeValueFor(U), -                     [](const LatticeVal &LV) { return !LV.isOverdefined(); }); +                 return all_of(Solver.getStructLatticeValueFor(U), +                               [](const ValueLatticeElement &LV) { +                                 return !isOverdefined(LV); +                               });                 } -               return !Solver.getLatticeValueFor(U).isOverdefined(); +               return !isOverdefined(Solver.getLatticeValueFor(U));               }) &&        "We can only zap functions where all live users have a concrete value"); @@ -2006,7 +1878,7 @@ bool llvm::runIPSCCP(      Module &M, const DataLayout &DL,      std::function<const TargetLibraryInfo &(Function &)> GetTLI,      function_ref<AnalysisResultsForFn(Function &)> getAnalysis) { -  SCCPSolver Solver(DL, GetTLI); +  SCCPSolver Solver(DL, GetTLI, M.getContext());    // Loop over all functions, marking arguments to those with their addresses    // taken or that are external as overdefined. @@ -2080,30 +1952,21 @@ bool llvm::runIPSCCP(          }        } -    for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { -      if (!Solver.isBlockExecutable(&*BB)) { -        LLVM_DEBUG(dbgs() << "  BasicBlock Dead:" << *BB); +    SmallPtrSet<Value *, 32> InsertedValues; +    for (BasicBlock &BB : F) { +      if (!Solver.isBlockExecutable(&BB)) { +        LLVM_DEBUG(dbgs() << "  BasicBlock Dead:" << BB);          ++NumDeadBlocks;          MadeChanges = true; -        if (&*BB != &F.front()) -          BlocksToErase.push_back(&*BB); +        if (&BB != &F.front()) +          BlocksToErase.push_back(&BB);          continue;        } -      for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { -        Instruction *Inst = &*BI++; -        if (Inst->getType()->isVoidTy()) -          continue; -        if (tryToReplaceWithConstant(Solver, Inst)) { -          if (Inst->isSafeToRemove()) -            Inst->eraseFromParent(); -          // Hey, we just changed something! -          MadeChanges = true; -          ++IPNumInstRemoved; -        } -      } +      MadeChanges |= simplifyInstsInBlock(Solver, BB, InsertedValues, +                                          IPNumInstRemoved, IPNumInstReplaced);      }      DomTreeUpdater DTU = Solver.getDTU(F); @@ -2189,10 +2052,9 @@ bool llvm::runIPSCCP(    // whether other functions are optimizable.    SmallVector<ReturnInst*, 8> ReturnsToZap; -  const MapVector<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); -  for (const auto &I : RV) { +  for (const auto &I : Solver.getTrackedRetVals()) {      Function *F = I.first; -    if (I.second.isOverdefined() || F->getReturnType()->isVoidTy()) +    if (isOverdefined(I.second) || F->getReturnType()->isVoidTy())        continue;      findReturnsToZap(*F, ReturnsToZap, Solver);    } @@ -2213,17 +2075,16 @@ bool llvm::runIPSCCP(    // If we inferred constant or undef values for globals variables, we can    // delete the global and any stores that remain to it. -  const DenseMap<GlobalVariable*, LatticeVal> &TG = Solver.getTrackedGlobals(); -  for (DenseMap<GlobalVariable*, LatticeVal>::const_iterator I = TG.begin(), -         E = TG.end(); I != E; ++I) { -    GlobalVariable *GV = I->first; -    assert(!I->second.isOverdefined() && -           "Overdefined values should have been taken out of the map!"); +  for (auto &I : make_early_inc_range(Solver.getTrackedGlobals())) { +    GlobalVariable *GV = I.first; +    if (isOverdefined(I.second)) +      continue;      LLVM_DEBUG(dbgs() << "Found that GV '" << GV->getName()                        << "' is constant!\n");      while (!GV->use_empty()) {        StoreInst *SI = cast<StoreInst>(GV->user_back());        SI->eraseFromParent(); +      MadeChanges = true;      }      M.getGlobalList().erase(GV);      ++IPNumGlobalConst; diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp index 89916e43fce2..89f324deef9f 100644 --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -94,11 +94,6 @@  #include <utility>  #include <vector> -#ifndef NDEBUG -// We only use this for a debug check. -#include <random> -#endif -  using namespace llvm;  using namespace llvm::sroa; @@ -115,11 +110,6 @@ STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion");  STATISTIC(NumDeleted, "Number of instructions deleted");  STATISTIC(NumVectorized, "Number of vectorized aggregates"); -/// Hidden option to enable randomly shuffling the slices to help uncover -/// instability in their order. -static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices", -                                             cl::init(false), cl::Hidden); -  /// Hidden option to experiment with completely strict handling of inbounds  /// GEPs.  static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), @@ -129,7 +119,7 @@ namespace {  /// A custom IRBuilder inserter which prefixes all names, but only in  /// Assert builds. -class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter { +class IRBuilderPrefixedInserter final : public IRBuilderDefaultInserter {    std::string Prefix;    const Twine getNameWithPrefix(const Twine &Name) const { @@ -139,9 +129,8 @@ class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter {  public:    void SetNamePrefix(const Twine &P) { Prefix = P.str(); } -protected:    void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, -                    BasicBlock::iterator InsertPt) const { +                    BasicBlock::iterator InsertPt) const override {      IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name), BB,                                             InsertPt);    } @@ -663,7 +652,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {  public:    SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS)        : PtrUseVisitor<SliceBuilder>(DL), -        AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {} +        AllocSize(DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize()), +        AS(AS) {}  private:    void markAsDead(Instruction &I) { @@ -752,8 +742,10 @@ private:            // For array or vector indices, scale the index by the size of the            // type.            APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth()); -          GEPOffset += Index * APInt(Offset.getBitWidth(), -                                     DL.getTypeAllocSize(GTI.getIndexedType())); +          GEPOffset += +              Index * +              APInt(Offset.getBitWidth(), +                    DL.getTypeAllocSize(GTI.getIndexedType()).getFixedSize());          }          // If this index has computed an intermediate pointer which is not @@ -788,7 +780,7 @@ private:          LI.getPointerAddressSpace() != DL.getAllocaAddrSpace())        return PI.setAborted(&LI); -    uint64_t Size = DL.getTypeStoreSize(LI.getType()); +    uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedSize();      return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile());    } @@ -803,7 +795,7 @@ private:          SI.getPointerAddressSpace() != DL.getAllocaAddrSpace())        return PI.setAborted(&SI); -    uint64_t Size = DL.getTypeStoreSize(ValOp->getType()); +    uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedSize();      // If this memory access can be shown to *statically* extend outside the      // bounds of the allocation, it's behavior is undefined, so simply @@ -1069,17 +1061,9 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI)        llvm::remove_if(Slices, [](const Slice &S) { return S.isDead(); }),        Slices.end()); -#ifndef NDEBUG -  if (SROARandomShuffleSlices) { -    std::mt19937 MT(static_cast<unsigned>( -        std::chrono::system_clock::now().time_since_epoch().count())); -    std::shuffle(Slices.begin(), Slices.end(), MT); -  } -#endif -    // Sort the uses. This arranges for the offsets to be in ascending order,    // and the sizes to be in descending order. -  llvm::sort(Slices); +  std::stable_sort(Slices.begin(), Slices.end());  }  #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1200,7 +1184,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) {    // TODO: Allow recursive phi users.    // TODO: Allow stores.    BasicBlock *BB = PN.getParent(); -  MaybeAlign MaxAlign; +  Align MaxAlign;    uint64_t APWidth = DL.getIndexTypeSizeInBits(PN.getType());    APInt MaxSize(APWidth, 0);    bool HaveLoad = false; @@ -1221,8 +1205,8 @@ static bool isSafePHIToSpeculate(PHINode &PN) {        if (BBI->mayWriteToMemory())          return false; -    uint64_t Size = DL.getTypeStoreSize(LI->getType()); -    MaxAlign = std::max(MaxAlign, MaybeAlign(LI->getAlignment())); +    uint64_t Size = DL.getTypeStoreSize(LI->getType()).getFixedSize(); +    MaxAlign = std::max(MaxAlign, LI->getAlign());      MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize;      HaveLoad = true;    } @@ -1273,7 +1257,7 @@ static void speculatePHINodeLoads(PHINode &PN) {    // matter which one we get and if any differ.    AAMDNodes AATags;    SomeLoad->getAAMetadata(AATags); -  const MaybeAlign Align = MaybeAlign(SomeLoad->getAlignment()); +  Align Alignment = SomeLoad->getAlign();    // Rewrite all loads of the PN to use the new PHI.    while (!PN.use_empty()) { @@ -1300,11 +1284,10 @@ static void speculatePHINodeLoads(PHINode &PN) {      Instruction *TI = Pred->getTerminator();      IRBuilderTy PredBuilder(TI); -    LoadInst *Load = PredBuilder.CreateLoad( -        LoadTy, InVal, +    LoadInst *Load = PredBuilder.CreateAlignedLoad( +        LoadTy, InVal, Alignment,          (PN.getName() + ".sroa.speculate.load." + Pred->getName()));      ++NumLoadsSpeculated; -    Load->setAlignment(Align);      if (AATags)        Load->setAAMetadata(AATags);      NewPN->addIncoming(Load, Pred); @@ -1342,10 +1325,10 @@ static bool isSafeSelectToSpeculate(SelectInst &SI) {      // absolutely (e.g. allocas) or at this point because we can see other      // accesses to it.      if (!isSafeToLoadUnconditionally(TValue, LI->getType(), -                                     MaybeAlign(LI->getAlignment()), DL, LI)) +                                     LI->getAlign(), DL, LI))        return false;      if (!isSafeToLoadUnconditionally(FValue, LI->getType(), -                                     MaybeAlign(LI->getAlignment()), DL, LI)) +                                     LI->getAlign(), DL, LI))        return false;    } @@ -1371,8 +1354,8 @@ static void speculateSelectInstLoads(SelectInst &SI) {      NumLoadsSpeculated += 2;      // Transfer alignment and AA info if present. -    TL->setAlignment(MaybeAlign(LI->getAlignment())); -    FL->setAlignment(MaybeAlign(LI->getAlignment())); +    TL->setAlignment(LI->getAlign()); +    FL->setAlignment(LI->getAlign());      AAMDNodes Tags;      LI->getAAMetadata(Tags); @@ -1479,14 +1462,15 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,    // extremely poorly defined currently. The long-term goal is to remove GEPing    // over a vector from the IR completely.    if (VectorType *VecTy = dyn_cast<VectorType>(Ty)) { -    unsigned ElementSizeInBits = DL.getTypeSizeInBits(VecTy->getScalarType()); +    unsigned ElementSizeInBits = +        DL.getTypeSizeInBits(VecTy->getScalarType()).getFixedSize();      if (ElementSizeInBits % 8 != 0) {        // GEPs over non-multiple of 8 size vector elements are invalid.        return nullptr;      }      APInt ElementSize(Offset.getBitWidth(), ElementSizeInBits / 8);      APInt NumSkippedElements = Offset.sdiv(ElementSize); -    if (NumSkippedElements.ugt(VecTy->getNumElements())) +    if (NumSkippedElements.ugt(cast<FixedVectorType>(VecTy)->getNumElements()))        return nullptr;      Offset -= NumSkippedElements * ElementSize;      Indices.push_back(IRB.getInt(NumSkippedElements)); @@ -1496,7 +1480,8 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,    if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) {      Type *ElementTy = ArrTy->getElementType(); -    APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); +    APInt ElementSize(Offset.getBitWidth(), +                      DL.getTypeAllocSize(ElementTy).getFixedSize());      APInt NumSkippedElements = Offset.sdiv(ElementSize);      if (NumSkippedElements.ugt(ArrTy->getNumElements()))        return nullptr; @@ -1518,7 +1503,7 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,    unsigned Index = SL->getElementContainingOffset(StructOffset);    Offset -= APInt(Offset.getBitWidth(), SL->getElementOffset(Index));    Type *ElementTy = STy->getElementType(Index); -  if (Offset.uge(DL.getTypeAllocSize(ElementTy))) +  if (Offset.uge(DL.getTypeAllocSize(ElementTy).getFixedSize()))      return nullptr; // The offset points into alignment padding.    Indices.push_back(IRB.getInt32(Index)); @@ -1550,7 +1535,8 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL,    Type *ElementTy = Ty->getElementType();    if (!ElementTy->isSized())      return nullptr; // We can't GEP through an unsized element. -  APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); +  APInt ElementSize(Offset.getBitWidth(), +                    DL.getTypeAllocSize(ElementTy).getFixedSize());    if (ElementSize == 0)      return nullptr; // Zero-length arrays can't help us build a natural GEP.    APInt NumSkippedElements = Offset.sdiv(ElementSize); @@ -1681,20 +1667,8 @@ static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr,  }  /// Compute the adjusted alignment for a load or store from an offset. -static Align getAdjustedAlignment(Instruction *I, uint64_t Offset, -                                  const DataLayout &DL) { -  MaybeAlign Alignment; -  Type *Ty; -  if (auto *LI = dyn_cast<LoadInst>(I)) { -    Alignment = MaybeAlign(LI->getAlignment()); -    Ty = LI->getType(); -  } else if (auto *SI = dyn_cast<StoreInst>(I)) { -    Alignment = MaybeAlign(SI->getAlignment()); -    Ty = SI->getValueOperand()->getType(); -  } else { -    llvm_unreachable("Only loads and stores are allowed!"); -  } -  return commonAlignment(DL.getValueOrABITypeAlignment(Alignment, Ty), Offset); +static Align getAdjustedAlignment(Instruction *I, uint64_t Offset) { +  return commonAlignment(getLoadStoreAlignment(I), Offset);  }  /// Test whether we can convert a value from the old to the new type. @@ -1717,7 +1691,8 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) {      return false;    } -  if (DL.getTypeSizeInBits(NewTy) != DL.getTypeSizeInBits(OldTy)) +  if (DL.getTypeSizeInBits(NewTy).getFixedSize() != +      DL.getTypeSizeInBits(OldTy).getFixedSize())      return false;    if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType())      return false; @@ -1728,8 +1703,15 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) {    NewTy = NewTy->getScalarType();    if (NewTy->isPointerTy() || OldTy->isPointerTy()) {      if (NewTy->isPointerTy() && OldTy->isPointerTy()) { -      return cast<PointerType>(NewTy)->getPointerAddressSpace() == -        cast<PointerType>(OldTy)->getPointerAddressSpace(); +      unsigned OldAS = OldTy->getPointerAddressSpace(); +      unsigned NewAS = NewTy->getPointerAddressSpace(); +      // Convert pointers if they are pointers from the same address space or +      // different integral (not non-integral) address spaces with the same +      // pointer size. +      return OldAS == NewAS || +             (!DL.isNonIntegralAddressSpace(OldAS) && +              !DL.isNonIntegralAddressSpace(NewAS) && +              DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS));      }      // We can convert integers to integral pointers, but not to non-integral @@ -1765,36 +1747,40 @@ static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V,    assert(!(isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) &&           "Integer types must be the exact same to convert."); -  // See if we need inttoptr for this type pair. A cast involving both scalars -  // and vectors requires and additional bitcast. +  // See if we need inttoptr for this type pair. May require additional bitcast.    if (OldTy->isIntOrIntVectorTy() && NewTy->isPtrOrPtrVectorTy()) {      // Expand <2 x i32> to i8* --> <2 x i32> to i64 to i8* -    if (OldTy->isVectorTy() && !NewTy->isVectorTy()) -      return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), -                                NewTy); -      // Expand i128 to <2 x i8*> --> i128 to <2 x i64> to <2 x i8*> -    if (!OldTy->isVectorTy() && NewTy->isVectorTy()) -      return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), -                                NewTy); - -    return IRB.CreateIntToPtr(V, NewTy); +    // Expand <4 x i32> to <2 x i8*> --> <4 x i32> to <2 x i64> to <2 x i8*> +    // Directly handle i64 to i8* +    return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), +                              NewTy);    } -  // See if we need ptrtoint for this type pair. A cast involving both scalars -  // and vectors requires and additional bitcast. +  // See if we need ptrtoint for this type pair. May require additional bitcast.    if (OldTy->isPtrOrPtrVectorTy() && NewTy->isIntOrIntVectorTy()) {      // Expand <2 x i8*> to i128 --> <2 x i8*> to <2 x i64> to i128 -    if (OldTy->isVectorTy() && !NewTy->isVectorTy()) -      return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), -                               NewTy); -      // Expand i8* to <2 x i32> --> i8* to i64 to <2 x i32> -    if (!OldTy->isVectorTy() && NewTy->isVectorTy()) -      return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), -                               NewTy); +    // Expand <2 x i8*> to <4 x i32> --> <2 x i8*> to <2 x i64> to <4 x i32> +    // Expand i8* to i64 --> i8* to i64 to i64 +    return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), +                             NewTy); +  } -    return IRB.CreatePtrToInt(V, NewTy); +  if (OldTy->isPtrOrPtrVectorTy() && NewTy->isPtrOrPtrVectorTy()) { +    unsigned OldAS = OldTy->getPointerAddressSpace(); +    unsigned NewAS = NewTy->getPointerAddressSpace(); +    // To convert pointers with different address spaces (they are already +    // checked convertible, i.e. they have the same pointer size), so far we +    // cannot use `bitcast` (which has restrict on the same address space) or +    // `addrspacecast` (which is not always no-op casting). Instead, use a pair +    // of no-op `ptrtoint`/`inttoptr` casts through an integer with the same bit +    // size. +    if (OldAS != NewAS) { +      assert(DL.getPointerSize(OldAS) == DL.getPointerSize(NewAS)); +      return IRB.CreateIntToPtr(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), +                                NewTy); +    }    }    return IRB.CreateBitCast(V, NewTy); @@ -1813,19 +1799,20 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S,        std::max(S.beginOffset(), P.beginOffset()) - P.beginOffset();    uint64_t BeginIndex = BeginOffset / ElementSize;    if (BeginIndex * ElementSize != BeginOffset || -      BeginIndex >= Ty->getNumElements()) +      BeginIndex >= cast<FixedVectorType>(Ty)->getNumElements())      return false;    uint64_t EndOffset =        std::min(S.endOffset(), P.endOffset()) - P.beginOffset();    uint64_t EndIndex = EndOffset / ElementSize; -  if (EndIndex * ElementSize != EndOffset || EndIndex > Ty->getNumElements()) +  if (EndIndex * ElementSize != EndOffset || +      EndIndex > cast<FixedVectorType>(Ty)->getNumElements())      return false;    assert(EndIndex > BeginIndex && "Empty vector!");    uint64_t NumElements = EndIndex - BeginIndex;    Type *SliceTy = (NumElements == 1)                        ? Ty->getElementType() -                      : VectorType::get(Ty->getElementType(), NumElements); +                      : FixedVectorType::get(Ty->getElementType(), NumElements);    Type *SplitIntTy =        Type::getIntNTy(Ty->getContext(), NumElements * ElementSize * 8); @@ -1890,7 +1877,8 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {        // Return if bitcast to vectors is different for total size in bits.        if (!CandidateTys.empty()) {          VectorType *V = CandidateTys[0]; -        if (DL.getTypeSizeInBits(VTy) != DL.getTypeSizeInBits(V)) { +        if (DL.getTypeSizeInBits(VTy).getFixedSize() != +            DL.getTypeSizeInBits(V).getFixedSize()) {            CandidateTys.clear();            return;          } @@ -1936,13 +1924,15 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {      // they're all integer vectors. We sort by ascending number of elements.      auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) {        (void)DL; -      assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) && +      assert(DL.getTypeSizeInBits(RHSTy).getFixedSize() == +                 DL.getTypeSizeInBits(LHSTy).getFixedSize() &&               "Cannot have vector types of different sizes!");        assert(RHSTy->getElementType()->isIntegerTy() &&               "All non-integer types eliminated!");        assert(LHSTy->getElementType()->isIntegerTy() &&               "All non-integer types eliminated!"); -      return RHSTy->getNumElements() < LHSTy->getNumElements(); +      return cast<FixedVectorType>(RHSTy)->getNumElements() < +             cast<FixedVectorType>(LHSTy)->getNumElements();      };      llvm::sort(CandidateTys, RankVectorTypes);      CandidateTys.erase( @@ -1964,13 +1954,14 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {    // Try each vector type, and return the one which works.    auto CheckVectorTypeForPromotion = [&](VectorType *VTy) { -    uint64_t ElementSize = DL.getTypeSizeInBits(VTy->getElementType()); +    uint64_t ElementSize = +        DL.getTypeSizeInBits(VTy->getElementType()).getFixedSize();      // While the definition of LLVM vectors is bitpacked, we don't support sizes      // that aren't byte sized.      if (ElementSize % 8)        return false; -    assert((DL.getTypeSizeInBits(VTy) % 8) == 0 && +    assert((DL.getTypeSizeInBits(VTy).getFixedSize() % 8) == 0 &&             "vector size not a multiple of element size?");      ElementSize /= 8; @@ -2000,7 +1991,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,                                              Type *AllocaTy,                                              const DataLayout &DL,                                              bool &WholeAllocaOp) { -  uint64_t Size = DL.getTypeStoreSize(AllocaTy); +  uint64_t Size = DL.getTypeStoreSize(AllocaTy).getFixedSize();    uint64_t RelBegin = S.beginOffset() - AllocBeginOffset;    uint64_t RelEnd = S.endOffset() - AllocBeginOffset; @@ -2016,7 +2007,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,      if (LI->isVolatile())        return false;      // We can't handle loads that extend past the allocated memory. -    if (DL.getTypeStoreSize(LI->getType()) > Size) +    if (DL.getTypeStoreSize(LI->getType()).getFixedSize() > Size)        return false;      // So far, AllocaSliceRewriter does not support widening split slice tails      // in rewriteIntegerLoad. @@ -2028,7 +2019,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,      if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size)        WholeAllocaOp = true;      if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) { -      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) +      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize())          return false;      } else if (RelBegin != 0 || RelEnd != Size ||                 !canConvertValue(DL, AllocaTy, LI->getType())) { @@ -2041,7 +2032,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,      if (SI->isVolatile())        return false;      // We can't handle stores that extend past the allocated memory. -    if (DL.getTypeStoreSize(ValueTy) > Size) +    if (DL.getTypeStoreSize(ValueTy).getFixedSize() > Size)        return false;      // So far, AllocaSliceRewriter does not support widening split slice tails      // in rewriteIntegerStore. @@ -2053,7 +2044,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,      if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size)        WholeAllocaOp = true;      if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) { -      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) +      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize())          return false;      } else if (RelBegin != 0 || RelEnd != Size ||                 !canConvertValue(DL, ValueTy, AllocaTy)) { @@ -2084,13 +2075,13 @@ static bool isIntegerWideningViableForSlice(const Slice &S,  /// promote the resulting alloca.  static bool isIntegerWideningViable(Partition &P, Type *AllocaTy,                                      const DataLayout &DL) { -  uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy); +  uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy).getFixedSize();    // Don't create integer types larger than the maximum bitwidth.    if (SizeInBits > IntegerType::MAX_INT_BITS)      return false;    // Don't try to handle allocas with bit-padding. -  if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy)) +  if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy).getFixedSize())      return false;    // We need to ensure that an integer type with the appropriate bitwidth can @@ -2129,11 +2120,13 @@ static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V,                               const Twine &Name) {    LLVM_DEBUG(dbgs() << "       start: " << *V << "\n");    IntegerType *IntTy = cast<IntegerType>(V->getType()); -  assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && +  assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <= +             DL.getTypeStoreSize(IntTy).getFixedSize() &&           "Element extends past full value");    uint64_t ShAmt = 8 * Offset;    if (DL.isBigEndian()) -    ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); +    ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() - +                 DL.getTypeStoreSize(Ty).getFixedSize() - Offset);    if (ShAmt) {      V = IRB.CreateLShr(V, ShAmt, Name + ".shift");      LLVM_DEBUG(dbgs() << "     shifted: " << *V << "\n"); @@ -2158,11 +2151,13 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old,      V = IRB.CreateZExt(V, IntTy, Name + ".ext");      LLVM_DEBUG(dbgs() << "    extended: " << *V << "\n");    } -  assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && +  assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <= +             DL.getTypeStoreSize(IntTy).getFixedSize() &&           "Element store outside of alloca store");    uint64_t ShAmt = 8 * Offset;    if (DL.isBigEndian()) -    ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); +    ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() - +                 DL.getTypeStoreSize(Ty).getFixedSize() - Offset);    if (ShAmt) {      V = IRB.CreateShl(V, ShAmt, Name + ".shift");      LLVM_DEBUG(dbgs() << "     shifted: " << *V << "\n"); @@ -2180,7 +2175,7 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old,  static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex,                              unsigned EndIndex, const Twine &Name) { -  VectorType *VecTy = cast<VectorType>(V->getType()); +  auto *VecTy = cast<FixedVectorType>(V->getType());    unsigned NumElements = EndIndex - BeginIndex;    assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); @@ -2194,12 +2189,12 @@ static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex,      return V;    } -  SmallVector<Constant *, 8> Mask; +  SmallVector<int, 8> Mask;    Mask.reserve(NumElements);    for (unsigned i = BeginIndex; i != EndIndex; ++i) -    Mask.push_back(IRB.getInt32(i)); -  V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), -                              ConstantVector::get(Mask), Name + ".extract"); +    Mask.push_back(i); +  V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), Mask, +                              Name + ".extract");    LLVM_DEBUG(dbgs() << "     shuffle: " << *V << "\n");    return V;  } @@ -2218,21 +2213,23 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,      return V;    } -  assert(Ty->getNumElements() <= VecTy->getNumElements() && +  assert(cast<FixedVectorType>(Ty)->getNumElements() <= +             cast<FixedVectorType>(VecTy)->getNumElements() &&           "Too many elements!"); -  if (Ty->getNumElements() == VecTy->getNumElements()) { +  if (cast<FixedVectorType>(Ty)->getNumElements() == +      cast<FixedVectorType>(VecTy)->getNumElements()) {      assert(V->getType() == VecTy && "Vector type mismatch");      return V;    } -  unsigned EndIndex = BeginIndex + Ty->getNumElements(); +  unsigned EndIndex = BeginIndex + cast<FixedVectorType>(Ty)->getNumElements();    // When inserting a smaller vector into the larger to store, we first    // use a shuffle vector to widen it with undef elements, and then    // a second shuffle vector to select between the loaded vector and the    // incoming vector.    SmallVector<Constant *, 8> Mask; -  Mask.reserve(VecTy->getNumElements()); -  for (unsigned i = 0; i != VecTy->getNumElements(); ++i) +  Mask.reserve(cast<FixedVectorType>(VecTy)->getNumElements()); +  for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i)      if (i >= BeginIndex && i < EndIndex)        Mask.push_back(IRB.getInt32(i - BeginIndex));      else @@ -2242,7 +2239,7 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V,    LLVM_DEBUG(dbgs() << "    shuffle: " << *V << "\n");    Mask.clear(); -  for (unsigned i = 0; i != VecTy->getNumElements(); ++i) +  for (unsigned i = 0; i != cast<FixedVectorType>(VecTy)->getNumElements(); ++i)      Mask.push_back(IRB.getInt1(i >= BeginIndex && i < EndIndex));    V = IRB.CreateSelect(ConstantVector::get(Mask), V, Old, Name + "blend"); @@ -2325,18 +2322,20 @@ public:          NewAllocaBeginOffset(NewAllocaBeginOffset),          NewAllocaEndOffset(NewAllocaEndOffset),          NewAllocaTy(NewAI.getAllocatedType()), -        IntTy(IsIntegerPromotable -                  ? Type::getIntNTy( -                        NewAI.getContext(), -                        DL.getTypeSizeInBits(NewAI.getAllocatedType())) -                  : nullptr), +        IntTy( +            IsIntegerPromotable +                ? Type::getIntNTy(NewAI.getContext(), +                                  DL.getTypeSizeInBits(NewAI.getAllocatedType()) +                                      .getFixedSize()) +                : nullptr),          VecTy(PromotableVecTy),          ElementTy(VecTy ? VecTy->getElementType() : nullptr), -        ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0), +        ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8 +                          : 0),          PHIUsers(PHIUsers), SelectUsers(SelectUsers),          IRB(NewAI.getContext(), ConstantFolder()) {      if (VecTy) { -      assert((DL.getTypeSizeInBits(ElementTy) % 8) == 0 && +      assert((DL.getTypeSizeInBits(ElementTy).getFixedSize() % 8) == 0 &&               "Only multiple-of-8 sized vector elements are viable");        ++NumVectorized;      } @@ -2368,7 +2367,8 @@ public:      Instruction *OldUserI = cast<Instruction>(OldUse->getUser());      IRB.SetInsertPoint(OldUserI);      IRB.SetCurrentDebugLocation(OldUserI->getDebugLoc()); -    IRB.SetNamePrefix(Twine(NewAI.getName()) + "." + Twine(BeginOffset) + "."); +    IRB.getInserter().SetNamePrefix( +        Twine(NewAI.getName()) + "." + Twine(BeginOffset) + ".");      CanSROA &= visit(cast<Instruction>(OldUse->getUser()));      if (VecTy || IntTy) @@ -2429,14 +2429,9 @@ private:    ///    /// You can optionally pass a type to this routine and if that type's ABI    /// alignment is itself suitable, this will return zero. -  MaybeAlign getSliceAlign(Type *Ty = nullptr) { -    const MaybeAlign NewAIAlign = DL.getValueOrABITypeAlignment( -        MaybeAlign(NewAI.getAlignment()), NewAI.getAllocatedType()); -    const MaybeAlign Align = -        commonAlignment(NewAIAlign, NewBeginOffset - NewAllocaBeginOffset); -    return (Ty && Align && Align->value() == DL.getABITypeAlignment(Ty)) -               ? None -               : Align; +  Align getSliceAlign() { +    return commonAlignment(NewAI.getAlign(), +                           NewBeginOffset - NewAllocaBeginOffset);    }    unsigned getIndex(uint64_t Offset) { @@ -2460,7 +2455,7 @@ private:      assert(EndIndex > BeginIndex && "Empty vector!");      Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                     NewAI.getAlignment(), "load"); +                                     NewAI.getAlign(), "load");      return extractVector(IRB, V, BeginIndex, EndIndex, "vec");    } @@ -2468,7 +2463,7 @@ private:      assert(IntTy && "We cannot insert an integer to the alloca");      assert(!LI.isVolatile());      Value *V = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                     NewAI.getAlignment(), "load"); +                                     NewAI.getAlign(), "load");      V = convertValue(DL, IRB, V, IntTy);      assert(NewBeginOffset >= NewAllocaBeginOffset && "Out of bounds offset");      uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; @@ -2500,7 +2495,8 @@ private:      Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8)                               : LI.getType(); -    const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize; +    const bool IsLoadPastEnd = +        DL.getTypeStoreSize(TargetTy).getFixedSize() > SliceSize;      bool IsPtrAdjusted = false;      Value *V;      if (VecTy) { @@ -2513,12 +2509,14 @@ private:                  (IsLoadPastEnd && NewAllocaTy->isIntegerTy() &&                   TargetTy->isIntegerTy()))) {        LoadInst *NewLI = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                              NewAI.getAlignment(), -                                              LI.isVolatile(), LI.getName()); +                                              NewAI.getAlign(), LI.isVolatile(), +                                              LI.getName());        if (AATags)          NewLI->setAAMetadata(AATags);        if (LI.isVolatile())          NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); +      if (NewLI->isAtomic()) +        NewLI->setAlignment(LI.getAlign());        // Any !nonnull metadata or !range metadata on the old load is also valid        // on the new load. This is even true in some cases even when the loads @@ -2549,9 +2547,9 @@ private:            }      } else {        Type *LTy = TargetTy->getPointerTo(AS); -      LoadInst *NewLI = IRB.CreateAlignedLoad( -          TargetTy, getNewAllocaSlicePtr(IRB, LTy), getSliceAlign(TargetTy), -          LI.isVolatile(), LI.getName()); +      LoadInst *NewLI = +          IRB.CreateAlignedLoad(TargetTy, getNewAllocaSlicePtr(IRB, LTy), +                                getSliceAlign(), LI.isVolatile(), LI.getName());        if (AATags)          NewLI->setAAMetadata(AATags);        if (LI.isVolatile()) @@ -2566,7 +2564,7 @@ private:        assert(!LI.isVolatile());        assert(LI.getType()->isIntegerTy() &&               "Only integer type loads and stores are split"); -      assert(SliceSize < DL.getTypeStoreSize(LI.getType()) && +      assert(SliceSize < DL.getTypeStoreSize(LI.getType()).getFixedSize() &&               "Split load isn't smaller than original load");        assert(DL.typeSizeEqualsStoreSize(LI.getType()) &&               "Non-byte-multiple bit width"); @@ -2577,7 +2575,8 @@ private:        // the computed value, and then replace the placeholder with LI, leaving        // LI only used for this computation.        Value *Placeholder = new LoadInst( -          LI.getType(), UndefValue::get(LI.getType()->getPointerTo(AS))); +          LI.getType(), UndefValue::get(LI.getType()->getPointerTo(AS)), "", +          false, Align(1));        V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset,                          "insert");        LI.replaceAllUsesWith(V); @@ -2600,19 +2599,20 @@ private:        unsigned EndIndex = getIndex(NewEndOffset);        assert(EndIndex > BeginIndex && "Empty vector!");        unsigned NumElements = EndIndex - BeginIndex; -      assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); +      assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() && +             "Too many elements!");        Type *SliceTy = (NumElements == 1)                            ? ElementTy -                          : VectorType::get(ElementTy, NumElements); +                          : FixedVectorType::get(ElementTy, NumElements);        if (V->getType() != SliceTy)          V = convertValue(DL, IRB, V, SliceTy);        // Mix in the existing elements.        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                         NewAI.getAlignment(), "load"); +                                         NewAI.getAlign(), "load");        V = insertVector(IRB, Old, V, BeginIndex, "vec");      } -    StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); +    StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign());      if (AATags)        Store->setAAMetadata(AATags);      Pass.DeadInsts.insert(&SI); @@ -2624,16 +2624,17 @@ private:    bool rewriteIntegerStore(Value *V, StoreInst &SI, AAMDNodes AATags) {      assert(IntTy && "We cannot extract an integer from the alloca");      assert(!SI.isVolatile()); -    if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) { +    if (DL.getTypeSizeInBits(V->getType()).getFixedSize() != +        IntTy->getBitWidth()) {        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                         NewAI.getAlignment(), "oldload"); +                                         NewAI.getAlign(), "oldload");        Old = convertValue(DL, IRB, Old, IntTy);        assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset");        uint64_t Offset = BeginOffset - NewAllocaBeginOffset;        V = insertInteger(DL, IRB, Old, SI.getValueOperand(), Offset, "insert");      }      V = convertValue(DL, IRB, V, NewAllocaTy); -    StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); +    StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign());      Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access,                               LLVMContext::MD_access_group});      if (AATags) @@ -2659,7 +2660,7 @@ private:        if (AllocaInst *AI = dyn_cast<AllocaInst>(V->stripInBoundsOffsets()))          Pass.PostPromotionWorklist.insert(AI); -    if (SliceSize < DL.getTypeStoreSize(V->getType())) { +    if (SliceSize < DL.getTypeStoreSize(V->getType()).getFixedSize()) {        assert(!SI.isVolatile());        assert(V->getType()->isIntegerTy() &&               "Only integer type loads and stores are split"); @@ -2675,7 +2676,8 @@ private:      if (IntTy && V->getType()->isIntegerTy())        return rewriteIntegerStore(V, SI, AATags); -    const bool IsStorePastEnd = DL.getTypeStoreSize(V->getType()) > SliceSize; +    const bool IsStorePastEnd = +        DL.getTypeStoreSize(V->getType()).getFixedSize() > SliceSize;      StoreInst *NewSI;      if (NewBeginOffset == NewAllocaBeginOffset &&          NewEndOffset == NewAllocaEndOffset && @@ -2695,13 +2697,13 @@ private:            }        V = convertValue(DL, IRB, V, NewAllocaTy); -      NewSI = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), -                                     SI.isVolatile()); +      NewSI = +          IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), SI.isVolatile());      } else {        unsigned AS = SI.getPointerAddressSpace();        Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS)); -      NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()), -                                     SI.isVolatile()); +      NewSI = +          IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(), SI.isVolatile());      }      NewSI->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access,                               LLVMContext::MD_access_group}); @@ -2709,6 +2711,8 @@ private:        NewSI->setAAMetadata(AATags);      if (SI.isVolatile())        NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); +    if (NewSI->isAtomic()) +      NewSI->setAlignment(SI.getAlign());      Pass.DeadInsts.insert(&SI);      deleteIfTriviallyDead(OldOp); @@ -2786,9 +2790,9 @@ private:          return false;        const auto Len = C->getZExtValue();        auto *Int8Ty = IntegerType::getInt8Ty(NewAI.getContext()); -      auto *SrcTy = VectorType::get(Int8Ty, Len); +      auto *SrcTy = FixedVectorType::get(Int8Ty, Len);        return canConvertValue(DL, SrcTy, AllocaTy) && -        DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy)); +             DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy).getFixedSize());      }();      // If this doesn't map cleanly onto the alloca type, and that type isn't @@ -2820,16 +2824,17 @@ private:        unsigned EndIndex = getIndex(NewEndOffset);        assert(EndIndex > BeginIndex && "Empty vector!");        unsigned NumElements = EndIndex - BeginIndex; -      assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); +      assert(NumElements <= cast<FixedVectorType>(VecTy)->getNumElements() && +             "Too many elements!"); -      Value *Splat = -          getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ElementTy) / 8); +      Value *Splat = getIntegerSplat( +          II.getValue(), DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8);        Splat = convertValue(DL, IRB, Splat, ElementTy);        if (NumElements > 1)          Splat = getVectorSplat(Splat, NumElements);        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                         NewAI.getAlignment(), "oldload"); +                                         NewAI.getAlign(), "oldload");        V = insertVector(IRB, Old, Splat, BeginIndex, "vec");      } else if (IntTy) {        // If this is a memset on an alloca where we can widen stores, insert the @@ -2842,7 +2847,7 @@ private:        if (IntTy && (BeginOffset != NewAllocaBeginOffset ||                      EndOffset != NewAllocaBeginOffset)) {          Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                           NewAI.getAlignment(), "oldload"); +                                           NewAI.getAlign(), "oldload");          Old = convertValue(DL, IRB, Old, IntTy);          uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;          V = insertInteger(DL, IRB, Old, V, Offset, "insert"); @@ -2856,15 +2861,17 @@ private:        assert(NewBeginOffset == NewAllocaBeginOffset);        assert(NewEndOffset == NewAllocaEndOffset); -      V = getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ScalarTy) / 8); +      V = getIntegerSplat(II.getValue(), +                          DL.getTypeSizeInBits(ScalarTy).getFixedSize() / 8);        if (VectorType *AllocaVecTy = dyn_cast<VectorType>(AllocaTy)) -        V = getVectorSplat(V, AllocaVecTy->getNumElements()); +        V = getVectorSplat( +            V, cast<FixedVectorType>(AllocaVecTy)->getNumElements());        V = convertValue(DL, IRB, V, AllocaTy);      } -    StoreInst *New = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), -                                            II.isVolatile()); +    StoreInst *New = +        IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlign(), II.isVolatile());      if (AATags)        New->setAAMetadata(AATags);      LLVM_DEBUG(dbgs() << "          to: " << *New << "\n"); @@ -2919,7 +2926,8 @@ private:      bool EmitMemCpy =          !VecTy && !IntTy &&          (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || -         SliceSize != DL.getTypeStoreSize(NewAI.getAllocatedType()) || +         SliceSize != +             DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedSize() ||           !NewAI.getAllocatedType()->isSingleValueType());      // If we're just going to emit a memcpy, the alloca hasn't changed, and the @@ -2955,7 +2963,7 @@ private:      unsigned OffsetWidth = DL.getIndexSizeInBits(OtherAS);      APInt OtherOffset(OffsetWidth, NewBeginOffset - BeginOffset);      Align OtherAlign = -        assumeAligned(IsDest ? II.getSourceAlignment() : II.getDestAlignment()); +        (IsDest ? II.getSourceAlign() : II.getDestAlign()).valueOrOne();      OtherAlign =          commonAlignment(OtherAlign, OtherOffset.zextOrTrunc(64).getZExtValue()); @@ -3007,7 +3015,7 @@ private:        if (NumElements == 1)          OtherTy = VecTy->getElementType();        else -        OtherTy = VectorType::get(VecTy->getElementType(), NumElements); +        OtherTy = FixedVectorType::get(VecTy->getElementType(), NumElements);      } else if (IntTy && !IsWholeAlloca) {        OtherTy = SubIntTy;      } else { @@ -3028,11 +3036,11 @@ private:      Value *Src;      if (VecTy && !IsWholeAlloca && !IsDest) {        Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                  NewAI.getAlignment(), "load"); +                                  NewAI.getAlign(), "load");        Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec");      } else if (IntTy && !IsWholeAlloca && !IsDest) {        Src = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                  NewAI.getAlignment(), "load"); +                                  NewAI.getAlign(), "load");        Src = convertValue(DL, IRB, Src, IntTy);        uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;        Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract"); @@ -3046,11 +3054,11 @@ private:      if (VecTy && !IsWholeAlloca && IsDest) {        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                         NewAI.getAlignment(), "oldload"); +                                         NewAI.getAlign(), "oldload");        Src = insertVector(IRB, Old, Src, BeginIndex, "vec");      } else if (IntTy && !IsWholeAlloca && IsDest) {        Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI, -                                         NewAI.getAlignment(), "oldload"); +                                         NewAI.getAlign(), "oldload");        Old = convertValue(DL, IRB, Old, IntTy);        uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset;        Src = insertInteger(DL, IRB, Old, Src, Offset, "insert"); @@ -3115,17 +3123,12 @@ private:        Instruction *I = Uses.pop_back_val();        if (LoadInst *LI = dyn_cast<LoadInst>(I)) { -        MaybeAlign LoadAlign = DL.getValueOrABITypeAlignment( -            MaybeAlign(LI->getAlignment()), LI->getType()); -        LI->setAlignment(std::min(LoadAlign, getSliceAlign())); +        LI->setAlignment(std::min(LI->getAlign(), getSliceAlign()));          continue;        }        if (StoreInst *SI = dyn_cast<StoreInst>(I)) { -          Value *Op = SI->getOperand(0); -          MaybeAlign StoreAlign = DL.getValueOrABITypeAlignment( -              MaybeAlign(SI->getAlignment()), Op->getType()); -          SI->setAlignment(std::min(StoreAlign, getSliceAlign())); -          continue; +        SI->setAlignment(std::min(SI->getAlign(), getSliceAlign())); +        continue;        }        assert(isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I) || @@ -3146,14 +3149,14 @@ private:      // as local as possible to the PHI. To do that, we re-use the location of      // the old pointer, which necessarily must be in the right position to      // dominate the PHI. -    IRBuilderTy PtrBuilder(IRB); +    IRBuilderBase::InsertPointGuard Guard(IRB);      if (isa<PHINode>(OldPtr)) -      PtrBuilder.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt()); +      IRB.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt());      else -      PtrBuilder.SetInsertPoint(OldPtr); -    PtrBuilder.SetCurrentDebugLocation(OldPtr->getDebugLoc()); +      IRB.SetInsertPoint(OldPtr); +    IRB.SetCurrentDebugLocation(OldPtr->getDebugLoc()); -    Value *NewPtr = getNewAllocaSlicePtr(PtrBuilder, OldPtr->getType()); +    Value *NewPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType());      // Replace the operands which were using the old pointer.      std::replace(PN.op_begin(), PN.op_end(), cast<Value>(OldPtr), NewPtr); @@ -3357,7 +3360,7 @@ private:        Value *GEP =            IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep");        LoadInst *Load = -          IRB.CreateAlignedLoad(Ty, GEP, Alignment.value(), Name + ".load"); +          IRB.CreateAlignedLoad(Ty, GEP, Alignment, Name + ".load");        if (AATags)          Load->setAAMetadata(AATags);        Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); @@ -3375,9 +3378,10 @@ private:      AAMDNodes AATags;      LI.getAAMetadata(AATags);      LoadOpSplitter Splitter(&LI, *U, LI.getType(), AATags, -                            getAdjustedAlignment(&LI, 0, DL), DL); +                            getAdjustedAlignment(&LI, 0), DL);      Value *V = UndefValue::get(LI.getType());      Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca"); +    Visited.erase(&LI);      LI.replaceAllUsesWith(V);      LI.eraseFromParent();      return true; @@ -3403,7 +3407,7 @@ private:        Value *InBoundsGEP =            IRB.CreateInBoundsGEP(BaseTy, Ptr, GEPIndices, Name + ".gep");        StoreInst *Store = -          IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Alignment.value()); +          IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Alignment);        if (AATags)          Store->setAAMetadata(AATags);        LLVM_DEBUG(dbgs() << "          to: " << *Store << "\n"); @@ -3422,8 +3426,9 @@ private:      AAMDNodes AATags;      SI.getAAMetadata(AATags);      StoreOpSplitter Splitter(&SI, *U, V->getType(), AATags, -                             getAdjustedAlignment(&SI, 0, DL), DL); +                             getAdjustedAlignment(&SI, 0), DL);      Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); +    Visited.erase(&SI);      SI.eraseFromParent();      return true;    } @@ -3438,7 +3443,110 @@ private:      return false;    } +  // Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2) +  bool foldGEPSelect(GetElementPtrInst &GEPI) { +    if (!GEPI.hasAllConstantIndices()) +      return false; + +    SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand()); + +    LLVM_DEBUG(dbgs() << "  Rewriting gep(select) -> select(gep):" +                      << "\n    original: " << *Sel +                      << "\n              " << GEPI); + +    IRBuilderTy Builder(&GEPI); +    SmallVector<Value *, 4> Index(GEPI.idx_begin(), GEPI.idx_end()); +    bool IsInBounds = GEPI.isInBounds(); + +    Value *True = Sel->getTrueValue(); +    Value *NTrue = +        IsInBounds +            ? Builder.CreateInBoundsGEP(True, Index, +                                        True->getName() + ".sroa.gep") +            : Builder.CreateGEP(True, Index, True->getName() + ".sroa.gep"); + +    Value *False = Sel->getFalseValue(); + +    Value *NFalse = +        IsInBounds +            ? Builder.CreateInBoundsGEP(False, Index, +                                        False->getName() + ".sroa.gep") +            : Builder.CreateGEP(False, Index, False->getName() + ".sroa.gep"); + +    Value *NSel = Builder.CreateSelect(Sel->getCondition(), NTrue, NFalse, +                                       Sel->getName() + ".sroa.sel"); +    Visited.erase(&GEPI); +    GEPI.replaceAllUsesWith(NSel); +    GEPI.eraseFromParent(); +    Instruction *NSelI = cast<Instruction>(NSel); +    Visited.insert(NSelI); +    enqueueUsers(*NSelI); + +    LLVM_DEBUG(dbgs() << "\n          to: " << *NTrue +                      << "\n              " << *NFalse +                      << "\n              " << *NSel << '\n'); + +    return true; +  } + +  // Fold gep (phi ptr1, ptr2) => phi gep(ptr1), gep(ptr2) +  bool foldGEPPhi(GetElementPtrInst &GEPI) { +    if (!GEPI.hasAllConstantIndices()) +      return false; + +    PHINode *PHI = cast<PHINode>(GEPI.getPointerOperand()); +    if (GEPI.getParent() != PHI->getParent() || +        llvm::any_of(PHI->incoming_values(), [](Value *In) +          { Instruction *I = dyn_cast<Instruction>(In); +            return !I || isa<GetElementPtrInst>(I) || isa<PHINode>(I) || +                   succ_empty(I->getParent()) || +                   !I->getParent()->isLegalToHoistInto(); +          })) +      return false; + +    LLVM_DEBUG(dbgs() << "  Rewriting gep(phi) -> phi(gep):" +                      << "\n    original: " << *PHI +                      << "\n              " << GEPI +                      << "\n          to: "); + +    SmallVector<Value *, 4> Index(GEPI.idx_begin(), GEPI.idx_end()); +    bool IsInBounds = GEPI.isInBounds(); +    IRBuilderTy PHIBuilder(GEPI.getParent()->getFirstNonPHI()); +    PHINode *NewPN = PHIBuilder.CreatePHI(GEPI.getType(), +                                          PHI->getNumIncomingValues(), +                                          PHI->getName() + ".sroa.phi"); +    for (unsigned I = 0, E = PHI->getNumIncomingValues(); I != E; ++I) { +      Instruction *In = cast<Instruction>(PHI->getIncomingValue(I)); + +      IRBuilderTy B(In->getParent(), std::next(In->getIterator())); +      Value *NewVal = IsInBounds +          ? B.CreateInBoundsGEP(In, Index, In->getName() + ".sroa.gep") +          : B.CreateGEP(In, Index, In->getName() + ".sroa.gep"); +      NewPN->addIncoming(NewVal, PHI->getIncomingBlock(I)); +    } + +    Visited.erase(&GEPI); +    GEPI.replaceAllUsesWith(NewPN); +    GEPI.eraseFromParent(); +    Visited.insert(NewPN); +    enqueueUsers(*NewPN); + +    LLVM_DEBUG(for (Value *In : NewPN->incoming_values()) +                 dbgs() << "\n              " << *In; +               dbgs() << "\n              " << *NewPN << '\n'); + +    return true; +  } +    bool visitGetElementPtrInst(GetElementPtrInst &GEPI) { +    if (isa<SelectInst>(GEPI.getPointerOperand()) && +        foldGEPSelect(GEPI)) +      return true; + +    if (isa<PHINode>(GEPI.getPointerOperand()) && +        foldGEPPhi(GEPI)) +      return true; +      enqueueUsers(GEPI);      return false;    } @@ -3465,8 +3573,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {    if (Ty->isSingleValueType())      return Ty; -  uint64_t AllocSize = DL.getTypeAllocSize(Ty); -  uint64_t TypeSize = DL.getTypeSizeInBits(Ty); +  uint64_t AllocSize = DL.getTypeAllocSize(Ty).getFixedSize(); +  uint64_t TypeSize = DL.getTypeSizeInBits(Ty).getFixedSize();    Type *InnerTy;    if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) { @@ -3479,8 +3587,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {      return Ty;    } -  if (AllocSize > DL.getTypeAllocSize(InnerTy) || -      TypeSize > DL.getTypeSizeInBits(InnerTy)) +  if (AllocSize > DL.getTypeAllocSize(InnerTy).getFixedSize() || +      TypeSize > DL.getTypeSizeInBits(InnerTy).getFixedSize())      return Ty;    return stripAggregateTypeWrapping(DL, InnerTy); @@ -3501,17 +3609,28 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {  /// return a type if necessary.  static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,                                uint64_t Size) { -  if (Offset == 0 && DL.getTypeAllocSize(Ty) == Size) +  if (Offset == 0 && DL.getTypeAllocSize(Ty).getFixedSize() == Size)      return stripAggregateTypeWrapping(DL, Ty); -  if (Offset > DL.getTypeAllocSize(Ty) || -      (DL.getTypeAllocSize(Ty) - Offset) < Size) +  if (Offset > DL.getTypeAllocSize(Ty).getFixedSize() || +      (DL.getTypeAllocSize(Ty).getFixedSize() - Offset) < Size)      return nullptr; -  if (SequentialType *SeqTy = dyn_cast<SequentialType>(Ty)) { -    Type *ElementTy = SeqTy->getElementType(); -    uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); +  if (isa<ArrayType>(Ty) || isa<VectorType>(Ty)) { +     Type *ElementTy; +     uint64_t TyNumElements; +     if (auto *AT = dyn_cast<ArrayType>(Ty)) { +       ElementTy = AT->getElementType(); +       TyNumElements = AT->getNumElements(); +     } else { +       // FIXME: This isn't right for vectors with non-byte-sized or +       // non-power-of-two sized elements. +       auto *VT = cast<FixedVectorType>(Ty); +       ElementTy = VT->getElementType(); +       TyNumElements = VT->getNumElements(); +    } +    uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize();      uint64_t NumSkippedElements = Offset / ElementSize; -    if (NumSkippedElements >= SeqTy->getNumElements()) +    if (NumSkippedElements >= TyNumElements)        return nullptr;      Offset -= NumSkippedElements * ElementSize; @@ -3549,7 +3668,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,    Offset -= SL->getElementOffset(Index);    Type *ElementTy = STy->getElementType(Index); -  uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); +  uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize();    if (Offset >= ElementSize)      return nullptr; // The offset points into alignment padding. @@ -3860,7 +3979,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {            getAdjustedPtr(IRB, DL, BasePtr,                           APInt(DL.getIndexSizeInBits(AS), PartOffset),                           PartPtrTy, BasePtr->getName() + "."), -          getAdjustedAlignment(LI, PartOffset, DL).value(), +          getAdjustedAlignment(LI, PartOffset),            /*IsVolatile*/ false, LI->getName());        PLoad->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access,                                  LLVMContext::MD_access_group}); @@ -3918,7 +4037,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {              getAdjustedPtr(IRB, DL, StoreBasePtr,                             APInt(DL.getIndexSizeInBits(AS), PartOffset),                             PartPtrTy, StoreBasePtr->getName() + "."), -            getAdjustedAlignment(SI, PartOffset, DL).value(), +            getAdjustedAlignment(SI, PartOffset),              /*IsVolatile*/ false);          PStore->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access,                                     LLVMContext::MD_access_group}); @@ -4003,7 +4122,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {              getAdjustedPtr(IRB, DL, LoadBasePtr,                             APInt(DL.getIndexSizeInBits(AS), PartOffset),                             LoadPartPtrTy, LoadBasePtr->getName() + "."), -            getAdjustedAlignment(LI, PartOffset, DL).value(), +            getAdjustedAlignment(LI, PartOffset),              /*IsVolatile*/ false, LI->getName());        } @@ -4015,7 +4134,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) {            getAdjustedPtr(IRB, DL, StoreBasePtr,                           APInt(DL.getIndexSizeInBits(AS), PartOffset),                           StorePartPtrTy, StoreBasePtr->getName() + "."), -          getAdjustedAlignment(SI, PartOffset, DL).value(), +          getAdjustedAlignment(SI, PartOffset),            /*IsVolatile*/ false);        // Now build a new slice for the alloca. @@ -4117,7 +4236,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,    Type *SliceTy = nullptr;    const DataLayout &DL = AI.getModule()->getDataLayout();    if (Type *CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset())) -    if (DL.getTypeAllocSize(CommonUseTy) >= P.size()) +    if (DL.getTypeAllocSize(CommonUseTy).getFixedSize() >= P.size())        SliceTy = CommonUseTy;    if (!SliceTy)      if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(), @@ -4129,7 +4248,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,      SliceTy = Type::getIntNTy(*C, P.size() * 8);    if (!SliceTy)      SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size()); -  assert(DL.getTypeAllocSize(SliceTy) >= P.size()); +  assert(DL.getTypeAllocSize(SliceTy).getFixedSize() >= P.size());    bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, DL); @@ -4151,19 +4270,14 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,      // FIXME: We might want to defer PHI speculation until after here.      // FIXME: return nullptr;    } else { -    // If alignment is unspecified we fallback on the one required by the ABI -    // for this type. We also make sure the alignment is compatible with -    // P.beginOffset(). -    const Align Alignment = commonAlignment( -        DL.getValueOrABITypeAlignment(MaybeAlign(AI.getAlignment()), -                                      AI.getAllocatedType()), -        P.beginOffset()); +    // Make sure the alignment is compatible with P.beginOffset(). +    const Align Alignment = commonAlignment(AI.getAlign(), P.beginOffset());      // If we will get at least this much alignment from the type alone, leave      // the alloca's alignment unconstrained. -    const bool IsUnconstrained = Alignment <= DL.getABITypeAlignment(SliceTy); +    const bool IsUnconstrained = Alignment <= DL.getABITypeAlign(SliceTy);      NewAI = new AllocaInst(          SliceTy, AI.getType()->getAddressSpace(), nullptr, -        IsUnconstrained ? MaybeAlign() : Alignment, +        IsUnconstrained ? DL.getPrefTypeAlign(SliceTy) : Alignment,          AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI);      // Copy the old AI debug location over to the new one.      NewAI->setDebugLoc(AI.getDebugLoc()); @@ -4270,7 +4384,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {    // to be rewritten into a partition.    bool IsSorted = true; -  uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType()); +  uint64_t AllocaSize = +      DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize();    const uint64_t MaxBitVectorSize = 1024;    if (AllocaSize <= MaxBitVectorSize) {      // If a byte boundary is included in any load or store, a slice starting or @@ -4334,7 +4449,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {        Changed = true;        if (NewAI != &AI) {          uint64_t SizeOfByte = 8; -        uint64_t AllocaSize = DL.getTypeSizeInBits(NewAI->getAllocatedType()); +        uint64_t AllocaSize = +            DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedSize();          // Don't include any padding.          uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte);          Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); @@ -4354,7 +4470,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {      auto *Expr = DbgDeclares.front()->getExpression();      auto VarSize = Var->getSizeInBits();      DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); -    uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()); +    uint64_t AllocaSize = +        DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedSize();      for (auto Fragment : Fragments) {        // Create a fragment expression describing the new partition or reuse AI's        // expression if there is only one partition. @@ -4442,8 +4559,9 @@ bool SROA::runOnAlloca(AllocaInst &AI) {    const DataLayout &DL = AI.getModule()->getDataLayout();    // Skip alloca forms that this analysis can't handle. -  if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() || -      DL.getTypeAllocSize(AI.getAllocatedType()) == 0) +  auto *AT = AI.getAllocatedType(); +  if (AI.isArrayAllocation() || !AT->isSized() || isa<ScalableVectorType>(AT) || +      DL.getTypeAllocSize(AT).getFixedSize() == 0)      return false;    bool Changed = false; @@ -4563,8 +4681,14 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT,    BasicBlock &EntryBB = F.getEntryBlock();    for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end());         I != E; ++I) { -    if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) -      Worklist.insert(AI); +    if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { +      if (isa<ScalableVectorType>(AI->getAllocatedType())) { +        if (isAllocaPromotable(AI)) +          PromotableAllocas.push_back(AI); +      } else { +        Worklist.insert(AI); +      } +    }    }    bool Changed = false; diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index c25c6c632b8f..851bd79cd6d8 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -22,8 +22,8 @@  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h" -#include "llvm/IR/Dominators.h"  #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h"  #include "llvm/IR/Function.h"  #include "llvm/IR/IRBuilder.h"  #include "llvm/IR/InstVisitor.h" @@ -41,6 +41,7 @@  #include "llvm/Support/CommandLine.h"  #include "llvm/Support/MathExtras.h"  #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h"  #include <cassert>  #include <cstdint>  #include <iterator> @@ -51,6 +52,11 @@ using namespace llvm;  #define DEBUG_TYPE "scalarizer" +static cl::opt<bool> ScalarizeVariableInsertExtract( +    "scalarize-variable-insert-extract", cl::init(true), cl::Hidden, +    cl::desc("Allow the scalarizer pass to scalarize " +             "insertelement/extractelement with variable index")); +  // This is disabled by default because having separate loads and stores  // makes it more likely that the -combiner-alias-analysis limits will be  // reached. @@ -156,8 +162,8 @@ struct VectorLayout {    VectorLayout() = default;    // Return the alignment of element I. -  uint64_t getElemAlign(unsigned I) { -    return MinAlign(VecAlign, I * ElemSize); +  Align getElemAlign(unsigned I) { +    return commonAlignment(VecAlign, I * ElemSize);    }    // The type of the vector. @@ -167,7 +173,7 @@ struct VectorLayout {    Type *ElemTy = nullptr;    // The alignment of the vector. -  uint64_t VecAlign = 0; +  Align VecAlign;    // The size of each element.    uint64_t ElemSize = 0; @@ -192,6 +198,8 @@ public:    bool visitGetElementPtrInst(GetElementPtrInst &GEPI);    bool visitCastInst(CastInst &CI);    bool visitBitCastInst(BitCastInst &BCI); +  bool visitInsertElementInst(InsertElementInst &IEI); +  bool visitExtractElementInst(ExtractElementInst &EEI);    bool visitShuffleVectorInst(ShuffleVectorInst &SVI);    bool visitPHINode(PHINode &PHI);    bool visitLoadInst(LoadInst &LI); @@ -203,8 +211,8 @@ private:    void gather(Instruction *Op, const ValueVector &CV);    bool canTransferMetadata(unsigned Kind);    void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV); -  bool getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout, -                       const DataLayout &DL); +  Optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment, +                                         const DataLayout &DL);    bool finish();    template<typename T> bool splitUnary(Instruction &, const T &); @@ -215,6 +223,8 @@ private:    ScatterMap Scattered;    GatherList Gathered; +  SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs; +    unsigned ParallelLoopAccessMDKind;    DominatorTree *DT; @@ -252,7 +262,7 @@ Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,    PtrTy = dyn_cast<PointerType>(Ty);    if (PtrTy)      Ty = PtrTy->getElementType(); -  Size = Ty->getVectorNumElements(); +  Size = cast<FixedVectorType>(Ty)->getNumElements();    if (!CachePtr)      Tmp.resize(Size, nullptr);    else if (CachePtr->empty()) @@ -269,7 +279,7 @@ Value *Scatterer::operator[](unsigned I) {      return CV[I];    IRBuilder<> Builder(BB, BBI);    if (PtrTy) { -    Type *ElTy = PtrTy->getElementType()->getVectorElementType(); +    Type *ElTy = cast<VectorType>(PtrTy->getElementType())->getElementType();      if (!CV[0]) {        Type *NewPtrTy = PointerType::get(ElTy, PtrTy->getAddressSpace());        CV[0] = Builder.CreateBitCast(V, NewPtrTy, V->getName() + ".i0"); @@ -376,11 +386,6 @@ Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) {  // so that we can avoid creating the gathered form if all uses of Op are  // replaced with uses of CV.  void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { -  // Since we're not deleting Op yet, stub out its operands, so that it -  // doesn't make anything live unnecessarily. -  for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) -    Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType())); -    transferMetadataAndIRFlags(Op, CV);    // If we already have a scattered form of Op (created from ExtractElements @@ -389,13 +394,13 @@ void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) {    if (!SV.empty()) {      for (unsigned I = 0, E = SV.size(); I != E; ++I) {        Value *V = SV[I]; -      if (V == nullptr) +      if (V == nullptr || SV[I] == CV[I])          continue;        Instruction *Old = cast<Instruction>(V);        CV[I]->takeName(Old);        Old->replaceAllUsesWith(CV[I]); -      Old->eraseFromParent(); +      PotentiallyDeadInstrs.emplace_back(Old);      }    }    SV = CV; @@ -434,25 +439,22 @@ void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,  }  // Try to fill in Layout from Ty, returning true on success.  Alignment is -// the alignment of the vector, or 0 if the ABI default should be used. -bool ScalarizerVisitor::getVectorLayout(Type *Ty, unsigned Alignment, -                                 VectorLayout &Layout, const DataLayout &DL) { +// the alignment of the vector, or None if the ABI default should be used. +Optional<VectorLayout> +ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment, +                                   const DataLayout &DL) { +  VectorLayout Layout;    // Make sure we're dealing with a vector.    Layout.VecTy = dyn_cast<VectorType>(Ty);    if (!Layout.VecTy) -    return false; - +    return None;    // Check that we're dealing with full-byte elements.    Layout.ElemTy = Layout.VecTy->getElementType();    if (!DL.typeSizeEqualsStoreSize(Layout.ElemTy)) -    return false; - -  if (Alignment) -    Layout.VecAlign = Alignment; -  else -    Layout.VecAlign = DL.getABITypeAlignment(Layout.VecTy); +    return None; +  Layout.VecAlign = Alignment;    Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy); -  return true; +  return Layout;  }  // Scalarize one-operand instruction I, using Split(Builder, X, Name) @@ -463,7 +465,7 @@ bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {    if (!VT)      return false; -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    IRBuilder<> Builder(&I);    Scatterer Op = scatter(&I, I.getOperand(0));    assert(Op.size() == NumElems && "Mismatched unary operation"); @@ -483,17 +485,19 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {    if (!VT)      return false; -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    IRBuilder<> Builder(&I); -  Scatterer Op0 = scatter(&I, I.getOperand(0)); -  Scatterer Op1 = scatter(&I, I.getOperand(1)); -  assert(Op0.size() == NumElems && "Mismatched binary operation"); -  assert(Op1.size() == NumElems && "Mismatched binary operation"); +  Scatterer VOp0 = scatter(&I, I.getOperand(0)); +  Scatterer VOp1 = scatter(&I, I.getOperand(1)); +  assert(VOp0.size() == NumElems && "Mismatched binary operation"); +  assert(VOp1.size() == NumElems && "Mismatched binary operation");    ValueVector Res;    Res.resize(NumElems); -  for (unsigned Elem = 0; Elem < NumElems; ++Elem) -    Res[Elem] = Split(Builder, Op0[Elem], Op1[Elem], -                      I.getName() + ".i" + Twine(Elem)); +  for (unsigned Elem = 0; Elem < NumElems; ++Elem) { +    Value *Op0 = VOp0[Elem]; +    Value *Op1 = VOp1[Elem]; +    Res[Elem] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Elem)); +  }    gather(&I, Res);    return true;  } @@ -524,7 +528,7 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) {    if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID))      return false; -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    unsigned NumArgs = CI.getNumArgOperands();    ValueVector ScalarOperands(NumArgs); @@ -574,26 +578,33 @@ bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {    if (!VT)      return false; -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    IRBuilder<> Builder(&SI); -  Scatterer Op1 = scatter(&SI, SI.getOperand(1)); -  Scatterer Op2 = scatter(&SI, SI.getOperand(2)); -  assert(Op1.size() == NumElems && "Mismatched select"); -  assert(Op2.size() == NumElems && "Mismatched select"); +  Scatterer VOp1 = scatter(&SI, SI.getOperand(1)); +  Scatterer VOp2 = scatter(&SI, SI.getOperand(2)); +  assert(VOp1.size() == NumElems && "Mismatched select"); +  assert(VOp2.size() == NumElems && "Mismatched select");    ValueVector Res;    Res.resize(NumElems);    if (SI.getOperand(0)->getType()->isVectorTy()) { -    Scatterer Op0 = scatter(&SI, SI.getOperand(0)); -    assert(Op0.size() == NumElems && "Mismatched select"); -    for (unsigned I = 0; I < NumElems; ++I) -      Res[I] = Builder.CreateSelect(Op0[I], Op1[I], Op2[I], +    Scatterer VOp0 = scatter(&SI, SI.getOperand(0)); +    assert(VOp0.size() == NumElems && "Mismatched select"); +    for (unsigned I = 0; I < NumElems; ++I) { +      Value *Op0 = VOp0[I]; +      Value *Op1 = VOp1[I]; +      Value *Op2 = VOp2[I]; +      Res[I] = Builder.CreateSelect(Op0, Op1, Op2,                                      SI.getName() + ".i" + Twine(I)); +    }    } else {      Value *Op0 = SI.getOperand(0); -    for (unsigned I = 0; I < NumElems; ++I) -      Res[I] = Builder.CreateSelect(Op0, Op1[I], Op2[I], +    for (unsigned I = 0; I < NumElems; ++I) { +      Value *Op1 = VOp1[I]; +      Value *Op2 = VOp2[I]; +      Res[I] = Builder.CreateSelect(Op0, Op1, Op2,                                      SI.getName() + ".i" + Twine(I)); +    }    }    gather(&SI, Res);    return true; @@ -621,7 +632,7 @@ bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {      return false;    IRBuilder<> Builder(&GEPI); -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    unsigned NumIndices = GEPI.getNumIndices();    // The base pointer might be scalar even if it's a vector GEP. In those cases, @@ -666,7 +677,7 @@ bool ScalarizerVisitor::visitCastInst(CastInst &CI) {    if (!VT)      return false; -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    IRBuilder<> Builder(&CI);    Scatterer Op0 = scatter(&CI, CI.getOperand(0));    assert(Op0.size() == NumElems && "Mismatched cast"); @@ -685,8 +696,8 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {    if (!DstVT || !SrcVT)      return false; -  unsigned DstNumElems = DstVT->getNumElements(); -  unsigned SrcNumElems = SrcVT->getNumElements(); +  unsigned DstNumElems = cast<FixedVectorType>(DstVT)->getNumElements(); +  unsigned SrcNumElems = cast<FixedVectorType>(SrcVT)->getNumElements();    IRBuilder<> Builder(&BCI);    Scatterer Op0 = scatter(&BCI, BCI.getOperand(0));    ValueVector Res; @@ -700,7 +711,7 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {      // <M x t1> -> <N*M x t2>.  Convert each t1 to <N x t2> and copy the      // individual elements to the destination.      unsigned FanOut = DstNumElems / SrcNumElems; -    Type *MidTy = VectorType::get(DstVT->getElementType(), FanOut); +    auto *MidTy = FixedVectorType::get(DstVT->getElementType(), FanOut);      unsigned ResI = 0;      for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) {        Value *V = Op0[Op0I]; @@ -718,7 +729,7 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {    } else {      // <N*M x t1> -> <M x t2>.  Convert each group of <N x t1> into a t2.      unsigned FanIn = SrcNumElems / DstNumElems; -    Type *MidTy = VectorType::get(SrcVT->getElementType(), FanIn); +    auto *MidTy = FixedVectorType::get(SrcVT->getElementType(), FanIn);      unsigned Op0I = 0;      for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) {        Value *V = UndefValue::get(MidTy); @@ -734,12 +745,79 @@ bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {    return true;  } +bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) { +  VectorType *VT = dyn_cast<VectorType>(IEI.getType()); +  if (!VT) +    return false; + +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements(); +  IRBuilder<> Builder(&IEI); +  Scatterer Op0 = scatter(&IEI, IEI.getOperand(0)); +  Value *NewElt = IEI.getOperand(1); +  Value *InsIdx = IEI.getOperand(2); + +  ValueVector Res; +  Res.resize(NumElems); + +  if (auto *CI = dyn_cast<ConstantInt>(InsIdx)) { +    for (unsigned I = 0; I < NumElems; ++I) +      Res[I] = CI->getValue().getZExtValue() == I ? NewElt : Op0[I]; +  } else { +    if (!ScalarizeVariableInsertExtract) +      return false; + +    for (unsigned I = 0; I < NumElems; ++I) { +      Value *ShouldReplace = +          Builder.CreateICmpEQ(InsIdx, ConstantInt::get(InsIdx->getType(), I), +                               InsIdx->getName() + ".is." + Twine(I)); +      Value *OldElt = Op0[I]; +      Res[I] = Builder.CreateSelect(ShouldReplace, NewElt, OldElt, +                                    IEI.getName() + ".i" + Twine(I)); +    } +  } + +  gather(&IEI, Res); +  return true; +} + +bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) { +  VectorType *VT = dyn_cast<VectorType>(EEI.getOperand(0)->getType()); +  if (!VT) +    return false; + +  unsigned NumSrcElems = cast<FixedVectorType>(VT)->getNumElements(); +  IRBuilder<> Builder(&EEI); +  Scatterer Op0 = scatter(&EEI, EEI.getOperand(0)); +  Value *ExtIdx = EEI.getOperand(1); + +  if (auto *CI = dyn_cast<ConstantInt>(ExtIdx)) { +    Value *Res = Op0[CI->getValue().getZExtValue()]; +    gather(&EEI, {Res}); +    return true; +  } + +  if (!ScalarizeVariableInsertExtract) +    return false; + +  Value *Res = UndefValue::get(VT->getElementType()); +  for (unsigned I = 0; I < NumSrcElems; ++I) { +    Value *ShouldExtract = +        Builder.CreateICmpEQ(ExtIdx, ConstantInt::get(ExtIdx->getType(), I), +                             ExtIdx->getName() + ".is." + Twine(I)); +    Value *Elt = Op0[I]; +    Res = Builder.CreateSelect(ShouldExtract, Elt, Res, +                               EEI.getName() + ".upto" + Twine(I)); +  } +  gather(&EEI, {Res}); +  return true; +} +  bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {    VectorType *VT = dyn_cast<VectorType>(SVI.getType());    if (!VT)      return false; -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    Scatterer Op0 = scatter(&SVI, SVI.getOperand(0));    Scatterer Op1 = scatter(&SVI, SVI.getOperand(1));    ValueVector Res; @@ -763,7 +841,7 @@ bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {    if (!VT)      return false; -  unsigned NumElems = VT->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(VT)->getNumElements();    IRBuilder<> Builder(&PHI);    ValueVector Res;    Res.resize(NumElems); @@ -789,20 +867,20 @@ bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {    if (!LI.isSimple())      return false; -  VectorLayout Layout; -  if (!getVectorLayout(LI.getType(), LI.getAlignment(), Layout, -                       LI.getModule()->getDataLayout())) +  Optional<VectorLayout> Layout = getVectorLayout( +      LI.getType(), LI.getAlign(), LI.getModule()->getDataLayout()); +  if (!Layout)      return false; -  unsigned NumElems = Layout.VecTy->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();    IRBuilder<> Builder(&LI);    Scatterer Ptr = scatter(&LI, LI.getPointerOperand());    ValueVector Res;    Res.resize(NumElems);    for (unsigned I = 0; I < NumElems; ++I) -    Res[I] = Builder.CreateAlignedLoad(Layout.VecTy->getElementType(), Ptr[I], -                                       Layout.getElemAlign(I), +    Res[I] = Builder.CreateAlignedLoad(Layout->VecTy->getElementType(), Ptr[I], +                                       Align(Layout->getElemAlign(I)),                                         LI.getName() + ".i" + Twine(I));    gather(&LI, Res);    return true; @@ -814,22 +892,23 @@ bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {    if (!SI.isSimple())      return false; -  VectorLayout Layout;    Value *FullValue = SI.getValueOperand(); -  if (!getVectorLayout(FullValue->getType(), SI.getAlignment(), Layout, -                       SI.getModule()->getDataLayout())) +  Optional<VectorLayout> Layout = getVectorLayout( +      FullValue->getType(), SI.getAlign(), SI.getModule()->getDataLayout()); +  if (!Layout)      return false; -  unsigned NumElems = Layout.VecTy->getNumElements(); +  unsigned NumElems = cast<FixedVectorType>(Layout->VecTy)->getNumElements();    IRBuilder<> Builder(&SI); -  Scatterer Ptr = scatter(&SI, SI.getPointerOperand()); -  Scatterer Val = scatter(&SI, FullValue); +  Scatterer VPtr = scatter(&SI, SI.getPointerOperand()); +  Scatterer VVal = scatter(&SI, FullValue);    ValueVector Stores;    Stores.resize(NumElems);    for (unsigned I = 0; I < NumElems; ++I) { -    unsigned Align = Layout.getElemAlign(I); -    Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align); +    Value *Val = VVal[I]; +    Value *Ptr = VPtr[I]; +    Stores[I] = Builder.CreateAlignedStore(Val, Ptr, Layout->getElemAlign(I));    }    transferMetadataAndIRFlags(&SI, Stores);    return true; @@ -852,23 +931,32 @@ bool ScalarizerVisitor::finish() {      if (!Op->use_empty()) {        // The value is still needed, so recreate it using a series of        // InsertElements. -      Type *Ty = Op->getType(); -      Value *Res = UndefValue::get(Ty); -      BasicBlock *BB = Op->getParent(); -      unsigned Count = Ty->getVectorNumElements(); -      IRBuilder<> Builder(Op); -      if (isa<PHINode>(Op)) -        Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); -      for (unsigned I = 0; I < Count; ++I) -        Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), -                                          Op->getName() + ".upto" + Twine(I)); +      Value *Res = UndefValue::get(Op->getType()); +      if (auto *Ty = dyn_cast<VectorType>(Op->getType())) { +        BasicBlock *BB = Op->getParent(); +        unsigned Count = cast<FixedVectorType>(Ty)->getNumElements(); +        IRBuilder<> Builder(Op); +        if (isa<PHINode>(Op)) +          Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); +        for (unsigned I = 0; I < Count; ++I) +          Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), +                                            Op->getName() + ".upto" + Twine(I)); +      } else { +        assert(CV.size() == 1 && Op->getType() == CV[0]->getType()); +        Res = CV[0]; +        if (Op == Res) +          continue; +      }        Res->takeName(Op);        Op->replaceAllUsesWith(Res);      } -    Op->eraseFromParent(); +    PotentiallyDeadInstrs.emplace_back(Op);    }    Gathered.clear();    Scattered.clear(); + +  RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); +    return true;  } diff --git a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 2a1a040bf83e..f1d2e3c1ecfa 100644 --- a/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -431,8 +431,10 @@ private:    bool reuniteExts(Instruction *I);    /// Find the closest dominator of <Dominatee> that is equivalent to <Key>. -  Instruction *findClosestMatchingDominator(const SCEV *Key, -                                            Instruction *Dominatee); +  Instruction *findClosestMatchingDominator( +      const SCEV *Key, Instruction *Dominatee, +      DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs); +    /// Verify F is free of dead code.    void verifyNoDeadCode(Function &F); @@ -456,7 +458,8 @@ private:    /// multiple GEPs with a single index.    bool LowerGEP; -  DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingExprs; +  DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingAdds; +  DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingSubs;  };  } // end anonymous namespace @@ -519,7 +522,7 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,      //   sext(a + b) = sext(a) + sext(b)      // even if the addition is not marked nsw.      // -    // Leveraging this invarient, we can trace into an sext'ed inbound GEP +    // Leveraging this invariant, we can trace into an sext'ed inbound GEP      // index if the constant offset is non-negative.      //      // Verified in @sext_add in split-gep.ll. @@ -549,6 +552,9 @@ bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended,  APInt ConstantOffsetExtractor::findInEitherOperand(BinaryOperator *BO,                                                     bool SignExtended,                                                     bool ZeroExtended) { +  // Save off the current height of the chain, in case we need to restore it. +  size_t ChainLength = UserChain.size(); +    // BO being non-negative does not shed light on whether its operands are    // non-negative. Clear the NonNegative flag here.    APInt ConstantOffset = find(BO->getOperand(0), SignExtended, ZeroExtended, @@ -559,12 +565,22 @@ APInt ConstantOffsetExtractor::findInEitherOperand(BinaryOperator *BO,    // However, such cases are probably already handled by -instcombine,    // given this pass runs after the standard optimizations.    if (ConstantOffset != 0) return ConstantOffset; + +  // Reset the chain back to where it was when we started exploring this node, +  // since visiting the LHS didn't pan out. +  UserChain.resize(ChainLength); +    ConstantOffset = find(BO->getOperand(1), SignExtended, ZeroExtended,                          /* NonNegative */ false);    // If U is a sub operator, negate the constant offset found in the right    // operand.    if (BO->getOpcode() == Instruction::Sub)      ConstantOffset = -ConstantOffset; + +  // If RHS wasn't a suitable candidate either, reset the chain again. +  if (ConstantOffset == 0) +    UserChain.resize(ChainLength); +    return ConstantOffset;  } @@ -688,7 +704,7 @@ Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) {    }    BinaryOperator *BO = cast<BinaryOperator>(UserChain[ChainIndex]); -  assert(BO->getNumUses() <= 1 && +  assert((BO->use_empty() || BO->hasOneUse()) &&           "distributeExtsAndCloneChain clones each BinaryOperator in "           "UserChain, so no one should be used more than "           "once"); @@ -1141,7 +1157,8 @@ bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) {  }  Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator( -    const SCEV *Key, Instruction *Dominatee) { +    const SCEV *Key, Instruction *Dominatee, +    DenseMap<const SCEV *, SmallVector<Instruction *, 2>> &DominatingExprs) {    auto Pos = DominatingExprs.find(Key);    if (Pos == DominatingExprs.end())      return nullptr; @@ -1169,12 +1186,23 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {    // If Dom can't sign overflow and Dom dominates I, optimize I to sext(Dom).    // TODO: handle zext    Value *LHS = nullptr, *RHS = nullptr; -  if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS)))) || -      match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { +  if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) {      if (LHS->getType() == RHS->getType()) {        const SCEV *Key =            SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); -      if (auto *Dom = findClosestMatchingDominator(Key, I)) { +      if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingAdds)) { +        Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); +        NewSExt->takeName(I); +        I->replaceAllUsesWith(NewSExt); +        RecursivelyDeleteTriviallyDeadInstructions(I); +        return true; +      } +    } +  } else if (match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { +    if (LHS->getType() == RHS->getType()) { +      const SCEV *Key = +          SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); +      if (auto *Dom = findClosestMatchingDominator(Key, I, DominatingSubs)) {          Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I);          NewSExt->takeName(I);          I->replaceAllUsesWith(NewSExt); @@ -1185,12 +1213,17 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {    }    // Add I to DominatingExprs if it's an add/sub that can't sign overflow. -  if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS))) || -      match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) { -    if (programUndefinedIfFullPoison(I)) { +  if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS)))) { +    if (programUndefinedIfPoison(I)) { +      const SCEV *Key = +          SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); +      DominatingAdds[Key].push_back(I); +    } +  } else if (match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) { +    if (programUndefinedIfPoison(I)) {        const SCEV *Key =            SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); -      DominatingExprs[Key].push_back(I); +      DominatingSubs[Key].push_back(I);      }    }    return false; @@ -1198,7 +1231,8 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) {  bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) {    bool Changed = false; -  DominatingExprs.clear(); +  DominatingAdds.clear(); +  DominatingSubs.clear();    for (const auto Node : depth_first(DT)) {      BasicBlock *BB = Node->getBlock();      for (auto I = BB->begin(); I != BB->end(); ) { diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index d7a34acb4318..6c6d6ca9cf65 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -26,7 +26,6 @@  #include "llvm/Analysis/LoopPass.h"  #include "llvm/Analysis/MemorySSA.h"  #include "llvm/Analysis/MemorySSAUpdater.h" -#include "llvm/Analysis/Utils/Local.h"  #include "llvm/IR/BasicBlock.h"  #include "llvm/IR/Constant.h"  #include "llvm/IR/Constants.h" @@ -36,6 +35,7 @@  #include "llvm/IR/Instruction.h"  #include "llvm/IR/Instructions.h"  #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h"  #include "llvm/IR/Use.h"  #include "llvm/IR/Value.h"  #include "llvm/InitializePasses.h" @@ -182,7 +182,7 @@ static void buildPartialUnswitchConditionalBranch(BasicBlock &BB,                                                    BasicBlock &UnswitchedSucc,                                                    BasicBlock &NormalSucc) {    IRBuilder<> IRB(&BB); -   +    Value *Cond = Direction ? IRB.CreateOr(Invariants) :      IRB.CreateAnd(Invariants);    IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, @@ -598,19 +598,36 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,    auto *ParentBB = SI.getParent(); +  // The same check must be used both for the default and the exit cases. We +  // should never leave edges from the switch instruction to a basic block that +  // we are unswitching, hence the condition used to determine the default case +  // needs to also be used to populate ExitCaseIndices, which is then used to +  // remove cases from the switch. +  auto IsTriviallyUnswitchableExitBlock = [&](BasicBlock &BBToCheck) { +    // BBToCheck is not an exit block if it is inside loop L. +    if (L.contains(&BBToCheck)) +      return false; +    // BBToCheck is not trivial to unswitch if its phis aren't loop invariant. +    if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, BBToCheck)) +      return false; +    // We do not unswitch a block that only has an unreachable statement, as +    // it's possible this is a previously unswitched block. Only unswitch if +    // either the terminator is not unreachable, or, if it is, it's not the only +    // instruction in the block. +    auto *TI = BBToCheck.getTerminator(); +    bool isUnreachable = isa<UnreachableInst>(TI); +    return !isUnreachable || +           (isUnreachable && (BBToCheck.getFirstNonPHIOrDbg() != TI)); +  }; +    SmallVector<int, 4> ExitCaseIndices; -  for (auto Case : SI.cases()) { -    auto *SuccBB = Case.getCaseSuccessor(); -    if (!L.contains(SuccBB) && -        areLoopExitPHIsLoopInvariant(L, *ParentBB, *SuccBB)) +  for (auto Case : SI.cases()) +    if (IsTriviallyUnswitchableExitBlock(*Case.getCaseSuccessor()))        ExitCaseIndices.push_back(Case.getCaseIndex()); -  }    BasicBlock *DefaultExitBB = nullptr;    SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultCaseWeight =        SwitchInstProfUpdateWrapper::getSuccessorWeight(SI, 0); -  if (!L.contains(SI.getDefaultDest()) && -      areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && -      !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) { +  if (IsTriviallyUnswitchableExitBlock(*SI.getDefaultDest())) {      DefaultExitBB = SI.getDefaultDest();    } else if (ExitCaseIndices.empty())      return false; @@ -1557,6 +1574,11 @@ static void deleteDeadBlocksFromLoop(Loop &L,      // Check that the dominator tree has already been updated.      assert(!DT.getNode(BB) && "Should already have cleared domtree!");      LI.changeLoopFor(BB, nullptr); +    // Drop all uses of the instructions to make sure we won't have dangling +    // uses in other blocks. +    for (auto &I : *BB) +      if (!I.use_empty()) +        I.replaceAllUsesWith(UndefValue::get(I.getType()));      BB->dropAllReferences();    } @@ -2465,7 +2487,7 @@ turnGuardIntoBranch(IntrinsicInst *GI, Loop &L,  /// unswitch candidates, making adequate predictions instead of wild guesses.  /// That requires knowing not just the number of "remaining" candidates but  /// also costs of unswitching for each of these candidates. -static int calculateUnswitchCostMultiplier( +static int CalculateUnswitchCostMultiplier(      Instruction &TI, Loop &L, LoopInfo &LI, DominatorTree &DT,      ArrayRef<std::pair<Instruction *, TinyPtrVector<Value *>>>          UnswitchCandidates) { @@ -2656,11 +2678,11 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,        if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB))          return false; -      if (auto CS = CallSite(&I)) -        if (CS.isConvergent() || CS.cannotDuplicate()) +      if (auto *CB = dyn_cast<CallBase>(&I)) +        if (CB->isConvergent() || CB->cannotDuplicate())            return false; -      Cost += TTI.getUserCost(&I); +      Cost += TTI.getUserCost(&I, TargetTransformInfo::TCK_CodeSize);      }      assert(Cost >= 0 && "Must not have negative costs!");      LoopCost += Cost; @@ -2754,7 +2776,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,      // exponential behavior of loop-unswitch.      if (EnableUnswitchCostMultiplier) {        int CostMultiplier = -          calculateUnswitchCostMultiplier(TI, L, LI, DT, UnswitchCandidates); +          CalculateUnswitchCostMultiplier(TI, L, LI, DT, UnswitchCandidates);        assert(            (CostMultiplier > 0 && CostMultiplier <= UnswitchThreshold) &&            "cost multiplier needs to be in the range of 1..UnswitchThreshold"); @@ -2868,7 +2890,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,    // Save the current loop name in a variable so that we can report it even    // after it has been deleted. -  std::string LoopName = L.getName(); +  std::string LoopName = std::string(L.getName());    auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,                                          ArrayRef<Loop *> NewLoops) { @@ -2983,10 +3005,6 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) {    if (MSSA && VerifyMemorySSA)      MSSA->verifyMemorySSA(); -  // If anything was unswitched, also clear any cached information about this -  // loop. -  LPM.deleteSimpleAnalysisLoop(L); -    // Historically this pass has had issues with the dominator tree so verify it    // in asserts builds.    assert(DT.verify(DominatorTree::VerificationLevel::Fast)); diff --git a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 623a8b711ed8..2e459c9a64d4 100644 --- a/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -104,6 +104,21 @@ static bool mergeEmptyReturnBlocks(Function &F) {        continue;      } +    // Skip merging if this would result in a CallBr instruction with a +    // duplicate destination. FIXME: See note in CodeGenPrepare.cpp. +    bool SkipCallBr = false; +    for (pred_iterator PI = pred_begin(&BB), E = pred_end(&BB); +         PI != E && !SkipCallBr; ++PI) { +      if (auto *CBI = dyn_cast<CallBrInst>((*PI)->getTerminator())) +        for (unsigned i = 0, e = CBI->getNumSuccessors(); i != e; ++i) +          if (RetBlock == CBI->getSuccessor(i)) { +            SkipCallBr = true; +            break; +          } +    } +    if (SkipCallBr) +      continue; +      // Otherwise, we found a duplicate return block.  Merge the two.      Changed = true; @@ -266,6 +281,14 @@ struct CFGSimplifyPass : public FunctionPass {        return false;      Options.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +    if (F.hasFnAttribute(Attribute::OptForFuzzing)) { +      Options.setSimplifyCondBranch(false) +             .setFoldTwoEntryPHINode(false); +    } else { +      Options.setSimplifyCondBranch(true) +             .setFoldTwoEntryPHINode(true); +    } +      auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);      return simplifyFunctionCFG(F, TTI, Options);    } diff --git a/llvm/lib/Transforms/Scalar/Sink.cpp b/llvm/lib/Transforms/Scalar/Sink.cpp index 677d86f8c7b4..48f289c8f17d 100644 --- a/llvm/lib/Transforms/Scalar/Sink.cpp +++ b/llvm/lib/Transforms/Scalar/Sink.cpp @@ -166,8 +166,8 @@ static bool SinkInstruction(Instruction *Inst,    // dominated by one of the successors.    // Look at all the dominated blocks and see if we can sink it in one.    DomTreeNode *DTN = DT.getNode(Inst->getParent()); -  for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end(); -      I != E && SuccToSinkTo == nullptr; ++I) { +  for (auto I = DTN->begin(), E = DTN->end(); I != E && SuccToSinkTo == nullptr; +       ++I) {      BasicBlock *Candidate = (*I)->getBlock();      // A node always immediate-dominates its children on the dominator      // tree. diff --git a/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp index cd7bfb2f20dc..8258b92a716d 100644 --- a/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -67,8 +67,8 @@ isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT,        return false;      } -    if (auto CS = ImmutableCallSite(UI)) { -      if (CS.isConvergent() || CS.cannotDuplicate()) { +    if (const auto *CS = dyn_cast<CallBase>(UI)) { +      if (CS->isConvergent() || CS->cannotDuplicate()) {          LLVM_DEBUG(dbgs() << "  Unsafe: convergent "                     "callsite cannot de duplicated: " << *UI << '\n');          return false; @@ -232,7 +232,8 @@ static bool isSafeAndProfitableToSpeculateAroundPHI(        continue;      int &MatCost = InsertResult.first->second.MatCost; -    MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType()); +    MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType(), +                                TargetTransformInfo::TCK_SizeAndLatency);      NonFreeMat |= MatCost != TTI.TCC_Free;    }    if (!NonFreeMat) { @@ -283,12 +284,15 @@ static bool isSafeAndProfitableToSpeculateAroundPHI(        int MatCost = IncomingConstantAndCostsAndCount.second.MatCost;        int &FoldedCost = IncomingConstantAndCostsAndCount.second.FoldedCost;        if (IID) -        FoldedCost += TTI.getIntImmCostIntrin(IID, Idx, IncomingC->getValue(), -                                              IncomingC->getType()); +        FoldedCost += +          TTI.getIntImmCostIntrin(IID, Idx, IncomingC->getValue(), +                                  IncomingC->getType(), +                                  TargetTransformInfo::TCK_SizeAndLatency);        else          FoldedCost +=              TTI.getIntImmCostInst(UserI->getOpcode(), Idx, -                                  IncomingC->getValue(), IncomingC->getType()); +                                  IncomingC->getValue(), IncomingC->getType(), +                                  TargetTransformInfo::TCK_SizeAndLatency);        // If we accumulate more folded cost for this incoming constant than        // materialized cost, then we'll regress any edge with this constant so @@ -465,7 +469,7 @@ findProfitablePHIs(ArrayRef<PHINode *> PNs,              if (CostMapIt != SpecCostMap.end())                Cost += CostMapIt->second;            } -        Cost += TTI.getUserCost(I); +        Cost += TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency);          bool Inserted = SpecCostMap.insert({I, Cost}).second;          (void)Inserted;          assert(Inserted && "Must not re-insert a cost during the DFS!"); diff --git a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp index c8d899bb4871..f82a2936c762 100644 --- a/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp +++ b/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -65,6 +65,7 @@  #include "llvm/Analysis/GlobalsModRef.h"  #include "llvm/Analysis/ValueTracking.h"  #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h"  #include "llvm/IR/Module.h"  #include "llvm/IR/Operator.h"  #include "llvm/InitializePasses.h" @@ -244,19 +245,35 @@ static unsigned ComputeSpeculationCost(const Instruction *I,      case Instruction::FNeg:      case Instruction::ICmp:      case Instruction::FCmp: -      return TTI.getUserCost(I); +      return TTI.getUserCost(I, TargetTransformInfo::TCK_SizeAndLatency);      default: -      return UINT_MAX; // Disallow anything not whitelisted. +      return UINT_MAX; // Disallow anything not explicitly listed.    }  }  bool SpeculativeExecutionPass::considerHoistingFromTo(      BasicBlock &FromBlock, BasicBlock &ToBlock) {    SmallPtrSet<const Instruction *, 8> NotHoisted; -  const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](User *U) { -    for (Value* V : U->operand_values()) { -      if (Instruction *I = dyn_cast<Instruction>(V)) { +  const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](const User *U) { +    // Debug variable has special operand to check it's not hoisted. +    if (const auto *DVI = dyn_cast<DbgVariableIntrinsic>(U)) { +      if (const auto *I = +              dyn_cast_or_null<Instruction>(DVI->getVariableLocation())) +        if (NotHoisted.count(I) == 0) +          return true; +      return false; +    } + +    // Usially debug label instrinsic corresponds to label in LLVM IR. In these +    // cases we should not move it here. +    // TODO: Possible special processing needed to detect it is related to a +    // hoisted instruction. +    if (isa<DbgLabelInst>(U)) +      return false; + +    for (const Value *V : U->operand_values()) { +      if (const Instruction *I = dyn_cast<Instruction>(V)) {          if (NotHoisted.count(I) > 0)            return false;        } @@ -265,7 +282,8 @@ bool SpeculativeExecutionPass::considerHoistingFromTo(    };    unsigned TotalSpeculationCost = 0; -  for (auto& I : FromBlock) { +  unsigned NotHoistedInstCount = 0; +  for (const auto &I : FromBlock) {      const unsigned Cost = ComputeSpeculationCost(&I, *TTI);      if (Cost != UINT_MAX && isSafeToSpeculativelyExecute(&I) &&          AllPrecedingUsesFromBlockHoisted(&I)) { @@ -273,15 +291,15 @@ bool SpeculativeExecutionPass::considerHoistingFromTo(        if (TotalSpeculationCost > SpecExecMaxSpeculationCost)          return false;  // too much to hoist      } else { -      NotHoisted.insert(&I); -      if (NotHoisted.size() > SpecExecMaxNotHoisted) +      // Debug info instrinsics should not be counted for threshold. +      if (!isa<DbgInfoIntrinsic>(I)) +        NotHoistedInstCount++; +      if (NotHoistedInstCount > SpecExecMaxNotHoisted)          return false; // too much left behind +      NotHoisted.insert(&I);      }    } -  if (TotalSpeculationCost == 0) -    return false; // nothing to hoist -    for (auto I = FromBlock.begin(); I != FromBlock.end();) {      // We have to increment I before moving Current as moving Current      // changes the list that I is iterating through. diff --git a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp index 4ce4ce46f67a..c20e57b02c1a 100644 --- a/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -8,13 +8,12 @@  #include "llvm/ADT/DenseMap.h"  #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SCCIterator.h"  #include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/ADT/SmallVector.h"  #include "llvm/Analysis/InstructionSimplify.h"  #include "llvm/Analysis/LegacyDivergenceAnalysis.h" -#include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/RegionInfo.h"  #include "llvm/Analysis/RegionIterator.h"  #include "llvm/Analysis/RegionPass.h" @@ -34,6 +33,7 @@  #include "llvm/IR/Use.h"  #include "llvm/IR/User.h"  #include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h"  #include "llvm/InitializePasses.h"  #include "llvm/Pass.h"  #include "llvm/Support/Casting.h" @@ -43,6 +43,7 @@  #include "llvm/Support/raw_ostream.h"  #include "llvm/Transforms/Scalar.h"  #include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/Local.h"  #include "llvm/Transforms/Utils/SSAUpdater.h"  #include <algorithm>  #include <cassert> @@ -88,6 +89,59 @@ using BBPredicates = DenseMap<BasicBlock *, Value *>;  using PredMap = DenseMap<BasicBlock *, BBPredicates>;  using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>; +// A traits type that is intended to be used in graph algorithms. The graph +// traits starts at an entry node, and traverses the RegionNodes that are in +// the Nodes set. +struct SubGraphTraits { +  using NodeRef = std::pair<RegionNode *, SmallDenseSet<RegionNode *> *>; +  using BaseSuccIterator = GraphTraits<RegionNode *>::ChildIteratorType; + +  // This wraps a set of Nodes into the iterator, so we know which edges to +  // filter out. +  class WrappedSuccIterator +      : public iterator_adaptor_base< +            WrappedSuccIterator, BaseSuccIterator, +            typename std::iterator_traits<BaseSuccIterator>::iterator_category, +            NodeRef, std::ptrdiff_t, NodeRef *, NodeRef> { +    SmallDenseSet<RegionNode *> *Nodes; + +  public: +    WrappedSuccIterator(BaseSuccIterator It, SmallDenseSet<RegionNode *> *Nodes) +        : iterator_adaptor_base(It), Nodes(Nodes) {} + +    NodeRef operator*() const { return {*I, Nodes}; } +  }; + +  static bool filterAll(const NodeRef &N) { return true; } +  static bool filterSet(const NodeRef &N) { return N.second->count(N.first); } + +  using ChildIteratorType = +      filter_iterator<WrappedSuccIterator, bool (*)(const NodeRef &)>; + +  static NodeRef getEntryNode(Region *R) { +    return {GraphTraits<Region *>::getEntryNode(R), nullptr}; +  } + +  static NodeRef getEntryNode(NodeRef N) { return N; } + +  static iterator_range<ChildIteratorType> children(const NodeRef &N) { +    auto *filter = N.second ? &filterSet : &filterAll; +    return make_filter_range( +        make_range<WrappedSuccIterator>( +            {GraphTraits<RegionNode *>::child_begin(N.first), N.second}, +            {GraphTraits<RegionNode *>::child_end(N.first), N.second}), +        filter); +  } + +  static ChildIteratorType child_begin(const NodeRef &N) { +    return children(N).begin(); +  } + +  static ChildIteratorType child_end(const NodeRef &N) { +    return children(N).end(); +  } +}; +  /// Finds the nearest common dominator of a set of BasicBlocks.  ///  /// For every BB you add to the set, you can specify whether we "remember" the @@ -192,11 +246,11 @@ class StructurizeCFG : public RegionPass {    LegacyDivergenceAnalysis *DA;    DominatorTree *DT; -  LoopInfo *LI;    SmallVector<RegionNode *, 8> Order;    BBSet Visited; +  SmallVector<WeakVH, 8> AffectedPhis;    BBPhiMap DeletedPhis;    BB2BBVecMap AddedPhis; @@ -211,13 +265,8 @@ class StructurizeCFG : public RegionPass {    void orderNodes(); -  Loop *getAdjustedLoop(RegionNode *RN); -  unsigned getAdjustedLoopDepth(RegionNode *RN); -    void analyzeLoops(RegionNode *N); -  Value *invert(Value *Condition); -    Value *buildCondition(BranchInst *Term, unsigned Idx, bool Invert);    void gatherPredicates(RegionNode *N); @@ -232,6 +281,8 @@ class StructurizeCFG : public RegionPass {    void setPhiValues(); +  void simplifyAffectedPhis(); +    void killTerminator(BasicBlock *BB);    void changeExit(RegionNode *Node, BasicBlock *NewExit, @@ -279,7 +330,6 @@ public:        AU.addRequired<LegacyDivergenceAnalysis>();      AU.addRequiredID(LowerSwitchID);      AU.addRequired<DominatorTreeWrapperPass>(); -    AU.addRequired<LoopInfoWrapperPass>();      AU.addPreserved<DominatorTreeWrapperPass>();      RegionPass::getAnalysisUsage(AU); @@ -311,75 +361,60 @@ bool StructurizeCFG::doInitialization(Region *R, RGPassManager &RGM) {    return false;  } -/// Use the exit block to determine the loop if RN is a SubRegion. -Loop *StructurizeCFG::getAdjustedLoop(RegionNode *RN) { -  if (RN->isSubRegion()) { -    Region *SubRegion = RN->getNodeAs<Region>(); -    return LI->getLoopFor(SubRegion->getExit()); -  } - -  return LI->getLoopFor(RN->getEntry()); -} - -/// Use the exit block to determine the loop depth if RN is a SubRegion. -unsigned StructurizeCFG::getAdjustedLoopDepth(RegionNode *RN) { -  if (RN->isSubRegion()) { -    Region *SubR = RN->getNodeAs<Region>(); -    return LI->getLoopDepth(SubR->getExit()); -  } - -  return LI->getLoopDepth(RN->getEntry()); -} - -/// Build up the general order of nodes +/// Build up the general order of nodes, by performing a topological sort of the +/// parent region's nodes, while ensuring that there is no outer cycle node +/// between any two inner cycle nodes.  void StructurizeCFG::orderNodes() { -  ReversePostOrderTraversal<Region*> RPOT(ParentRegion); -  SmallDenseMap<Loop*, unsigned, 8> LoopBlocks; - -  // The reverse post-order traversal of the list gives us an ordering close -  // to what we want.  The only problem with it is that sometimes backedges -  // for outer loops will be visited before backedges for inner loops. -  for (RegionNode *RN : RPOT) { -    Loop *Loop = getAdjustedLoop(RN); -    ++LoopBlocks[Loop]; -  } - -  unsigned CurrentLoopDepth = 0; -  Loop *CurrentLoop = nullptr; -  for (auto I = RPOT.begin(), E = RPOT.end(); I != E; ++I) { -    RegionNode *RN = cast<RegionNode>(*I); -    unsigned LoopDepth = getAdjustedLoopDepth(RN); - -    if (is_contained(Order, *I)) -      continue; - -    if (LoopDepth < CurrentLoopDepth) { -      // Make sure we have visited all blocks in this loop before moving back to -      // the outer loop. +  Order.resize(std::distance(GraphTraits<Region *>::nodes_begin(ParentRegion), +                             GraphTraits<Region *>::nodes_end(ParentRegion))); +  if (Order.empty()) +    return; -      auto LoopI = I; -      while (unsigned &BlockCount = LoopBlocks[CurrentLoop]) { -        LoopI++; -        if (getAdjustedLoop(cast<RegionNode>(*LoopI)) == CurrentLoop) { -          --BlockCount; -          Order.push_back(*LoopI); -        } +  SmallDenseSet<RegionNode *> Nodes; +  auto EntryNode = SubGraphTraits::getEntryNode(ParentRegion); + +  // A list of range indices of SCCs in Order, to be processed. +  SmallVector<std::pair<unsigned, unsigned>, 8> WorkList; +  unsigned I = 0, E = Order.size(); +  while (true) { +    // Run through all the SCCs in the subgraph starting with Entry. +    for (auto SCCI = +             scc_iterator<SubGraphTraits::NodeRef, SubGraphTraits>::begin( +                 EntryNode); +         !SCCI.isAtEnd(); ++SCCI) { +      auto &SCC = *SCCI; + +      // An SCC up to the size of 2, can be reduced to an entry (the last node), +      // and a possible additional node. Therefore, it is already in order, and +      // there is no need to add it to the work-list. +      unsigned Size = SCC.size(); +      if (Size > 2) +        WorkList.emplace_back(I, I + Size); + +      // Add the SCC nodes to the Order array. +      for (auto &N : SCC) { +        assert(I < E && "SCC size mismatch!"); +        Order[I++] = N.first;        }      } +    assert(I == E && "SCC size mismatch!"); -    CurrentLoop = getAdjustedLoop(RN); -    if (CurrentLoop) -      LoopBlocks[CurrentLoop]--; +    // If there are no more SCCs to order, then we are done. +    if (WorkList.empty()) +      break; -    CurrentLoopDepth = LoopDepth; -    Order.push_back(*I); -  } +    std::tie(I, E) = WorkList.pop_back_val(); + +    // Collect the set of nodes in the SCC's subgraph. These are only the +    // possible child nodes; we do not add the entry (last node) otherwise we +    // will have the same exact SCC all over again. +    Nodes.clear(); +    Nodes.insert(Order.begin() + I, Order.begin() + E - 1); -  // This pass originally used a post-order traversal and then operated on -  // the list in reverse. Now that we are using a reverse post-order traversal -  // rather than re-working the whole pass to operate on the list in order, -  // we just reverse the list and continue to operate on it in reverse. -  std::reverse(Order.begin(), Order.end()); +    // Update the entry node. +    EntryNode.first = Order[E - 1]; +    EntryNode.second = &Nodes; +  }  }  /// Determine the end of the loops @@ -401,39 +436,6 @@ void StructurizeCFG::analyzeLoops(RegionNode *N) {    }  } -/// Invert the given condition -Value *StructurizeCFG::invert(Value *Condition) { -  // First: Check if it's a constant -  if (Constant *C = dyn_cast<Constant>(Condition)) -    return ConstantExpr::getNot(C); - -  // Second: If the condition is already inverted, return the original value -  Value *NotCondition; -  if (match(Condition, m_Not(m_Value(NotCondition)))) -    return NotCondition; - -  if (Instruction *Inst = dyn_cast<Instruction>(Condition)) { -    // Third: Check all the users for an invert -    BasicBlock *Parent = Inst->getParent(); -    for (User *U : Condition->users()) -      if (Instruction *I = dyn_cast<Instruction>(U)) -        if (I->getParent() == Parent && match(I, m_Not(m_Specific(Condition)))) -          return I; - -    // Last option: Create a new instruction -    return BinaryOperator::CreateNot(Condition, "", Parent->getTerminator()); -  } - -  if (Argument *Arg = dyn_cast<Argument>(Condition)) { -    BasicBlock &EntryBlock = Arg->getParent()->getEntryBlock(); -    return BinaryOperator::CreateNot(Condition, -                                     Arg->getName() + ".inv", -                                     EntryBlock.getTerminator()); -  } - -  llvm_unreachable("Unhandled condition to invert"); -} -  /// Build the condition for one edge  Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx,                                        bool Invert) { @@ -442,7 +444,7 @@ Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx,      Cond = Term->getCondition();      if (Idx != (unsigned)Invert) -      Cond = invert(Cond); +      Cond = invertCondition(Cond);    }    return Cond;  } @@ -520,8 +522,7 @@ void StructurizeCFG::collectInfos() {    for (RegionNode *RN : reverse(Order)) {      LLVM_DEBUG(dbgs() << "Visiting: "                        << (RN->isSubRegion() ? "SubRegion with entry: " : "") -                      << RN->getEntry()->getName() << " Loop Depth: " -                      << LI->getLoopDepth(RN->getEntry()) << "\n"); +                      << RN->getEntry()->getName() << "\n");      // Analyze all the conditions leading to a node      gatherPredicates(RN); @@ -585,9 +586,14 @@ void StructurizeCFG::insertConditions(bool Loops) {  void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) {    PhiMap &Map = DeletedPhis[To];    for (PHINode &Phi : To->phis()) { +    bool Recorded = false;      while (Phi.getBasicBlockIndex(From) != -1) {        Value *Deleted = Phi.removeIncomingValue(From, false);        Map[&Phi].push_back(std::make_pair(From, Deleted)); +      if (!Recorded) { +        AffectedPhis.push_back(&Phi); +        Recorded = true; +      }      }    }  } @@ -632,28 +638,29 @@ void StructurizeCFG::setPhiValues() {        for (BasicBlock *FI : From)          Phi->setIncomingValueForBlock(FI, Updater.GetValueAtEndOfBlock(FI)); +      AffectedPhis.push_back(Phi);      }      DeletedPhis.erase(To);    }    assert(DeletedPhis.empty()); -  // Simplify any phis inserted by the SSAUpdater if possible +  AffectedPhis.append(InsertedPhis.begin(), InsertedPhis.end()); +} + +void StructurizeCFG::simplifyAffectedPhis() {    bool Changed;    do {      Changed = false; -      SimplifyQuery Q(Func->getParent()->getDataLayout());      Q.DT = DT; -    for (size_t i = 0; i < InsertedPhis.size(); ++i) { -      PHINode *Phi = InsertedPhis[i]; -      if (Value *V = SimplifyInstruction(Phi, Q)) { -        Phi->replaceAllUsesWith(V); -        Phi->eraseFromParent(); -        InsertedPhis[i] = InsertedPhis.back(); -        InsertedPhis.pop_back(); -        i--; -        Changed = true; +    for (WeakVH VH : AffectedPhis) { +      if (auto Phi = dyn_cast_or_null<PHINode>(VH)) { +        if (auto NewValue = SimplifyInstruction(Phi, Q)) { +          Phi->replaceAllUsesWith(NewValue); +          Phi->eraseFromParent(); +          Changed = true; +        }        }      }    } while (Changed); @@ -886,6 +893,7 @@ void StructurizeCFG::createFlow() {    BasicBlock *Exit = ParentRegion->getExit();    bool EntryDominatesExit = DT->dominates(ParentRegion->getEntry(), Exit); +  AffectedPhis.clear();    DeletedPhis.clear();    AddedPhis.clear();    Conditions.clear(); @@ -1036,7 +1044,6 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) {    ParentRegion = R;    DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); -  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();    orderNodes();    collectInfos(); @@ -1044,6 +1051,7 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) {    insertConditions(false);    insertConditions(true);    setPhiValues(); +  simplifyAffectedPhis();    rebuildSSA();    // Cleanup diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp index 9f0ab9103d42..5bb1d54d7d12 100644 --- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -64,7 +64,6 @@  #include "llvm/Analysis/PostDominators.h"  #include "llvm/Analysis/TargetTransformInfo.h"  #include "llvm/IR/CFG.h" -#include "llvm/IR/CallSite.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/DerivedTypes.h" @@ -126,16 +125,16 @@ struct AllocaDerivedValueTracker {        switch (I->getOpcode()) {        case Instruction::Call:        case Instruction::Invoke: { -        CallSite CS(I); +        auto &CB = cast<CallBase>(*I);          // If the alloca-derived argument is passed byval it is not an escape          // point, or a use of an alloca. Calling with byval copies the contents          // of the alloca into argument registers or stack slots, which exist          // beyond the lifetime of the current frame. -        if (CS.isArgOperand(U) && CS.isByValArgument(CS.getArgumentNo(U))) +        if (CB.isArgOperand(U) && CB.isByValArgument(CB.getArgOperandNo(U)))            continue;          bool IsNocapture = -            CS.isDataOperand(U) && CS.doesNotCapture(CS.getDataOperandNo(U)); -        callUsesLocalStack(CS, IsNocapture); +            CB.isDataOperand(U) && CB.doesNotCapture(CB.getDataOperandNo(U)); +        callUsesLocalStack(CB, IsNocapture);          if (IsNocapture) {            // If the alloca-derived argument is passed in as nocapture, then it            // can't propagate to the call's return. That would be capturing. @@ -168,17 +167,17 @@ struct AllocaDerivedValueTracker {      }    } -  void callUsesLocalStack(CallSite CS, bool IsNocapture) { +  void callUsesLocalStack(CallBase &CB, bool IsNocapture) {      // Add it to the list of alloca users. -    AllocaUsers.insert(CS.getInstruction()); +    AllocaUsers.insert(&CB);      // If it's nocapture then it can't capture this alloca.      if (IsNocapture)        return;      // If it can write to memory, it can leak the alloca value. -    if (!CS.onlyReadsMemory()) -      EscapePoints.insert(CS.getInstruction()); +    if (!CB.onlyReadsMemory()) +      EscapePoints.insert(&CB);    }    SmallPtrSet<Instruction *, 32> AllocaUsers; @@ -342,7 +341,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {        const DataLayout &DL = L->getModule()->getDataLayout();        if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) ||            !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getType(), -                                       MaybeAlign(L->getAlignment()), DL, L)) +                                       L->getAlign(), DL, L))          return false;      }    } @@ -355,89 +354,23 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {    return !is_contained(I->operands(), CI);  } -/// Return true if the specified value is the same when the return would exit -/// as it was when the initial iteration of the recursive function was executed. -/// -/// We currently handle static constants and arguments that are not modified as -/// part of the recursion. -static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) { -  if (isa<Constant>(V)) return true; // Static constants are always dyn consts - -  // Check to see if this is an immutable argument, if so, the value -  // will be available to initialize the accumulator. -  if (Argument *Arg = dyn_cast<Argument>(V)) { -    // Figure out which argument number this is... -    unsigned ArgNo = 0; -    Function *F = CI->getParent()->getParent(); -    for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI) -      ++ArgNo; - -    // If we are passing this argument into call as the corresponding -    // argument operand, then the argument is dynamically constant. -    // Otherwise, we cannot transform this function safely. -    if (CI->getArgOperand(ArgNo) == Arg) -      return true; -  } - -  // Switch cases are always constant integers. If the value is being switched -  // on and the return is only reachable from one of its cases, it's -  // effectively constant. -  if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor()) -    if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator())) -      if (SI->getCondition() == V) -        return SI->getDefaultDest() != RI->getParent(); - -  // Not a constant or immutable argument, we can't safely transform. -  return false; -} - -/// Check to see if the function containing the specified tail call consistently -/// returns the same runtime-constant value at all exit points except for -/// IgnoreRI. If so, return the returned value. -static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) { -  Function *F = CI->getParent()->getParent(); -  Value *ReturnedValue = nullptr; - -  for (BasicBlock &BBI : *F) { -    ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()); -    if (RI == nullptr || RI == IgnoreRI) continue; - -    // We can only perform this transformation if the value returned is -    // evaluatable at the start of the initial invocation of the function, -    // instead of at the end of the evaluation. -    // -    Value *RetOp = RI->getOperand(0); -    if (!isDynamicConstant(RetOp, CI, RI)) -      return nullptr; - -    if (ReturnedValue && RetOp != ReturnedValue) -      return nullptr;     // Cannot transform if differing values are returned. -    ReturnedValue = RetOp; -  } -  return ReturnedValue; -} +static bool canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { +  if (!I->isAssociative() || !I->isCommutative()) +    return false; -/// If the specified instruction can be transformed using accumulator recursion -/// elimination, return the constant which is the start of the accumulator -/// value.  Otherwise return null. -static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { -  if (!I->isAssociative() || !I->isCommutative()) return nullptr;    assert(I->getNumOperands() == 2 &&           "Associative/commutative operations should have 2 args!");    // Exactly one operand should be the result of the call instruction.    if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||        (I->getOperand(0) != CI && I->getOperand(1) != CI)) -    return nullptr; +    return false;    // The only user of this instruction we allow is a single return instruction.    if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back())) -    return nullptr; +    return false; -  // Ok, now we have to check all of the other return instructions in this -  // function.  If they return non-constants or differing values, then we cannot -  // transform the function safely. -  return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI); +  return true;  }  static Instruction *firstNonDbg(BasicBlock::iterator I) { @@ -446,11 +379,73 @@ static Instruction *firstNonDbg(BasicBlock::iterator I) {    return &*I;  } -static CallInst *findTRECandidate(Instruction *TI, -                                  bool CannotTailCallElimCallsMarkedTail, -                                  const TargetTransformInfo *TTI) { +namespace { +class TailRecursionEliminator { +  Function &F; +  const TargetTransformInfo *TTI; +  AliasAnalysis *AA; +  OptimizationRemarkEmitter *ORE; +  DomTreeUpdater &DTU; + +  // The below are shared state we want to have available when eliminating any +  // calls in the function. There values should be populated by +  // createTailRecurseLoopHeader the first time we find a call we can eliminate. +  BasicBlock *HeaderBB = nullptr; +  SmallVector<PHINode *, 8> ArgumentPHIs; +  bool RemovableCallsMustBeMarkedTail = false; + +  // PHI node to store our return value. +  PHINode *RetPN = nullptr; + +  // i1 PHI node to track if we have a valid return value stored in RetPN. +  PHINode *RetKnownPN = nullptr; + +  // Vector of select instructions we insereted. These selects use RetKnownPN +  // to either propagate RetPN or select a new return value. +  SmallVector<SelectInst *, 8> RetSelects; + +  // The below are shared state needed when performing accumulator recursion. +  // There values should be populated by insertAccumulator the first time we +  // find an elimination that requires an accumulator. + +  // PHI node to store our current accumulated value. +  PHINode *AccPN = nullptr; + +  // The instruction doing the accumulating. +  Instruction *AccumulatorRecursionInstr = nullptr; + +  TailRecursionEliminator(Function &F, const TargetTransformInfo *TTI, +                          AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, +                          DomTreeUpdater &DTU) +      : F(F), TTI(TTI), AA(AA), ORE(ORE), DTU(DTU) {} + +  CallInst *findTRECandidate(Instruction *TI, +                             bool CannotTailCallElimCallsMarkedTail); + +  void createTailRecurseLoopHeader(CallInst *CI); + +  void insertAccumulator(Instruction *AccRecInstr); + +  bool eliminateCall(CallInst *CI); + +  bool foldReturnAndProcessPred(ReturnInst *Ret, +                                bool CannotTailCallElimCallsMarkedTail); + +  bool processReturningBlock(ReturnInst *Ret, +                             bool CannotTailCallElimCallsMarkedTail); + +  void cleanupAndFinalize(); + +public: +  static bool eliminate(Function &F, const TargetTransformInfo *TTI, +                        AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, +                        DomTreeUpdater &DTU); +}; +} // namespace + +CallInst *TailRecursionEliminator::findTRECandidate( +    Instruction *TI, bool CannotTailCallElimCallsMarkedTail) {    BasicBlock *BB = TI->getParent(); -  Function *F = BB->getParent();    if (&BB->front() == TI) // Make sure there is something before the terminator.      return nullptr; @@ -461,7 +456,7 @@ static CallInst *findTRECandidate(Instruction *TI,    BasicBlock::iterator BBI(TI);    while (true) {      CI = dyn_cast<CallInst>(BBI); -    if (CI && CI->getCalledFunction() == F) +    if (CI && CI->getCalledFunction() == &F)        break;      if (BBI == BB->begin()) @@ -478,16 +473,14 @@ static CallInst *findTRECandidate(Instruction *TI,    //   double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call    // and disable this xform in this case, because the code generator will    // lower the call to fabs into inline code. -  if (BB == &F->getEntryBlock() && +  if (BB == &F.getEntryBlock() &&        firstNonDbg(BB->front().getIterator()) == CI &&        firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() &&        !TTI->isLoweredToCall(CI->getCalledFunction())) {      // A single-block function with just a call and a return. Check that      // the arguments match. -    CallSite::arg_iterator I = CallSite(CI).arg_begin(), -                           E = CallSite(CI).arg_end(); -    Function::arg_iterator FI = F->arg_begin(), -                           FE = F->arg_end(); +    auto I = CI->arg_begin(), E = CI->arg_end(); +    Function::arg_iterator FI = F.arg_begin(), FE = F.arg_end();      for (; I != E && FI != FE; ++I, ++FI)        if (*I != &*FI) break;      if (I == E && FI == FE) @@ -497,27 +490,106 @@ static CallInst *findTRECandidate(Instruction *TI,    return CI;  } -static bool eliminateRecursiveTailCall( -    CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, -    bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, -    AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { -  // If we are introducing accumulator recursion to eliminate operations after -  // the call instruction that are both associative and commutative, the initial -  // value for the accumulator is placed in this variable.  If this value is set -  // then we actually perform accumulator recursion elimination instead of -  // simple tail recursion elimination.  If the operation is an LLVM instruction -  // (eg: "add") then it is recorded in AccumulatorRecursionInstr.  If not, then -  // we are handling the case when the return instruction returns a constant C -  // which is different to the constant returned by other return instructions -  // (which is recorded in AccumulatorRecursionEliminationInitVal).  This is a -  // special case of accumulator recursion, the operation being "return C". -  Value *AccumulatorRecursionEliminationInitVal = nullptr; -  Instruction *AccumulatorRecursionInstr = nullptr; +void TailRecursionEliminator::createTailRecurseLoopHeader(CallInst *CI) { +  HeaderBB = &F.getEntryBlock(); +  BasicBlock *NewEntry = BasicBlock::Create(F.getContext(), "", &F, HeaderBB); +  NewEntry->takeName(HeaderBB); +  HeaderBB->setName("tailrecurse"); +  BranchInst *BI = BranchInst::Create(HeaderBB, NewEntry); +  BI->setDebugLoc(CI->getDebugLoc()); + +  // If this function has self recursive calls in the tail position where some +  // are marked tail and some are not, only transform one flavor or another. +  // We have to choose whether we move allocas in the entry block to the new +  // entry block or not, so we can't make a good choice for both. We make this +  // decision here based on whether the first call we found to remove is +  // marked tail. +  // NOTE: We could do slightly better here in the case that the function has +  // no entry block allocas. +  RemovableCallsMustBeMarkedTail = CI->isTailCall(); + +  // If this tail call is marked 'tail' and if there are any allocas in the +  // entry block, move them up to the new entry block. +  if (RemovableCallsMustBeMarkedTail) +    // Move all fixed sized allocas from HeaderBB to NewEntry. +    for (BasicBlock::iterator OEBI = HeaderBB->begin(), E = HeaderBB->end(), +                              NEBI = NewEntry->begin(); +         OEBI != E;) +      if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) +        if (isa<ConstantInt>(AI->getArraySize())) +          AI->moveBefore(&*NEBI); + +  // Now that we have created a new block, which jumps to the entry +  // block, insert a PHI node for each argument of the function. +  // For now, we initialize each PHI to only have the real arguments +  // which are passed in. +  Instruction *InsertPos = &HeaderBB->front(); +  for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { +    PHINode *PN = +        PHINode::Create(I->getType(), 2, I->getName() + ".tr", InsertPos); +    I->replaceAllUsesWith(PN); // Everyone use the PHI node now! +    PN->addIncoming(&*I, NewEntry); +    ArgumentPHIs.push_back(PN); +  } + +  // If the function doen't return void, create the RetPN and RetKnownPN PHI +  // nodes to track our return value. We initialize RetPN with undef and +  // RetKnownPN with false since we can't know our return value at function +  // entry. +  Type *RetType = F.getReturnType(); +  if (!RetType->isVoidTy()) { +    Type *BoolType = Type::getInt1Ty(F.getContext()); +    RetPN = PHINode::Create(RetType, 2, "ret.tr", InsertPos); +    RetKnownPN = PHINode::Create(BoolType, 2, "ret.known.tr", InsertPos); + +    RetPN->addIncoming(UndefValue::get(RetType), NewEntry); +    RetKnownPN->addIncoming(ConstantInt::getFalse(BoolType), NewEntry); +  } + +  // The entry block was changed from HeaderBB to NewEntry. +  // The forward DominatorTree needs to be recalculated when the EntryBB is +  // changed. In this corner-case we recalculate the entire tree. +  DTU.recalculate(*NewEntry->getParent()); +} + +void TailRecursionEliminator::insertAccumulator(Instruction *AccRecInstr) { +  assert(!AccPN && "Trying to insert multiple accumulators"); + +  AccumulatorRecursionInstr = AccRecInstr; + +  // Start by inserting a new PHI node for the accumulator. +  pred_iterator PB = pred_begin(HeaderBB), PE = pred_end(HeaderBB); +  AccPN = PHINode::Create(F.getReturnType(), std::distance(PB, PE) + 1, +                          "accumulator.tr", &HeaderBB->front()); + +  // Loop over all of the predecessors of the tail recursion block.  For the +  // real entry into the function we seed the PHI with the identity constant for +  // the accumulation operation.  For any other existing branches to this block +  // (due to other tail recursions eliminated) the accumulator is not modified. +  // Because we haven't added the branch in the current block to HeaderBB yet, +  // it will not show up as a predecessor. +  for (pred_iterator PI = PB; PI != PE; ++PI) { +    BasicBlock *P = *PI; +    if (P == &F.getEntryBlock()) { +      Constant *Identity = ConstantExpr::getBinOpIdentity( +          AccRecInstr->getOpcode(), AccRecInstr->getType()); +      AccPN->addIncoming(Identity, P); +    } else { +      AccPN->addIncoming(AccPN, P); +    } +  } + +  ++NumAccumAdded; +} + +bool TailRecursionEliminator::eliminateCall(CallInst *CI) { +  ReturnInst *Ret = cast<ReturnInst>(CI->getParent()->getTerminator());    // Ok, we found a potential tail call.  We can currently only transform the    // tail call if all of the instructions between the call and the return are    // movable to above the call itself, leaving the call next to the return.    // Check that this is the case now. +  Instruction *AccRecInstr = nullptr;    BasicBlock::iterator BBI(CI);    for (++BBI; &*BBI != Ret; ++BBI) {      if (canMoveAboveCall(&*BBI, CI, AA)) @@ -526,39 +598,16 @@ static bool eliminateRecursiveTailCall(      // If we can't move the instruction above the call, it might be because it      // is an associative and commutative operation that could be transformed      // using accumulator recursion elimination.  Check to see if this is the -    // case, and if so, remember the initial accumulator value for later. -    if ((AccumulatorRecursionEliminationInitVal = -             canTransformAccumulatorRecursion(&*BBI, CI))) { -      // Yes, this is accumulator recursion.  Remember which instruction -      // accumulates. -      AccumulatorRecursionInstr = &*BBI; -    } else { -      return false;   // Otherwise, we cannot eliminate the tail recursion! -    } -  } +    // case, and if so, remember which instruction accumulates for later. +    if (AccPN || !canTransformAccumulatorRecursion(&*BBI, CI)) +      return false; // We cannot eliminate the tail recursion! -  // We can only transform call/return pairs that either ignore the return value -  // of the call and return void, ignore the value of the call and return a -  // constant, return the value returned by the tail call, or that are being -  // accumulator recursion variable eliminated. -  if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI && -      !isa<UndefValue>(Ret->getReturnValue()) && -      AccumulatorRecursionEliminationInitVal == nullptr && -      !getCommonReturnValue(nullptr, CI)) { -    // One case remains that we are able to handle: the current return -    // instruction returns a constant, and all other return instructions -    // return a different constant. -    if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret)) -      return false; // Current return instruction does not return a constant. -    // Check that all other return instructions return a common constant.  If -    // so, record it in AccumulatorRecursionEliminationInitVal. -    AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI); -    if (!AccumulatorRecursionEliminationInitVal) -      return false; +    // Yes, this is accumulator recursion.  Remember which instruction +    // accumulates. +    AccRecInstr = &*BBI;    }    BasicBlock *BB = Ret->getParent(); -  Function *F = BB->getParent();    using namespace ore;    ORE->emit([&]() { @@ -568,51 +617,10 @@ static bool eliminateRecursiveTailCall(    // OK! We can transform this tail call.  If this is the first one found,    // create the new entry block, allowing us to branch back to the old entry. -  if (!OldEntry) { -    OldEntry = &F->getEntryBlock(); -    BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry); -    NewEntry->takeName(OldEntry); -    OldEntry->setName("tailrecurse"); -    BranchInst *BI = BranchInst::Create(OldEntry, NewEntry); -    BI->setDebugLoc(CI->getDebugLoc()); - -    // If this tail call is marked 'tail' and if there are any allocas in the -    // entry block, move them up to the new entry block. -    TailCallsAreMarkedTail = CI->isTailCall(); -    if (TailCallsAreMarkedTail) -      // Move all fixed sized allocas from OldEntry to NewEntry. -      for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(), -             NEBI = NewEntry->begin(); OEBI != E; ) -        if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) -          if (isa<ConstantInt>(AI->getArraySize())) -            AI->moveBefore(&*NEBI); - -    // Now that we have created a new block, which jumps to the entry -    // block, insert a PHI node for each argument of the function. -    // For now, we initialize each PHI to only have the real arguments -    // which are passed in. -    Instruction *InsertPos = &OldEntry->front(); -    for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); -         I != E; ++I) { -      PHINode *PN = PHINode::Create(I->getType(), 2, -                                    I->getName() + ".tr", InsertPos); -      I->replaceAllUsesWith(PN); // Everyone use the PHI node now! -      PN->addIncoming(&*I, NewEntry); -      ArgumentPHIs.push_back(PN); -    } -    // The entry block was changed from OldEntry to NewEntry. -    // The forward DominatorTree needs to be recalculated when the EntryBB is -    // changed. In this corner-case we recalculate the entire tree. -    DTU.recalculate(*NewEntry->getParent()); -  } +  if (!HeaderBB) +    createTailRecurseLoopHeader(CI); -  // If this function has self recursive calls in the tail position where some -  // are marked tail and some are not, only transform one flavor or another.  We -  // have to choose whether we move allocas in the entry block to the new entry -  // block or not, so we can't make a good choice for both.  NOTE: We could do -  // slightly better here in the case that the function has no entry block -  // allocas. -  if (TailCallsAreMarkedTail && !CI->isTailCall()) +  if (RemovableCallsMustBeMarkedTail && !CI->isTailCall())      return false;    // Ok, now that we know we have a pseudo-entry block WITH all of the @@ -621,74 +629,53 @@ static bool eliminateRecursiveTailCall(    for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i)      ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB); -  // If we are introducing an accumulator variable to eliminate the recursion, -  // do so now.  Note that we _know_ that no subsequent tail recursion -  // eliminations will happen on this function because of the way the -  // accumulator recursion predicate is set up. -  // -  if (AccumulatorRecursionEliminationInitVal) { -    Instruction *AccRecInstr = AccumulatorRecursionInstr; -    // Start by inserting a new PHI node for the accumulator. -    pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry); -    PHINode *AccPN = PHINode::Create( -        AccumulatorRecursionEliminationInitVal->getType(), -        std::distance(PB, PE) + 1, "accumulator.tr", &OldEntry->front()); - -    // Loop over all of the predecessors of the tail recursion block.  For the -    // real entry into the function we seed the PHI with the initial value, -    // computed earlier.  For any other existing branches to this block (due to -    // other tail recursions eliminated) the accumulator is not modified. -    // Because we haven't added the branch in the current block to OldEntry yet, -    // it will not show up as a predecessor. -    for (pred_iterator PI = PB; PI != PE; ++PI) { -      BasicBlock *P = *PI; -      if (P == &F->getEntryBlock()) -        AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P); -      else -        AccPN->addIncoming(AccPN, P); -    } +  if (AccRecInstr) { +    insertAccumulator(AccRecInstr); -    if (AccRecInstr) { -      // Add an incoming argument for the current block, which is computed by -      // our associative and commutative accumulator instruction. -      AccPN->addIncoming(AccRecInstr, BB); +    // Rewrite the accumulator recursion instruction so that it does not use +    // the result of the call anymore, instead, use the PHI node we just +    // inserted. +    AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); +  } -      // Next, rewrite the accumulator recursion instruction so that it does not -      // use the result of the call anymore, instead, use the PHI node we just -      // inserted. -      AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); +  // Update our return value tracking +  if (RetPN) { +    if (Ret->getReturnValue() == CI || AccRecInstr) { +      // Defer selecting a return value +      RetPN->addIncoming(RetPN, BB); +      RetKnownPN->addIncoming(RetKnownPN, BB);      } else { -      // Add an incoming argument for the current block, which is just the -      // constant returned by the current return instruction. -      AccPN->addIncoming(Ret->getReturnValue(), BB); +      // We found a return value we want to use, insert a select instruction to +      // select it if we don't already know what our return value will be and +      // store the result in our return value PHI node. +      SelectInst *SI = SelectInst::Create( +          RetKnownPN, RetPN, Ret->getReturnValue(), "current.ret.tr", Ret); +      RetSelects.push_back(SI); + +      RetPN->addIncoming(SI, BB); +      RetKnownPN->addIncoming(ConstantInt::getTrue(RetKnownPN->getType()), BB);      } -    // Finally, rewrite any return instructions in the program to return the PHI -    // node instead of the "initval" that they do currently.  This loop will -    // actually rewrite the return value we are destroying, but that's ok. -    for (BasicBlock &BBI : *F) -      if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator())) -        RI->setOperand(0, AccPN); -    ++NumAccumAdded; +    if (AccPN) +      AccPN->addIncoming(AccRecInstr ? AccRecInstr : AccPN, BB);    }    // Now that all of the PHI nodes are in place, remove the call and    // ret instructions, replacing them with an unconditional branch. -  BranchInst *NewBI = BranchInst::Create(OldEntry, Ret); +  BranchInst *NewBI = BranchInst::Create(HeaderBB, Ret);    NewBI->setDebugLoc(CI->getDebugLoc());    BB->getInstList().erase(Ret);  // Remove return.    BB->getInstList().erase(CI);   // Remove call. -  DTU.applyUpdates({{DominatorTree::Insert, BB, OldEntry}}); +  DTU.applyUpdates({{DominatorTree::Insert, BB, HeaderBB}});    ++NumEliminated;    return true;  } -static bool foldReturnAndProcessPred( -    BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry, -    bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, -    bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, -    AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { +bool TailRecursionEliminator::foldReturnAndProcessPred( +    ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { +  BasicBlock *BB = Ret->getParent(); +    bool Change = false;    // Make sure this block is a trivial return block. @@ -711,10 +698,11 @@ static bool foldReturnAndProcessPred(    while (!UncondBranchPreds.empty()) {      BranchInst *BI = UncondBranchPreds.pop_back_val();      BasicBlock *Pred = BI->getParent(); -    if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ +    if (CallInst *CI = +            findTRECandidate(BI, CannotTailCallElimCallsMarkedTail)) {        LLVM_DEBUG(dbgs() << "FOLDING: " << *BB                          << "INTO UNCOND BRANCH PRED: " << *Pred); -      ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); +      FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU);        // Cleanup: if all predecessors of BB have been eliminated by        // FoldReturnIntoUncondBranch, delete it.  It is important to empty it, @@ -723,8 +711,7 @@ static bool foldReturnAndProcessPred(        if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB))          DTU.deleteBB(BB); -      eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, -                                 ArgumentPHIs, AA, ORE, DTU); +      eliminateCall(CI);        ++NumRetDuped;        Change = true;      } @@ -733,23 +720,92 @@ static bool foldReturnAndProcessPred(    return Change;  } -static bool processReturningBlock( -    ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, -    SmallVectorImpl<PHINode *> &ArgumentPHIs, -    bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, -    AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { -  CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); +bool TailRecursionEliminator::processReturningBlock( +    ReturnInst *Ret, bool CannotTailCallElimCallsMarkedTail) { +  CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail);    if (!CI)      return false; -  return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, -                                    ArgumentPHIs, AA, ORE, DTU); +  return eliminateCall(CI); +} + +void TailRecursionEliminator::cleanupAndFinalize() { +  // If we eliminated any tail recursions, it's possible that we inserted some +  // silly PHI nodes which just merge an initial value (the incoming operand) +  // with themselves.  Check to see if we did and clean up our mess if so.  This +  // occurs when a function passes an argument straight through to its tail +  // call. +  for (PHINode *PN : ArgumentPHIs) { +    // If the PHI Node is a dynamic constant, replace it with the value it is. +    if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { +      PN->replaceAllUsesWith(PNV); +      PN->eraseFromParent(); +    } +  } + +  if (RetPN) { +    if (RetSelects.empty()) { +      // If we didn't insert any select instructions, then we know we didn't +      // store a return value and we can remove the PHI nodes we inserted. +      RetPN->dropAllReferences(); +      RetPN->eraseFromParent(); + +      RetKnownPN->dropAllReferences(); +      RetKnownPN->eraseFromParent(); + +      if (AccPN) { +        // We need to insert a copy of our accumulator instruction before any +        // return in the function, and return its result instead. +        Instruction *AccRecInstr = AccumulatorRecursionInstr; +        for (BasicBlock &BB : F) { +          ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator()); +          if (!RI) +            continue; + +          Instruction *AccRecInstrNew = AccRecInstr->clone(); +          AccRecInstrNew->setName("accumulator.ret.tr"); +          AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, +                                     RI->getOperand(0)); +          AccRecInstrNew->insertBefore(RI); +          RI->setOperand(0, AccRecInstrNew); +        } +      } +    } else { +      // We need to insert a select instruction before any return left in the +      // function to select our stored return value if we have one. +      for (BasicBlock &BB : F) { +        ReturnInst *RI = dyn_cast<ReturnInst>(BB.getTerminator()); +        if (!RI) +          continue; + +        SelectInst *SI = SelectInst::Create( +            RetKnownPN, RetPN, RI->getOperand(0), "current.ret.tr", RI); +        RetSelects.push_back(SI); +        RI->setOperand(0, SI); +      } + +      if (AccPN) { +        // We need to insert a copy of our accumulator instruction before any +        // of the selects we inserted, and select its result instead. +        Instruction *AccRecInstr = AccumulatorRecursionInstr; +        for (SelectInst *SI : RetSelects) { +          Instruction *AccRecInstrNew = AccRecInstr->clone(); +          AccRecInstrNew->setName("accumulator.ret.tr"); +          AccRecInstrNew->setOperand(AccRecInstr->getOperand(0) == AccPN, +                                     SI->getFalseValue()); +          AccRecInstrNew->insertBefore(SI); +          SI->setFalseValue(AccRecInstrNew); +        } +      } +    } +  }  } -static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, -                                   AliasAnalysis *AA, -                                   OptimizationRemarkEmitter *ORE, -                                   DomTreeUpdater &DTU) { +bool TailRecursionEliminator::eliminate(Function &F, +                                        const TargetTransformInfo *TTI, +                                        AliasAnalysis *AA, +                                        OptimizationRemarkEmitter *ORE, +                                        DomTreeUpdater &DTU) {    if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true")      return false; @@ -762,17 +818,15 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,    // If this function is a varargs function, we won't be able to PHI the args    // right, so don't even try to convert it...    if (F.getFunctionType()->isVarArg()) -    return false; - -  BasicBlock *OldEntry = nullptr; -  bool TailCallsAreMarkedTail = false; -  SmallVector<PHINode*, 8> ArgumentPHIs; +    return MadeChange;    // If false, we cannot perform TRE on tail calls marked with the 'tail'    // attribute, because doing so would cause the stack size to increase (real    // TRE would deallocate variable sized allocas, TRE doesn't).    bool CanTRETailMarkedCall = canTRE(F); +  TailRecursionEliminator TRE(F, TTI, AA, ORE, DTU); +    // Change any tail recursive calls to loops.    //    // FIXME: The code generator produces really bad code when an 'escaping @@ -782,29 +836,14 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,    for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) {      BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB.      if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { -      bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, -                                          ArgumentPHIs, !CanTRETailMarkedCall, -                                          TTI, AA, ORE, DTU); +      bool Change = TRE.processReturningBlock(Ret, !CanTRETailMarkedCall);        if (!Change && BB->getFirstNonPHIOrDbg() == Ret) -        Change = foldReturnAndProcessPred( -            BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, -            !CanTRETailMarkedCall, TTI, AA, ORE, DTU); +        Change = TRE.foldReturnAndProcessPred(Ret, !CanTRETailMarkedCall);        MadeChange |= Change;      }    } -  // If we eliminated any tail recursions, it's possible that we inserted some -  // silly PHI nodes which just merge an initial value (the incoming operand) -  // with themselves.  Check to see if we did and clean up our mess if so.  This -  // occurs when a function passes an argument straight through to its tail -  // call. -  for (PHINode *PN : ArgumentPHIs) { -    // If the PHI Node is a dynamic constant, replace it with the value it is. -    if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { -      PN->replaceAllUsesWith(PNV); -      PN->eraseFromParent(); -    } -  } +  TRE.cleanupAndFinalize();    return MadeChange;  } @@ -838,7 +877,7 @@ struct TailCallElim : public FunctionPass {      // UpdateStrategy to Lazy if we find it profitable later.      DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); -    return eliminateTailRecursion( +    return TailRecursionEliminator::eliminate(          F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),          &getAnalysis<AAResultsWrapperPass>().getAAResults(),          &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU); @@ -871,7 +910,7 @@ PreservedAnalyses TailCallElimPass::run(Function &F,    // UpdateStrategy based on some test results. It is feasible to switch the    // UpdateStrategy to Lazy if we find it profitable later.    DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); -  bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE, DTU); +  bool Changed = TailRecursionEliminator::eliminate(F, &TTI, &AA, &ORE, DTU);    if (!Changed)      return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp index c8461fdc1608..7c81e6352dec 100644 --- a/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp +++ b/llvm/lib/Transforms/Scalar/WarnMissedTransforms.cpp @@ -11,6 +11,7 @@  //===----------------------------------------------------------------------===//  #include "llvm/Transforms/Scalar/WarnMissedTransforms.h" +#include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/InitializePasses.h"  #include "llvm/Transforms/Utils/LoopUtils.h"  | 
