diff options
Diffstat (limited to 'lib/Analysis/LoopAccessAnalysis.cpp')
| -rw-r--r-- | lib/Analysis/LoopAccessAnalysis.cpp | 248 | 
1 files changed, 216 insertions, 32 deletions
diff --git a/lib/Analysis/LoopAccessAnalysis.cpp b/lib/Analysis/LoopAccessAnalysis.cpp index 4ba12583ff839..ed8e5e8cc489f 100644 --- a/lib/Analysis/LoopAccessAnalysis.cpp +++ b/lib/Analysis/LoopAccessAnalysis.cpp @@ -29,7 +29,7 @@  #include "llvm/Analysis/LoopAnalysisManager.h"  #include "llvm/Analysis/LoopInfo.h"  #include "llvm/Analysis/MemoryLocation.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h"  #include "llvm/Analysis/ScalarEvolution.h"  #include "llvm/Analysis/ScalarEvolutionExpander.h"  #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -522,6 +522,21 @@ public:      Accesses.insert(MemAccessInfo(Ptr, true));    } +  /// \brief Check if we can emit a run-time no-alias check for \p Access. +  /// +  /// Returns true if we can emit a run-time no alias check for \p Access. +  /// If we can check this access, this also adds it to a dependence set and +  /// adds a run-time to check for it to \p RtCheck. If \p Assume is true, +  /// we will attempt to use additional run-time checks in order to get +  /// the bounds of the pointer. +  bool createCheckForAccess(RuntimePointerChecking &RtCheck, +                            MemAccessInfo Access, +                            const ValueToValueMap &Strides, +                            DenseMap<Value *, unsigned> &DepSetId, +                            Loop *TheLoop, unsigned &RunningDepId, +                            unsigned ASId, bool ShouldCheckStride, +                            bool Assume); +    /// \brief Check whether we can check the pointers at runtime for    /// non-intersection.    /// @@ -597,9 +612,11 @@ private:  } // end anonymous namespace  /// \brief Check whether a pointer can participate in a runtime bounds check. +/// If \p Assume, try harder to prove that we can compute the bounds of \p Ptr +/// by adding run-time checks (overflow checks) if necessary.  static bool hasComputableBounds(PredicatedScalarEvolution &PSE,                                  const ValueToValueMap &Strides, Value *Ptr, -                                Loop *L) { +                                Loop *L, bool Assume) {    const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr);    // The bounds for loop-invariant pointer is trivial. @@ -607,6 +624,10 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE,      return true;    const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); + +  if (!AR && Assume) +    AR = PSE.getAsAddRec(Ptr); +    if (!AR)      return false; @@ -621,9 +642,53 @@ static bool isNoWrap(PredicatedScalarEvolution &PSE,      return true;    int64_t Stride = getPtrStride(PSE, Ptr, L, Strides); -  return Stride == 1; +  if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) +    return true; + +  return false;  } +bool AccessAnalysis::createCheckForAccess(RuntimePointerChecking &RtCheck, +                                          MemAccessInfo Access, +                                          const ValueToValueMap &StridesMap, +                                          DenseMap<Value *, unsigned> &DepSetId, +                                          Loop *TheLoop, unsigned &RunningDepId, +                                          unsigned ASId, bool ShouldCheckWrap, +                                          bool Assume) { +  Value *Ptr = Access.getPointer(); + +  if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume)) +    return false; + +  // When we run after a failing dependency check we have to make sure +  // we don't have wrapping pointers. +  if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { +    auto *Expr = PSE.getSCEV(Ptr); +    if (!Assume || !isa<SCEVAddRecExpr>(Expr)) +      return false; +    PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); +  } + +  // The id of the dependence set. +  unsigned DepId; + +  if (isDependencyCheckNeeded()) { +    Value *Leader = DepCands.getLeaderValue(Access).getPointer(); +    unsigned &LeaderId = DepSetId[Leader]; +    if (!LeaderId) +      LeaderId = RunningDepId++; +    DepId = LeaderId; +  } else +    // Each access has its own dependence set. +    DepId = RunningDepId++; + +  bool IsWrite = Access.getInt(); +  RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); +  DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); + +  return true; + } +  bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,                                       ScalarEvolution *SE, Loop *TheLoop,                                       const ValueToValueMap &StridesMap, @@ -643,12 +708,15 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,    for (auto &AS : AST) {      int NumReadPtrChecks = 0;      int NumWritePtrChecks = 0; +    bool CanDoAliasSetRT = true;      // We assign consecutive id to access from different dependence sets.      // Accesses within the same set don't need a runtime check.      unsigned RunningDepId = 1;      DenseMap<Value *, unsigned> DepSetId; +    SmallVector<MemAccessInfo, 4> Retries; +      for (auto A : AS) {        Value *Ptr = A.getValue();        bool IsWrite = Accesses.count(MemAccessInfo(Ptr, true)); @@ -659,29 +727,11 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,        else          ++NumReadPtrChecks; -      if (hasComputableBounds(PSE, StridesMap, Ptr, TheLoop) && -          // When we run after a failing dependency check we have to make sure -          // we don't have wrapping pointers. -          (!ShouldCheckWrap || isNoWrap(PSE, StridesMap, Ptr, TheLoop))) { -        // The id of the dependence set. -        unsigned DepId; - -        if (IsDepCheckNeeded) { -          Value *Leader = DepCands.getLeaderValue(Access).getPointer(); -          unsigned &LeaderId = DepSetId[Leader]; -          if (!LeaderId) -            LeaderId = RunningDepId++; -          DepId = LeaderId; -        } else -          // Each access has its own dependence set. -          DepId = RunningDepId++; - -        RtCheck.insert(TheLoop, Ptr, IsWrite, DepId, ASId, StridesMap, PSE); - -        DEBUG(dbgs() << "LAA: Found a runtime check ptr:" << *Ptr << '\n'); -      } else { +      if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, TheLoop, +                                RunningDepId, ASId, ShouldCheckWrap, false)) {          DEBUG(dbgs() << "LAA: Can't find bounds for ptr:" << *Ptr << '\n'); -        CanDoRT = false; +        Retries.push_back(Access); +        CanDoAliasSetRT = false;        }      } @@ -693,10 +743,29 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck,      // For example CanDoRT=false, NeedRTCheck=false means that we have a pointer      // for which we couldn't find the bounds but we don't actually need to emit      // any checks so it does not matter. -    if (!(IsDepCheckNeeded && CanDoRT && RunningDepId == 2)) -      NeedRTCheck |= (NumWritePtrChecks >= 2 || (NumReadPtrChecks >= 1 && -                                                 NumWritePtrChecks >= 1)); +    bool NeedsAliasSetRTCheck = false; +    if (!(IsDepCheckNeeded && CanDoAliasSetRT && RunningDepId == 2)) +      NeedsAliasSetRTCheck = (NumWritePtrChecks >= 2 || +                             (NumReadPtrChecks >= 1 && NumWritePtrChecks >= 1)); + +    // We need to perform run-time alias checks, but some pointers had bounds +    // that couldn't be checked. +    if (NeedsAliasSetRTCheck && !CanDoAliasSetRT) { +      // Reset the CanDoSetRt flag and retry all accesses that have failed. +      // We know that we need these checks, so we can now be more aggressive +      // and add further checks if required (overflow checks). +      CanDoAliasSetRT = true; +      for (auto Access : Retries) +        if (!createCheckForAccess(RtCheck, Access, StridesMap, DepSetId, +                                  TheLoop, RunningDepId, ASId, +                                  ShouldCheckWrap, /*Assume=*/true)) { +          CanDoAliasSetRT = false; +          break; +        } +    } +    CanDoRT &= CanDoAliasSetRT; +    NeedRTCheck |= NeedsAliasSetRTCheck;      ++ASId;    } @@ -1038,6 +1107,77 @@ static unsigned getAddressSpaceOperand(Value *I) {    return -1;  } +// TODO:This API can be improved by using the permutation of given width as the +// accesses are entered into the map. +bool llvm::sortLoadAccesses(ArrayRef<Value *> VL, const DataLayout &DL, +                           ScalarEvolution &SE, +                           SmallVectorImpl<Value *> &Sorted, +                           SmallVectorImpl<unsigned> *Mask) { +  SmallVector<std::pair<int64_t, Value *>, 4> OffValPairs; +  OffValPairs.reserve(VL.size()); +  Sorted.reserve(VL.size()); + +  // Walk over the pointers, and map each of them to an offset relative to +  // first pointer in the array. +  Value *Ptr0 = getPointerOperand(VL[0]); +  const SCEV *Scev0 = SE.getSCEV(Ptr0); +  Value *Obj0 = GetUnderlyingObject(Ptr0, DL); +  PointerType *PtrTy = dyn_cast<PointerType>(Ptr0->getType()); +  uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); + +  for (auto *Val : VL) { +    // The only kind of access we care about here is load. +    if (!isa<LoadInst>(Val)) +      return false; + +    Value *Ptr = getPointerOperand(Val); +    assert(Ptr && "Expected value to have a pointer operand."); +    // If a pointer refers to a different underlying object, bail - the +    // pointers are by definition incomparable. +    Value *CurrObj = GetUnderlyingObject(Ptr, DL); +    if (CurrObj != Obj0) +      return false; + +    const SCEVConstant *Diff = +        dyn_cast<SCEVConstant>(SE.getMinusSCEV(SE.getSCEV(Ptr), Scev0)); +    // The pointers may not have a constant offset from each other, or SCEV +    // may just not be smart enough to figure out they do. Regardless, +    // there's nothing we can do. +    if (!Diff || static_cast<unsigned>(Diff->getAPInt().abs().getSExtValue()) > +                     (VL.size() - 1) * Size) +      return false; + +    OffValPairs.emplace_back(Diff->getAPInt().getSExtValue(), Val); +  } +  SmallVector<unsigned, 4> UseOrder(VL.size()); +  for (unsigned i = 0; i < VL.size(); i++) { +    UseOrder[i] = i; +  } + +  // Sort the memory accesses and keep the order of their uses in UseOrder. +  std::sort(UseOrder.begin(), UseOrder.end(), +            [&OffValPairs](unsigned Left, unsigned Right) { +            return OffValPairs[Left].first < OffValPairs[Right].first; +            }); + +  for (unsigned i = 0; i < VL.size(); i++) +    Sorted.emplace_back(OffValPairs[UseOrder[i]].second); + +  // Sort UseOrder to compute the Mask. +  if (Mask) { +    Mask->reserve(VL.size()); +    for (unsigned i = 0; i < VL.size(); i++) +      Mask->emplace_back(i); +    std::sort(Mask->begin(), Mask->end(), +              [&UseOrder](unsigned Left, unsigned Right) { +              return UseOrder[Left] < UseOrder[Right]; +              }); +  } + +  return true; +} + +  /// Returns true if the memory operations \p A and \p B are consecutive.  bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL,                                 ScalarEvolution &SE, bool CheckType) { @@ -1471,10 +1611,11 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,        couldPreventStoreLoadForward(Distance, TypeByteSize))      return Dependence::BackwardVectorizableButPreventsForwarding; +  uint64_t MaxVF = MaxSafeDepDistBytes / (TypeByteSize * Stride);    DEBUG(dbgs() << "LAA: Positive distance " << Val.getSExtValue() -               << " with max VF = " -               << MaxSafeDepDistBytes / (TypeByteSize * Stride) << '\n'); - +               << " with max VF = " << MaxVF << '\n'); +  uint64_t MaxVFInBits = MaxVF * TypeByteSize * 8; +  MaxSafeRegisterWidth = std::min(MaxSafeRegisterWidth, MaxVFInBits);    return Dependence::BackwardVectorizable;  } @@ -2066,8 +2207,51 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {    if (!Stride)      return; -  DEBUG(dbgs() << "LAA: Found a strided access that we can version"); +  DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for " +                  "versioning:");    DEBUG(dbgs() << "  Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + +  // Avoid adding the "Stride == 1" predicate when we know that  +  // Stride >= Trip-Count. Such a predicate will effectively optimize a single +  // or zero iteration loop, as Trip-Count <= Stride == 1. +  //  +  // TODO: We are currently not making a very informed decision on when it is +  // beneficial to apply stride versioning. It might make more sense that the +  // users of this analysis (such as the vectorizer) will trigger it, based on  +  // their specific cost considerations; For example, in cases where stride  +  // versioning does  not help resolving memory accesses/dependences, the +  // vectorizer should evaluate the cost of the runtime test, and the benefit  +  // of various possible stride specializations, considering the alternatives  +  // of using gather/scatters (if available).  +   +  const SCEV *StrideExpr = PSE->getSCEV(Stride); +  const SCEV *BETakenCount = PSE->getBackedgeTakenCount();   + +  // Match the types so we can compare the stride and the BETakenCount. +  // The Stride can be positive/negative, so we sign extend Stride;  +  // The backdgeTakenCount is non-negative, so we zero extend BETakenCount. +  const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); +  uint64_t StrideTypeSize = DL.getTypeAllocSize(StrideExpr->getType()); +  uint64_t BETypeSize = DL.getTypeAllocSize(BETakenCount->getType()); +  const SCEV *CastedStride = StrideExpr; +  const SCEV *CastedBECount = BETakenCount; +  ScalarEvolution *SE = PSE->getSE(); +  if (BETypeSize >= StrideTypeSize) +    CastedStride = SE->getNoopOrSignExtend(StrideExpr, BETakenCount->getType()); +  else +    CastedBECount = SE->getZeroExtendExpr(BETakenCount, StrideExpr->getType()); +  const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount); +  // Since TripCount == BackEdgeTakenCount + 1, checking: +  // "Stride >= TripCount" is equivalent to checking:  +  // Stride - BETakenCount > 0 +  if (SE->isKnownPositive(StrideMinusBETaken)) { +    DEBUG(dbgs() << "LAA: Stride>=TripCount; No point in versioning as the " +                    "Stride==1 predicate will imply that the loop executes " +                    "at most once.\n"); +    return; +  }   +  DEBUG(dbgs() << "LAA: Found a strided access that we can version."); +    SymbolicStrides[Ptr] = Stride;    StrideSet.insert(Stride);  }  | 
