diff options
Diffstat (limited to 'lib/Analysis/LoopInfo.cpp')
| -rw-r--r-- | lib/Analysis/LoopInfo.cpp | 135 | 
1 files changed, 94 insertions, 41 deletions
diff --git a/lib/Analysis/LoopInfo.cpp b/lib/Analysis/LoopInfo.cpp index 3f78456b3586c..ef2b1257015ce 100644 --- a/lib/Analysis/LoopInfo.cpp +++ b/lib/Analysis/LoopInfo.cpp @@ -26,6 +26,7 @@  #include "llvm/IR/Constants.h"  #include "llvm/IR/DebugLoc.h"  #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRPrintingPasses.h"  #include "llvm/IR/Instructions.h"  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Metadata.h" @@ -213,33 +214,21 @@ bool Loop::isSafeToClone() const {  MDNode *Loop::getLoopID() const {    MDNode *LoopID = nullptr; -  if (BasicBlock *Latch = getLoopLatch()) { -    LoopID = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop); -  } else { -    assert(!getLoopLatch() && -           "The loop should have no single latch at this point"); -    // Go through each predecessor of the loop header and check the -    // terminator for the metadata. -    BasicBlock *H = getHeader(); -    for (BasicBlock *BB : this->blocks()) { -      TerminatorInst *TI = BB->getTerminator(); -      MDNode *MD = nullptr; - -      // Check if this terminator branches to the loop header. -      for (BasicBlock *Successor : TI->successors()) { -        if (Successor == H) { -          MD = TI->getMetadata(LLVMContext::MD_loop); -          break; -        } -      } -      if (!MD) -        return nullptr; -      if (!LoopID) -        LoopID = MD; -      else if (MD != LoopID) -        return nullptr; -    } +  // Go through the latch blocks and check the terminator for the metadata. +  SmallVector<BasicBlock *, 4> LatchesBlocks; +  getLoopLatches(LatchesBlocks); +  for (BasicBlock *BB : LatchesBlocks) { +    Instruction *TI = BB->getTerminator(); +    MDNode *MD = TI->getMetadata(LLVMContext::MD_loop); + +    if (!MD) +      return nullptr; + +    if (!LoopID) +      LoopID = MD; +    else if (MD != LoopID) +      return nullptr;    }    if (!LoopID || LoopID->getNumOperands() == 0 ||        LoopID->getOperand(0) != LoopID) @@ -248,23 +237,19 @@ MDNode *Loop::getLoopID() const {  }  void Loop::setLoopID(MDNode *LoopID) const { -  assert(LoopID && "Loop ID should not be null"); -  assert(LoopID->getNumOperands() > 0 && "Loop ID needs at least one operand"); -  assert(LoopID->getOperand(0) == LoopID && "Loop ID should refer to itself"); +  assert((!LoopID || LoopID->getNumOperands() > 0) && +         "Loop ID needs at least one operand"); +  assert((!LoopID || LoopID->getOperand(0) == LoopID) && +         "Loop ID should refer to itself"); -  if (BasicBlock *Latch = getLoopLatch()) { -    Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); -    return; -  } - -  assert(!getLoopLatch() && -         "The loop should have no single latch at this point");    BasicBlock *H = getHeader();    for (BasicBlock *BB : this->blocks()) { -    TerminatorInst *TI = BB->getTerminator(); -    for (BasicBlock *Successor : TI->successors()) { -      if (Successor == H) +    Instruction *TI = BB->getTerminator(); +    for (BasicBlock *Successor : successors(TI)) { +      if (Successor == H) {          TI->setMetadata(LLVMContext::MD_loop, LoopID); +        break; +      }      }    }  } @@ -308,16 +293,50 @@ bool Loop::isAnnotatedParallel() const {    if (!DesiredLoopIdMetadata)      return false; +  MDNode *ParallelAccesses = +      findOptionMDForLoop(this, "llvm.loop.parallel_accesses"); +  SmallPtrSet<MDNode *, 4> +      ParallelAccessGroups; // For scalable 'contains' check. +  if (ParallelAccesses) { +    for (const MDOperand &MD : drop_begin(ParallelAccesses->operands(), 1)) { +      MDNode *AccGroup = cast<MDNode>(MD.get()); +      assert(isValidAsAccessGroup(AccGroup) && +             "List item must be an access group"); +      ParallelAccessGroups.insert(AccGroup); +    } +  } +    // The loop branch contains the parallel loop metadata. In order to ensure    // that any parallel-loop-unaware optimization pass hasn't added loop-carried    // dependencies (thus converted the loop back to a sequential loop), check -  // that all the memory instructions in the loop contain parallelism metadata -  // that point to the same unique "loop id metadata" the loop branch does. +  // that all the memory instructions in the loop belong to an access group that +  // is parallel to this loop.    for (BasicBlock *BB : this->blocks()) {      for (Instruction &I : *BB) {        if (!I.mayReadOrWriteMemory())          continue; +      if (MDNode *AccessGroup = I.getMetadata(LLVMContext::MD_access_group)) { +        auto ContainsAccessGroup = [&ParallelAccessGroups](MDNode *AG) -> bool { +          if (AG->getNumOperands() == 0) { +            assert(isValidAsAccessGroup(AG) && "Item must be an access group"); +            return ParallelAccessGroups.count(AG); +          } + +          for (const MDOperand &AccessListItem : AG->operands()) { +            MDNode *AccGroup = cast<MDNode>(AccessListItem.get()); +            assert(isValidAsAccessGroup(AccGroup) && +                   "List item must be an access group"); +            if (ParallelAccessGroups.count(AccGroup)) +              return true; +          } +          return false; +        }; + +        if (ContainsAccessGroup(AccessGroup)) +          continue; +      } +        // The memory instruction can refer to the loop identifier metadata        // directly or indirectly through another list metadata (in case of        // nested parallel loops). The loop identifier metadata refers to @@ -708,6 +727,40 @@ void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) {    }  } +MDNode *llvm::findOptionMDForLoopID(MDNode *LoopID, StringRef Name) { +  // No loop metadata node, no loop properties. +  if (!LoopID) +    return nullptr; + +  // First operand should refer to the metadata node itself, for legacy reasons. +  assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); +  assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + +  // Iterate over the metdata node operands and look for MDString metadata. +  for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { +    MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); +    if (!MD || MD->getNumOperands() < 1) +      continue; +    MDString *S = dyn_cast<MDString>(MD->getOperand(0)); +    if (!S) +      continue; +    // Return the operand node if MDString holds expected metadata. +    if (Name.equals(S->getString())) +      return MD; +  } + +  // Loop property not found. +  return nullptr; +} + +MDNode *llvm::findOptionMDForLoop(const Loop *TheLoop, StringRef Name) { +  return findOptionMDForLoopID(TheLoop->getLoopID(), Name); +} + +bool llvm::isValidAsAccessGroup(MDNode *Node) { +  return Node->getNumOperands() == 0 && Node->isDistinct(); +} +  //===----------------------------------------------------------------------===//  // LoopInfo implementation  //  | 
