diff options
Diffstat (limited to 'lib/Transforms/Vectorize/SLPVectorizer.cpp')
| -rw-r--r-- | lib/Transforms/Vectorize/SLPVectorizer.cpp | 265 | 
1 files changed, 87 insertions, 178 deletions
diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index 1c7cbc7edf9a3..328f270029604 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -4026,40 +4026,36 @@ bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) {    if (!V)      return false; -  Value *P = V->getParent(); - -  // Vectorize in current basic block only. -  auto *Op0 = dyn_cast<Instruction>(V->getOperand(0)); -  auto *Op1 = dyn_cast<Instruction>(V->getOperand(1)); -  if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P) -    return false; -    // Try to vectorize V. -  if (tryToVectorizePair(Op0, Op1, R)) +  if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))      return true; -  auto *A = dyn_cast<BinaryOperator>(Op0); -  auto *B = dyn_cast<BinaryOperator>(Op1); +  BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0)); +  BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));    // Try to skip B.    if (B && B->hasOneUse()) { -    auto *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); -    auto *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); -    if (B0 && B0->getParent() == P && tryToVectorizePair(A, B0, R)) +    BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0)); +    BinaryOperator *B1 = dyn_cast<BinaryOperator>(B->getOperand(1)); +    if (tryToVectorizePair(A, B0, R)) {        return true; -    if (B1 && B1->getParent() == P && tryToVectorizePair(A, B1, R)) +    } +    if (tryToVectorizePair(A, B1, R)) {        return true; +    }    }    // Try to skip A.    if (A && A->hasOneUse()) { -    auto *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); -    auto *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); -    if (A0 && A0->getParent() == P && tryToVectorizePair(A0, B, R)) +    BinaryOperator *A0 = dyn_cast<BinaryOperator>(A->getOperand(0)); +    BinaryOperator *A1 = dyn_cast<BinaryOperator>(A->getOperand(1)); +    if (tryToVectorizePair(A0, B, R)) {        return true; -    if (A1 && A1->getParent() == P && tryToVectorizePair(A1, B, R)) +    } +    if (tryToVectorizePair(A1, B, R)) {        return true; +    }    } -  return false; +  return 0;  }  /// \brief Generate a shuffle mask to be used in a reduction tree. @@ -4511,143 +4507,29 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P,    return nullptr;  } -namespace { -/// Tracks instructons and its children. -class WeakVHWithLevel final : public CallbackVH { -  /// Operand index of the instruction currently beeing analized. -  unsigned Level = 0; -  /// Is this the instruction that should be vectorized, or are we now -  /// processing children (i.e. operands of this instruction) for potential -  /// vectorization? -  bool IsInitial = true; - -public: -  explicit WeakVHWithLevel() = default; -  WeakVHWithLevel(Value *V) : CallbackVH(V){}; -  /// Restart children analysis each time it is repaced by the new instruction. -  void allUsesReplacedWith(Value *New) override { -    setValPtr(New); -    Level = 0; -    IsInitial = true; -  } -  /// Check if the instruction was not deleted during vectorization. -  bool isValid() const { return !getValPtr(); } -  /// Is the istruction itself must be vectorized? -  bool isInitial() const { return IsInitial; } -  /// Try to vectorize children. -  void clearInitial() { IsInitial = false; } -  /// Are all children processed already? -  bool isFinal() const { -    assert(getValPtr() && -           (isa<Instruction>(getValPtr()) && -            cast<Instruction>(getValPtr())->getNumOperands() >= Level)); -    return getValPtr() && -           cast<Instruction>(getValPtr())->getNumOperands() == Level; -  } -  /// Get next child operation. -  Value *nextOperand() { -    assert(getValPtr() && isa<Instruction>(getValPtr()) && -           cast<Instruction>(getValPtr())->getNumOperands() > Level); -    return cast<Instruction>(getValPtr())->getOperand(Level++); -  } -  virtual ~WeakVHWithLevel() = default; -}; -} // namespace -  /// \brief Attempt to reduce a horizontal reduction.  /// If it is legal to match a horizontal reduction feeding -/// the phi node P with reduction operators Root in a basic block BB, then check -/// if it can be done. +/// the phi node P with reduction operators BI, then check if it +/// can be done.  /// \returns true if a horizontal reduction was matched and reduced.  /// \returns false if a horizontal reduction was not matched. -static bool canBeVectorized( -    PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, -    TargetTransformInfo *TTI, -    const function_ref<bool(BinaryOperator *, BoUpSLP &)> Vectorize) { +static bool canMatchHorizontalReduction(PHINode *P, BinaryOperator *BI, +                                        BoUpSLP &R, TargetTransformInfo *TTI, +                                        unsigned MinRegSize) {    if (!ShouldVectorizeHor)      return false; -  if (!Root) -    return false; - -  if (Root->getParent() != BB) +  HorizontalReduction HorRdx(MinRegSize); +  if (!HorRdx.matchAssociativeReduction(P, BI))      return false; -  SmallVector<WeakVHWithLevel, 8> Stack(1, Root); -  SmallSet<Value *, 8> VisitedInstrs; -  bool Res = false; -  while (!Stack.empty()) { -    Value *V = Stack.back(); -    if (!V) { -      Stack.pop_back(); -      continue; -    } -    auto *Inst = dyn_cast<Instruction>(V); -    if (!Inst || isa<PHINode>(Inst)) { -      Stack.pop_back(); -      continue; -    } -    if (Stack.back().isInitial()) { -      Stack.back().clearInitial(); -      if (auto *BI = dyn_cast<BinaryOperator>(Inst)) { -        HorizontalReduction HorRdx(R.getMinVecRegSize()); -        if (HorRdx.matchAssociativeReduction(P, BI)) { -          // If there is a sufficient number of reduction values, reduce -          // to a nearby power-of-2. Can safely generate oversized -          // vectors and rely on the backend to split them to legal sizes. -          HorRdx.ReduxWidth = -              std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues())); - -          if (HorRdx.tryToReduce(R, TTI)) { -            Res = true; -            P = nullptr; -            continue; -          } -        } -        if (P) { -          Inst = dyn_cast<Instruction>(BI->getOperand(0)); -          if (Inst == P) -            Inst = dyn_cast<Instruction>(BI->getOperand(1)); -          if (!Inst) { -            P = nullptr; -            continue; -          } -        } -      } -      P = nullptr; -      if (Vectorize(dyn_cast<BinaryOperator>(Inst), R)) { -        Res = true; -        continue; -      } -    } -    if (Stack.back().isFinal()) { -      Stack.pop_back(); -      continue; -    } -    if (auto *NextV = dyn_cast<Instruction>(Stack.back().nextOperand())) -      if (NextV->getParent() == BB && VisitedInstrs.insert(NextV).second && -          Stack.size() < RecursionMaxDepth) -        Stack.push_back(NextV); -  } -  return Res; -} - -bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V, -                                                 BasicBlock *BB, BoUpSLP &R, -                                                 TargetTransformInfo *TTI) { -  if (!V) -    return false; -  auto *I = dyn_cast<Instruction>(V); -  if (!I) -    return false; +  // If there is a sufficient number of reduction values, reduce +  // to a nearby power-of-2. Can safely generate oversized +  // vectors and rely on the backend to split them to legal sizes. +  HorRdx.ReduxWidth = +    std::max((uint64_t)4, PowerOf2Floor(HorRdx.numReductionValues())); -  if (!isa<BinaryOperator>(I)) -    P = nullptr; -  // Try to match and vectorize a horizontal reduction. -  return canBeVectorized(P, I, BB, R, TTI, -                         [this](BinaryOperator *BI, BoUpSLP &R) -> bool { -                           return tryToVectorize(BI, R); -                         }); +  return HorRdx.tryToReduce(R, TTI);  }  bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { @@ -4717,42 +4599,67 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {        if (P->getNumIncomingValues() != 2)          return Changed; +      Value *Rdx = getReductionValue(DT, P, BB, LI); + +      // Check if this is a Binary Operator. +      BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx); +      if (!BI) +        continue; +        // Try to match and vectorize a horizontal reduction. -      if (vectorizeRootInstruction(P, getReductionValue(DT, P, BB, LI), BB, R, -                                   TTI)) { +      if (canMatchHorizontalReduction(P, BI, R, TTI, R.getMinVecRegSize())) {          Changed = true;          it = BB->begin();          e = BB->end();          continue;        } + +     Value *Inst = BI->getOperand(0); +      if (Inst == P) +        Inst = BI->getOperand(1); + +      if (tryToVectorize(dyn_cast<BinaryOperator>(Inst), R)) { +        // We would like to start over since some instructions are deleted +        // and the iterator may become invalid value. +        Changed = true; +        it = BB->begin(); +        e = BB->end(); +        continue; +      } +        continue;      } -    if (ShouldStartVectorizeHorAtStore) { -      if (StoreInst *SI = dyn_cast<StoreInst>(it)) { -        // Try to match and vectorize a horizontal reduction. -        if (vectorizeRootInstruction(nullptr, SI->getValueOperand(), BB, R, -                                     TTI)) { -          Changed = true; -          it = BB->begin(); -          e = BB->end(); -          continue; +    if (ShouldStartVectorizeHorAtStore) +      if (StoreInst *SI = dyn_cast<StoreInst>(it)) +        if (BinaryOperator *BinOp = +                dyn_cast<BinaryOperator>(SI->getValueOperand())) { +          if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, +                                          R.getMinVecRegSize()) || +              tryToVectorize(BinOp, R)) { +            Changed = true; +            it = BB->begin(); +            e = BB->end(); +            continue; +          }          } -      } -    }      // Try to vectorize horizontal reductions feeding into a return. -    if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) { -      if (RI->getNumOperands() != 0) { -        // Try to match and vectorize a horizontal reduction. -        if (vectorizeRootInstruction(nullptr, RI->getOperand(0), BB, R, TTI)) { -          Changed = true; -          it = BB->begin(); -          e = BB->end(); -          continue; +    if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) +      if (RI->getNumOperands() != 0) +        if (BinaryOperator *BinOp = +                dyn_cast<BinaryOperator>(RI->getOperand(0))) { +          DEBUG(dbgs() << "SLP: Found a return to vectorize.\n"); +          if (canMatchHorizontalReduction(nullptr, BinOp, R, TTI, +                                          R.getMinVecRegSize()) || +              tryToVectorizePair(BinOp->getOperand(0), BinOp->getOperand(1), +                                 R)) { +            Changed = true; +            it = BB->begin(); +            e = BB->end(); +            continue; +          }          } -      } -    }      // Try to vectorize trees that start at compare instructions.      if (CmpInst *CI = dyn_cast<CmpInst>(it)) { @@ -4765,14 +4672,16 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) {          continue;        } -      for (int I = 0; I < 2; ++I) { -        if (vectorizeRootInstruction(nullptr, CI->getOperand(I), BB, R, TTI)) { -          Changed = true; -          // We would like to start over since some instructions are deleted -          // and the iterator may become invalid value. -          it = BB->begin(); -          e = BB->end(); -          break; +      for (int i = 0; i < 2; ++i) { +        if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i))) { +          if (tryToVectorizePair(BI->getOperand(0), BI->getOperand(1), R)) { +            Changed = true; +            // We would like to start over since some instructions are deleted +            // and the iterator may become invalid value. +            it = BB->begin(); +            e = BB->end(); +            break; +          }          }        }        continue;  | 
