diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2018-07-28 10:51:19 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2018-07-28 10:51:19 +0000 |
commit | eb11fae6d08f479c0799db45860a98af528fa6e7 (patch) | |
tree | 44d492a50c8c1a7eb8e2d17ea3360ec4d066f042 /lib/Transforms/Scalar/LoopPredication.cpp | |
parent | b8a2042aa938069e862750553db0e4d82d25822c (diff) |
Notes
Diffstat (limited to 'lib/Transforms/Scalar/LoopPredication.cpp')
-rw-r--r-- | lib/Transforms/Scalar/LoopPredication.cpp | 211 |
1 files changed, 145 insertions, 66 deletions
diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp index 2e4c7b19e476..561ceea1d880 100644 --- a/lib/Transforms/Scalar/LoopPredication.cpp +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -155,7 +155,7 @@ // When S = -1 (i.e. reverse iterating loop), the transformation is supported // when: // * The loop has a single latch with the condition of the form: -// B(X) = X <pred> latchLimit, where <pred> is u> or s>. +// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=. // * The guard condition is of the form // G(X) = X - 1 u< guardLimit // @@ -171,9 +171,14 @@ // guardStart u< guardLimit && latchLimit u>= 1. // Similarly for sgt condition the widened condition is: // guardStart u< guardLimit && latchLimit s>= 1. +// For uge condition the widened condition is: +// guardStart u< guardLimit && latchLimit u> 1. +// For sge condition the widened condition is: +// guardStart u< guardLimit && latchLimit s> 1. //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -198,6 +203,20 @@ static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation", static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop", cl::Hidden, cl::init(true)); + +static cl::opt<bool> + SkipProfitabilityChecks("loop-predication-skip-profitability-checks", + cl::Hidden, cl::init(false)); + +// This is the scale factor for the latch probability. We use this during +// profitability analysis to find other exiting blocks that have a much higher +// probability of exiting the loop instead of loop exiting via latch. +// This value should be greater than 1 for a sane profitability check. +static cl::opt<float> LatchExitProbabilityScale( + "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0), + cl::desc("scale factor for the latch probability. Value should be greater " + "than 1. Lower values are ignored")); + namespace { class LoopPredication { /// Represents an induction variable check: @@ -217,6 +236,7 @@ class LoopPredication { }; ScalarEvolution *SE; + BranchProbabilityInfo *BPI; Loop *L; const DataLayout *DL; @@ -250,6 +270,12 @@ class LoopPredication { IRBuilder<> &Builder); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + // If the loop always exits through another block in the loop, we should not + // predicate based on the latch check. For example, the latch check can be a + // very coarse grained check and there can be more fine grained exit checks + // within the loop. We identify such unprofitable loops through BPI. + bool isLoopProfitableToPredicate(); + // When the IV type is wider than the range operand type, we can still do loop // predication, by generating SCEVs for the range and latch that are of the // same type. We achieve this by generating a SCEV truncate expression for the @@ -266,8 +292,10 @@ class LoopPredication { // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do // so. Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType); + public: - LoopPredication(ScalarEvolution *SE) : SE(SE){}; + LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI) + : SE(SE), BPI(BPI){}; bool runOnLoop(Loop *L); }; @@ -279,6 +307,7 @@ public: } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<BranchProbabilityInfoWrapperPass>(); getLoopAnalysisUsage(AU); } @@ -286,7 +315,9 @@ public: if (skipLoop(L)) return false; auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LoopPredication LP(SE); + BranchProbabilityInfo &BPI = + getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + LoopPredication LP(SE, &BPI); return LP.runOnLoop(L); } }; @@ -296,6 +327,7 @@ char LoopPredicationLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", "Loop predication", false, false) @@ -307,7 +339,11 @@ Pass *llvm::createLoopPredicationPass() { PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - LoopPredication LP(&AR.SE); + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); + LoopPredication LP(&AR.SE, BPI); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); @@ -375,11 +411,11 @@ LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) { if (!NewLatchCheck.IV) return None; NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType); - DEBUG(dbgs() << "IV of type: " << *LatchType - << "can be represented as range check type:" << *RangeCheckType - << "\n"); - DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n"); - DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n"); + LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType + << "can be represented as range check type:" + << *RangeCheckType << "\n"); + LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n"); + LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n"); return NewLatchCheck; } @@ -412,30 +448,15 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || !CanExpand(LatchLimit) || !CanExpand(RHS)) { - DEBUG(dbgs() << "Can't expand limit check!\n"); + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } - ICmpInst::Predicate LimitCheckPred; - switch (LatchCheck.Pred) { - case ICmpInst::ICMP_ULT: - LimitCheckPred = ICmpInst::ICMP_ULE; - break; - case ICmpInst::ICMP_ULE: - LimitCheckPred = ICmpInst::ICMP_ULT; - break; - case ICmpInst::ICMP_SLT: - LimitCheckPred = ICmpInst::ICMP_SLE; - break; - case ICmpInst::ICMP_SLE: - LimitCheckPred = ICmpInst::ICMP_SLT; - break; - default: - llvm_unreachable("Unsupported loop latch!"); - } + auto LimitCheckPred = + ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); - DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); - DEBUG(dbgs() << "RHS: " << *RHS << "\n"); - DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); + LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); + LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n"); + LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); Instruction *InsertAt = Preheader->getTerminator(); auto *LimitCheck = @@ -454,16 +475,16 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( const SCEV *LatchLimit = LatchCheck.Limit; if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || !CanExpand(LatchLimit)) { - DEBUG(dbgs() << "Can't expand limit check!\n"); + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } // The decrement of the latch check IV should be the same as the // rangeCheckIV. auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE); if (RangeCheck.IV != PostDecLatchCheckIV) { - DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " - << *PostDecLatchCheckIV - << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); + LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " + << *PostDecLatchCheckIV + << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); return None; } @@ -472,9 +493,8 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( // latchLimit <pred> 1. // See the header comment for reasoning of the checks. Instruction *InsertAt = Preheader->getTerminator(); - auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred) - ? ICmpInst::ICMP_SGE - : ICmpInst::ICMP_UGE; + auto LimitCheckPred = + ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT, GuardStart, GuardLimit, InsertAt); auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, @@ -488,8 +508,8 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, IRBuilder<> &Builder) { - DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); - DEBUG(ICI->dump()); + LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); + LLVM_DEBUG(ICI->dump()); // parseLoopStructure guarantees that the latch condition is: // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=. @@ -497,34 +517,34 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, // i u< guardLimit auto RangeCheck = parseLoopICmp(ICI); if (!RangeCheck) { - DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } - DEBUG(dbgs() << "Guard check:\n"); - DEBUG(RangeCheck->dump()); + LLVM_DEBUG(dbgs() << "Guard check:\n"); + LLVM_DEBUG(RangeCheck->dump()); if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { - DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred - << ")!\n"); + LLVM_DEBUG(dbgs() << "Unsupported range check predicate(" + << RangeCheck->Pred << ")!\n"); return None; } auto *RangeCheckIV = RangeCheck->IV; if (!RangeCheckIV->isAffine()) { - DEBUG(dbgs() << "Range check IV is not affine!\n"); + LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n"); return None; } auto *Step = RangeCheckIV->getStepRecurrence(*SE); // We cannot just compare with latch IV step because the latch and range IVs // may have different types. if (!isSupportedStep(Step)) { - DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); + LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); return None; } auto *Ty = RangeCheckIV->getType(); auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty); if (!CurrLatchCheckOpt) { - DEBUG(dbgs() << "Failed to generate a loop latch check " - "corresponding to range type: " - << *Ty << "\n"); + LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " + "corresponding to range type: " + << *Ty << "\n"); return None; } @@ -535,7 +555,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() && "Range and latch steps should be of same type!"); if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) { - DEBUG(dbgs() << "Range and latch have different step values!\n"); + LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n"); return None; } @@ -551,14 +571,14 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, SCEVExpander &Expander) { - DEBUG(dbgs() << "Processing guard:\n"); - DEBUG(Guard->dump()); + LLVM_DEBUG(dbgs() << "Processing guard:\n"); + LLVM_DEBUG(Guard->dump()); IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); // The guard condition is expected to be in form of: // cond1 && cond2 && cond3 ... - // Iterate over subconditions looking for for icmp conditions which can be + // Iterate over subconditions looking for icmp conditions which can be // widened across loop iterations. Widening these conditions remember the // resulting list of subconditions in Checks vector. SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); @@ -605,7 +625,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, LastCheck = Builder.CreateAnd(LastCheck, Check); Guard->setOperand(0, LastCheck); - DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); return true; } @@ -614,7 +634,7 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { BasicBlock *LoopLatch = L->getLoopLatch(); if (!LoopLatch) { - DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); + LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); return None; } @@ -625,7 +645,7 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { if (!match(LoopLatch->getTerminator(), m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest, FalseDest))) { - DEBUG(dbgs() << "Failed to match the latch terminator!\n"); + LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n"); return None; } assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) && @@ -635,20 +655,20 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { auto Result = parseLoopICmp(Pred, LHS, RHS); if (!Result) { - DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } // Check affine first, so if it's not we don't try to compute the step // recurrence. if (!Result->IV->isAffine()) { - DEBUG(dbgs() << "The induction variable is not affine!\n"); + LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n"); return None; } auto *Step = Result->IV->getStepRecurrence(*SE); if (!isSupportedStep(Step)) { - DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); + LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); return None; } @@ -658,13 +678,14 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE; } else { assert(Step->isAllOnesValue() && "Step should be -1!"); - return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT; + return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE; } }; if (IsUnsupportedPredicate(Step, Result->Pred)) { - DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred - << ")!\n"); + LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred + << ")!\n"); return None; } return Result; @@ -700,11 +721,65 @@ bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) { Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; } +bool LoopPredication::isLoopProfitableToPredicate() { + if (SkipProfitabilityChecks || !BPI) + return true; + + SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 8> ExitEdges; + L->getExitEdges(ExitEdges); + // If there is only one exiting edge in the loop, it is always profitable to + // predicate the loop. + if (ExitEdges.size() == 1) + return true; + + // Calculate the exiting probabilities of all exiting edges from the loop, + // starting with the LatchExitProbability. + // Heuristic for profitability: If any of the exiting blocks' probability of + // exiting the loop is larger than exiting through the latch block, it's not + // profitable to predicate the loop. + auto *LatchBlock = L->getLoopLatch(); + assert(LatchBlock && "Should have a single latch at this point!"); + auto *LatchTerm = LatchBlock->getTerminator(); + assert(LatchTerm->getNumSuccessors() == 2 && + "expected to be an exiting block with 2 succs!"); + unsigned LatchBrExitIdx = + LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; + BranchProbability LatchExitProbability = + BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx); + + // Protect against degenerate inputs provided by the user. Providing a value + // less than one, can invert the definition of profitable loop predication. + float ScaleFactor = LatchExitProbabilityScale; + if (ScaleFactor < 1) { + LLVM_DEBUG( + dbgs() + << "Ignored user setting for loop-predication-latch-probability-scale: " + << LatchExitProbabilityScale << "\n"); + LLVM_DEBUG(dbgs() << "The value is set to 1.0\n"); + ScaleFactor = 1.0; + } + const auto LatchProbabilityThreshold = + LatchExitProbability * ScaleFactor; + + for (const auto &ExitEdge : ExitEdges) { + BranchProbability ExitingBlockProbability = + BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second); + // Some exiting edge has higher probability than the latch exiting edge. + // No longer profitable to predicate. + if (ExitingBlockProbability > LatchProbabilityThreshold) + return false; + } + // Using BPI, we have concluded that the most probable way to exit from the + // loop is through the latch (or there's no profile information and all + // exits are equally likely). + return true; +} + bool LoopPredication::runOnLoop(Loop *Loop) { L = Loop; - DEBUG(dbgs() << "Analyzing "); - DEBUG(L->dump()); + LLVM_DEBUG(dbgs() << "Analyzing "); + LLVM_DEBUG(L->dump()); Module *M = L->getHeader()->getModule(); @@ -725,9 +800,13 @@ bool LoopPredication::runOnLoop(Loop *Loop) { return false; LatchCheck = *LatchCheckOpt; - DEBUG(dbgs() << "Latch check:\n"); - DEBUG(LatchCheck.dump()); + LLVM_DEBUG(dbgs() << "Latch check:\n"); + LLVM_DEBUG(LatchCheck.dump()); + if (!isLoopProfitableToPredicate()) { + LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n"); + return false; + } // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector<IntrinsicInst *, 4> Guards; |