summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/LoopPredication.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Scalar/LoopPredication.cpp')
-rw-r--r--lib/Transforms/Scalar/LoopPredication.cpp211
1 files changed, 145 insertions, 66 deletions
diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp
index 2e4c7b19e476..561ceea1d880 100644
--- a/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/lib/Transforms/Scalar/LoopPredication.cpp
@@ -155,7 +155,7 @@
// When S = -1 (i.e. reverse iterating loop), the transformation is supported
// when:
// * The loop has a single latch with the condition of the form:
-// B(X) = X <pred> latchLimit, where <pred> is u> or s>.
+// B(X) = X <pred> latchLimit, where <pred> is u>, u>=, s>, or s>=.
// * The guard condition is of the form
// G(X) = X - 1 u< guardLimit
//
@@ -171,9 +171,14 @@
// guardStart u< guardLimit && latchLimit u>= 1.
// Similarly for sgt condition the widened condition is:
// guardStart u< guardLimit && latchLimit s>= 1.
+// For uge condition the widened condition is:
+// guardStart u< guardLimit && latchLimit u> 1.
+// For sge condition the widened condition is:
+// guardStart u< guardLimit && latchLimit s> 1.
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/LoopPredication.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/ScalarEvolution.h"
@@ -198,6 +203,20 @@ static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation",
static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop",
cl::Hidden, cl::init(true));
+
+static cl::opt<bool>
+ SkipProfitabilityChecks("loop-predication-skip-profitability-checks",
+ cl::Hidden, cl::init(false));
+
+// This is the scale factor for the latch probability. We use this during
+// profitability analysis to find other exiting blocks that have a much higher
+// probability of exiting the loop instead of loop exiting via latch.
+// This value should be greater than 1 for a sane profitability check.
+static cl::opt<float> LatchExitProbabilityScale(
+ "loop-predication-latch-probability-scale", cl::Hidden, cl::init(2.0),
+ cl::desc("scale factor for the latch probability. Value should be greater "
+ "than 1. Lower values are ignored"));
+
namespace {
class LoopPredication {
/// Represents an induction variable check:
@@ -217,6 +236,7 @@ class LoopPredication {
};
ScalarEvolution *SE;
+ BranchProbabilityInfo *BPI;
Loop *L;
const DataLayout *DL;
@@ -250,6 +270,12 @@ class LoopPredication {
IRBuilder<> &Builder);
bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
+ // If the loop always exits through another block in the loop, we should not
+ // predicate based on the latch check. For example, the latch check can be a
+ // very coarse grained check and there can be more fine grained exit checks
+ // within the loop. We identify such unprofitable loops through BPI.
+ bool isLoopProfitableToPredicate();
+
// When the IV type is wider than the range operand type, we can still do loop
// predication, by generating SCEVs for the range and latch that are of the
// same type. We achieve this by generating a SCEV truncate expression for the
@@ -266,8 +292,10 @@ class LoopPredication {
// Return the loopLatchCheck corresponding to the RangeCheckType if safe to do
// so.
Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType);
+
public:
- LoopPredication(ScalarEvolution *SE) : SE(SE){};
+ LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI)
+ : SE(SE), BPI(BPI){};
bool runOnLoop(Loop *L);
};
@@ -279,6 +307,7 @@ public:
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<BranchProbabilityInfoWrapperPass>();
getLoopAnalysisUsage(AU);
}
@@ -286,7 +315,9 @@ public:
if (skipLoop(L))
return false;
auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
- LoopPredication LP(SE);
+ BranchProbabilityInfo &BPI =
+ getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
+ LoopPredication LP(SE, &BPI);
return LP.runOnLoop(L);
}
};
@@ -296,6 +327,7 @@ char LoopPredicationLegacyPass::ID = 0;
INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication",
"Loop predication", false, false)
+INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopPass)
INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication",
"Loop predication", false, false)
@@ -307,7 +339,11 @@ Pass *llvm::createLoopPredicationPass() {
PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
- LoopPredication LP(&AR.SE);
+ const auto &FAM =
+ AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager();
+ Function *F = L.getHeader()->getParent();
+ auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F);
+ LoopPredication LP(&AR.SE, BPI);
if (!LP.runOnLoop(&L))
return PreservedAnalyses::all();
@@ -375,11 +411,11 @@ LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) {
if (!NewLatchCheck.IV)
return None;
NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType);
- DEBUG(dbgs() << "IV of type: " << *LatchType
- << "can be represented as range check type:" << *RangeCheckType
- << "\n");
- DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
- DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
+ LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
+ << "can be represented as range check type:"
+ << *RangeCheckType << "\n");
+ LLVM_DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n");
+ LLVM_DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n");
return NewLatchCheck;
}
@@ -412,30 +448,15 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
SE->getMinusSCEV(LatchStart, SE->getOne(Ty)));
if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
!CanExpand(LatchLimit) || !CanExpand(RHS)) {
- DEBUG(dbgs() << "Can't expand limit check!\n");
+ LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
return None;
}
- ICmpInst::Predicate LimitCheckPred;
- switch (LatchCheck.Pred) {
- case ICmpInst::ICMP_ULT:
- LimitCheckPred = ICmpInst::ICMP_ULE;
- break;
- case ICmpInst::ICMP_ULE:
- LimitCheckPred = ICmpInst::ICMP_ULT;
- break;
- case ICmpInst::ICMP_SLT:
- LimitCheckPred = ICmpInst::ICMP_SLE;
- break;
- case ICmpInst::ICMP_SLE:
- LimitCheckPred = ICmpInst::ICMP_SLT;
- break;
- default:
- llvm_unreachable("Unsupported loop latch!");
- }
+ auto LimitCheckPred =
+ ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred);
- DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
- DEBUG(dbgs() << "RHS: " << *RHS << "\n");
- DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
+ LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n");
+ LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n");
+ LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
Instruction *InsertAt = Preheader->getTerminator();
auto *LimitCheck =
@@ -454,16 +475,16 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
const SCEV *LatchLimit = LatchCheck.Limit;
if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) ||
!CanExpand(LatchLimit)) {
- DEBUG(dbgs() << "Can't expand limit check!\n");
+ LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
return None;
}
// The decrement of the latch check IV should be the same as the
// rangeCheckIV.
auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE);
if (RangeCheck.IV != PostDecLatchCheckIV) {
- DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
- << *PostDecLatchCheckIV
- << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
+ LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
+ << *PostDecLatchCheckIV
+ << " and RangeCheckIV: " << *RangeCheck.IV << "\n");
return None;
}
@@ -472,9 +493,8 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
// latchLimit <pred> 1.
// See the header comment for reasoning of the checks.
Instruction *InsertAt = Preheader->getTerminator();
- auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred)
- ? ICmpInst::ICMP_SGE
- : ICmpInst::ICMP_UGE;
+ auto LimitCheckPred =
+ ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred);
auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
GuardStart, GuardLimit, InsertAt);
auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
@@ -488,8 +508,8 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
SCEVExpander &Expander,
IRBuilder<> &Builder) {
- DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
- DEBUG(ICI->dump());
+ LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
+ LLVM_DEBUG(ICI->dump());
// parseLoopStructure guarantees that the latch condition is:
// ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=.
@@ -497,34 +517,34 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
// i u< guardLimit
auto RangeCheck = parseLoopICmp(ICI);
if (!RangeCheck) {
- DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
+ LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
return None;
}
- DEBUG(dbgs() << "Guard check:\n");
- DEBUG(RangeCheck->dump());
+ LLVM_DEBUG(dbgs() << "Guard check:\n");
+ LLVM_DEBUG(RangeCheck->dump());
if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
- DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred
- << ")!\n");
+ LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
+ << RangeCheck->Pred << ")!\n");
return None;
}
auto *RangeCheckIV = RangeCheck->IV;
if (!RangeCheckIV->isAffine()) {
- DEBUG(dbgs() << "Range check IV is not affine!\n");
+ LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
return None;
}
auto *Step = RangeCheckIV->getStepRecurrence(*SE);
// We cannot just compare with latch IV step because the latch and range IVs
// may have different types.
if (!isSupportedStep(Step)) {
- DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
+ LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
return None;
}
auto *Ty = RangeCheckIV->getType();
auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty);
if (!CurrLatchCheckOpt) {
- DEBUG(dbgs() << "Failed to generate a loop latch check "
- "corresponding to range type: "
- << *Ty << "\n");
+ LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
+ "corresponding to range type: "
+ << *Ty << "\n");
return None;
}
@@ -535,7 +555,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() &&
"Range and latch steps should be of same type!");
if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
- DEBUG(dbgs() << "Range and latch have different step values!\n");
+ LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
return None;
}
@@ -551,14 +571,14 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
SCEVExpander &Expander) {
- DEBUG(dbgs() << "Processing guard:\n");
- DEBUG(Guard->dump());
+ LLVM_DEBUG(dbgs() << "Processing guard:\n");
+ LLVM_DEBUG(Guard->dump());
IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
// The guard condition is expected to be in form of:
// cond1 && cond2 && cond3 ...
- // Iterate over subconditions looking for for icmp conditions which can be
+ // Iterate over subconditions looking for icmp conditions which can be
// widened across loop iterations. Widening these conditions remember the
// resulting list of subconditions in Checks vector.
SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0));
@@ -605,7 +625,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
LastCheck = Builder.CreateAnd(LastCheck, Check);
Guard->setOperand(0, LastCheck);
- DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
+ LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
return true;
}
@@ -614,7 +634,7 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
BasicBlock *LoopLatch = L->getLoopLatch();
if (!LoopLatch) {
- DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
+ LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
return None;
}
@@ -625,7 +645,7 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
if (!match(LoopLatch->getTerminator(),
m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest,
FalseDest))) {
- DEBUG(dbgs() << "Failed to match the latch terminator!\n");
+ LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
return None;
}
assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) &&
@@ -635,20 +655,20 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
auto Result = parseLoopICmp(Pred, LHS, RHS);
if (!Result) {
- DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
+ LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
return None;
}
// Check affine first, so if it's not we don't try to compute the step
// recurrence.
if (!Result->IV->isAffine()) {
- DEBUG(dbgs() << "The induction variable is not affine!\n");
+ LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
return None;
}
auto *Step = Result->IV->getStepRecurrence(*SE);
if (!isSupportedStep(Step)) {
- DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
+ LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
return None;
}
@@ -658,13 +678,14 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() {
Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE;
} else {
assert(Step->isAllOnesValue() && "Step should be -1!");
- return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT;
+ return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT &&
+ Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_SGE;
}
};
if (IsUnsupportedPredicate(Step, Result->Pred)) {
- DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
- << ")!\n");
+ LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
+ << ")!\n");
return None;
}
return Result;
@@ -700,11 +721,65 @@ bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) {
Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
}
+bool LoopPredication::isLoopProfitableToPredicate() {
+ if (SkipProfitabilityChecks || !BPI)
+ return true;
+
+ SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 8> ExitEdges;
+ L->getExitEdges(ExitEdges);
+ // If there is only one exiting edge in the loop, it is always profitable to
+ // predicate the loop.
+ if (ExitEdges.size() == 1)
+ return true;
+
+ // Calculate the exiting probabilities of all exiting edges from the loop,
+ // starting with the LatchExitProbability.
+ // Heuristic for profitability: If any of the exiting blocks' probability of
+ // exiting the loop is larger than exiting through the latch block, it's not
+ // profitable to predicate the loop.
+ auto *LatchBlock = L->getLoopLatch();
+ assert(LatchBlock && "Should have a single latch at this point!");
+ auto *LatchTerm = LatchBlock->getTerminator();
+ assert(LatchTerm->getNumSuccessors() == 2 &&
+ "expected to be an exiting block with 2 succs!");
+ unsigned LatchBrExitIdx =
+ LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
+ BranchProbability LatchExitProbability =
+ BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx);
+
+ // Protect against degenerate inputs provided by the user. Providing a value
+ // less than one, can invert the definition of profitable loop predication.
+ float ScaleFactor = LatchExitProbabilityScale;
+ if (ScaleFactor < 1) {
+ LLVM_DEBUG(
+ dbgs()
+ << "Ignored user setting for loop-predication-latch-probability-scale: "
+ << LatchExitProbabilityScale << "\n");
+ LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
+ ScaleFactor = 1.0;
+ }
+ const auto LatchProbabilityThreshold =
+ LatchExitProbability * ScaleFactor;
+
+ for (const auto &ExitEdge : ExitEdges) {
+ BranchProbability ExitingBlockProbability =
+ BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second);
+ // Some exiting edge has higher probability than the latch exiting edge.
+ // No longer profitable to predicate.
+ if (ExitingBlockProbability > LatchProbabilityThreshold)
+ return false;
+ }
+ // Using BPI, we have concluded that the most probable way to exit from the
+ // loop is through the latch (or there's no profile information and all
+ // exits are equally likely).
+ return true;
+}
+
bool LoopPredication::runOnLoop(Loop *Loop) {
L = Loop;
- DEBUG(dbgs() << "Analyzing ");
- DEBUG(L->dump());
+ LLVM_DEBUG(dbgs() << "Analyzing ");
+ LLVM_DEBUG(L->dump());
Module *M = L->getHeader()->getModule();
@@ -725,9 +800,13 @@ bool LoopPredication::runOnLoop(Loop *Loop) {
return false;
LatchCheck = *LatchCheckOpt;
- DEBUG(dbgs() << "Latch check:\n");
- DEBUG(LatchCheck.dump());
+ LLVM_DEBUG(dbgs() << "Latch check:\n");
+ LLVM_DEBUG(LatchCheck.dump());
+ if (!isLoopProfitableToPredicate()) {
+ LLVM_DEBUG(dbgs() << "Loop not profitable to predicate!\n");
+ return false;
+ }
// Collect all the guards into a vector and process later, so as not
// to invalidate the instruction iterator.
SmallVector<IntrinsicInst *, 4> Guards;