aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LoopPredication.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopPredication.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPredication.cpp182
1 files changed, 95 insertions, 87 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index b327d38d2a84..49c0fff84d81 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -191,6 +191,7 @@
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/ProfDataUtils.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
@@ -200,6 +201,7 @@
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
+#include <optional>
#define DEBUG_TYPE "loop-predication"
@@ -233,6 +235,13 @@ static cl::opt<bool> PredicateWidenableBranchGuards(
"expressed as widenable branches to deoptimize blocks"),
cl::init(true));
+static cl::opt<bool> InsertAssumesOfPredicatedGuardsConditions(
+ "loop-predication-insert-assumes-of-predicated-guards-conditions",
+ cl::Hidden,
+ cl::desc("Whether or not we should insert assumes of conditions of "
+ "predicated guards"),
+ cl::init(true));
+
namespace {
/// Represents an induction variable check:
/// icmp Pred, <induction variable>, <loop invariant limit>
@@ -263,8 +272,8 @@ class LoopPredication {
LoopICmp LatchCheck;
bool isSupportedStep(const SCEV* Step);
- Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
- Optional<LoopICmp> parseLoopLatchICmp();
+ std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI);
+ std::optional<LoopICmp> parseLoopLatchICmp();
/// Return an insertion point suitable for inserting a safe to speculate
/// instruction whose only user will be 'User' which has operands 'Ops'. A
@@ -287,16 +296,17 @@ class LoopPredication {
ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS);
- Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
- Instruction *Guard);
- Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck,
- LoopICmp RangeCheck,
- SCEVExpander &Expander,
- Instruction *Guard);
- Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
- LoopICmp RangeCheck,
- SCEVExpander &Expander,
- Instruction *Guard);
+ std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI,
+ SCEVExpander &Expander,
+ Instruction *Guard);
+ std::optional<Value *>
+ widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
+ SCEVExpander &Expander,
+ Instruction *Guard);
+ std::optional<Value *>
+ widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck,
+ SCEVExpander &Expander,
+ Instruction *Guard);
unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition,
SCEVExpander &Expander, Instruction *Guard);
bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
@@ -376,18 +386,17 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
return PA;
}
-Optional<LoopICmp>
-LoopPredication::parseLoopICmp(ICmpInst *ICI) {
+std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) {
auto Pred = ICI->getPredicate();
auto *LHS = ICI->getOperand(0);
auto *RHS = ICI->getOperand(1);
const SCEV *LHSS = SE->getSCEV(LHS);
if (isa<SCEVCouldNotCompute>(LHSS))
- return None;
+ return std::nullopt;
const SCEV *RHSS = SE->getSCEV(RHS);
if (isa<SCEVCouldNotCompute>(RHSS))
- return None;
+ return std::nullopt;
// Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV
if (SE->isLoopInvariant(LHSS, L)) {
@@ -398,7 +407,7 @@ LoopPredication::parseLoopICmp(ICmpInst *ICI) {
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS);
if (!AR || AR->getLoop() != L)
- return None;
+ return std::nullopt;
return LoopICmp(Pred, AR, RHSS);
}
@@ -446,8 +455,8 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL,
Type *RangeCheckType) {
if (!EnableIVTruncation)
return false;
- assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedSize() >
- DL.getTypeSizeInBits(RangeCheckType).getFixedSize() &&
+ assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() >
+ DL.getTypeSizeInBits(RangeCheckType).getFixedValue() &&
"Expected latch check IV type to be larger than range check operand "
"type!");
// The start and end values of the IV should be known. This is to guarantee
@@ -467,7 +476,7 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL,
// guarantees that truncating the latch check to RangeCheckType is a safe
// operation.
auto RangeCheckTypeBitSize =
- DL.getTypeSizeInBits(RangeCheckType).getFixedSize();
+ DL.getTypeSizeInBits(RangeCheckType).getFixedValue();
return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize &&
Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize;
}
@@ -475,20 +484,20 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL,
// Return an LoopICmp describing a latch check equivlent to LatchCheck but with
// the requested type if safe to do so. May involve the use of a new IV.
-static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
- ScalarEvolution &SE,
- const LoopICmp LatchCheck,
- Type *RangeCheckType) {
+static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
+ ScalarEvolution &SE,
+ const LoopICmp LatchCheck,
+ Type *RangeCheckType) {
auto *LatchType = LatchCheck.IV->getType();
if (RangeCheckType == LatchType)
return LatchCheck;
// For now, bail out if latch type is narrower than range type.
- if (DL.getTypeSizeInBits(LatchType).getFixedSize() <
- DL.getTypeSizeInBits(RangeCheckType).getFixedSize())
- return None;
+ if (DL.getTypeSizeInBits(LatchType).getFixedValue() <
+ DL.getTypeSizeInBits(RangeCheckType).getFixedValue())
+ return std::nullopt;
if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType))
- return None;
+ return std::nullopt;
// We can now safely identify the truncated version of the IV and limit for
// RangeCheckType.
LoopICmp NewLatchCheck;
@@ -496,7 +505,7 @@ static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL,
NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>(
SE.getTruncateExpr(LatchCheck.IV, RangeCheckType));
if (!NewLatchCheck.IV)
- return None;
+ return std::nullopt;
NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType);
LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType
<< "can be represented as range check type:"
@@ -562,15 +571,15 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) {
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S))
if (const auto *LI = dyn_cast<LoadInst>(U->getValue()))
if (LI->isUnordered() && L->hasLoopInvariantOperands(LI))
- if (AA->pointsToConstantMemory(LI->getOperand(0)) ||
+ if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) ||
LI->hasMetadata(LLVMContext::MD_invariant_load))
return true;
return false;
}
-Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
- LoopICmp LatchCheck, LoopICmp RangeCheck,
- SCEVExpander &Expander, Instruction *Guard) {
+std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
+ LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
+ Instruction *Guard) {
auto *Ty = RangeCheck.IV->getType();
// Generate the widened condition for the forward loop:
// guardStart u< guardLimit &&
@@ -590,12 +599,12 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
!isLoopInvariantValue(LatchStart) ||
!isLoopInvariantValue(LatchLimit)) {
LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
- return None;
+ return std::nullopt;
}
if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
!Expander.isSafeToExpandAt(LatchLimit, Guard)) {
LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
- return None;
+ return std::nullopt;
}
// guardLimit - guardStart + latchStart - 1
@@ -617,9 +626,9 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
}
-Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
- LoopICmp LatchCheck, LoopICmp RangeCheck,
- SCEVExpander &Expander, Instruction *Guard) {
+std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
+ LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander,
+ Instruction *Guard) {
auto *Ty = RangeCheck.IV->getType();
const SCEV *GuardStart = RangeCheck.IV->getStart();
const SCEV *GuardLimit = RangeCheck.Limit;
@@ -633,12 +642,12 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
!isLoopInvariantValue(LatchStart) ||
!isLoopInvariantValue(LatchLimit)) {
LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
- return None;
+ return std::nullopt;
}
if (!Expander.isSafeToExpandAt(LatchStart, Guard) ||
!Expander.isSafeToExpandAt(LatchLimit, Guard)) {
LLVM_DEBUG(dbgs() << "Can't expand limit check!\n");
- return None;
+ return std::nullopt;
}
// The decrement of the latch check IV should be the same as the
// rangeCheckIV.
@@ -647,7 +656,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: "
<< *PostDecLatchCheckIV
<< " and RangeCheckIV: " << *RangeCheck.IV << "\n");
- return None;
+ return std::nullopt;
}
// Generate the widened condition for CountDownLoop:
@@ -676,13 +685,12 @@ static void normalizePredicate(ScalarEvolution *SE, Loop *L,
ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE;
}
-
/// If ICI can be widened to a loop invariant condition emits the loop
/// invariant condition in the loop preheader and return it, otherwise
-/// returns None.
-Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
- SCEVExpander &Expander,
- Instruction *Guard) {
+/// returns std::nullopt.
+std::optional<Value *>
+LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
+ Instruction *Guard) {
LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
LLVM_DEBUG(ICI->dump());
@@ -693,26 +701,26 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
auto RangeCheck = parseLoopICmp(ICI);
if (!RangeCheck) {
LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
- return None;
+ return std::nullopt;
}
LLVM_DEBUG(dbgs() << "Guard check:\n");
LLVM_DEBUG(RangeCheck->dump());
if (RangeCheck->Pred != ICmpInst::ICMP_ULT) {
LLVM_DEBUG(dbgs() << "Unsupported range check predicate("
<< RangeCheck->Pred << ")!\n");
- return None;
+ return std::nullopt;
}
auto *RangeCheckIV = RangeCheck->IV;
if (!RangeCheckIV->isAffine()) {
LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n");
- return None;
+ return std::nullopt;
}
auto *Step = RangeCheckIV->getStepRecurrence(*SE);
// We cannot just compare with latch IV step because the latch and range IVs
// may have different types.
if (!isSupportedStep(Step)) {
LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n");
- return None;
+ return std::nullopt;
}
auto *Ty = RangeCheckIV->getType();
auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty);
@@ -720,7 +728,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check "
"corresponding to range type: "
<< *Ty << "\n");
- return None;
+ return std::nullopt;
}
LoopICmp CurrLatchCheck = *CurrLatchCheckOpt;
@@ -731,7 +739,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
"Range and latch steps should be of same type!");
if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) {
LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n");
- return None;
+ return std::nullopt;
}
if (Step->isOne())
@@ -756,17 +764,17 @@ unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,
// resulting list of subconditions in Checks vector.
SmallVector<Value *, 4> Worklist(1, Condition);
SmallPtrSet<Value *, 4> Visited;
+ Visited.insert(Condition);
Value *WideableCond = nullptr;
do {
Value *Condition = Worklist.pop_back_val();
- if (!Visited.insert(Condition).second)
- continue;
-
Value *LHS, *RHS;
using namespace llvm::PatternMatch;
if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) {
- Worklist.push_back(LHS);
- Worklist.push_back(RHS);
+ if (Visited.insert(LHS).second)
+ Worklist.push_back(LHS);
+ if (Visited.insert(RHS).second)
+ Worklist.push_back(RHS);
continue;
}
@@ -817,6 +825,10 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
Value *AllChecks = Builder.CreateAnd(Checks);
auto *OldCond = Guard->getOperand(0);
Guard->setOperand(0, AllChecks);
+ if (InsertAssumesOfPredicatedGuardsConditions) {
+ Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard));
+ Builder.CreateAssumption(OldCond);
+ }
RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
@@ -829,6 +841,12 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
LLVM_DEBUG(dbgs() << "Processing guard:\n");
LLVM_DEBUG(BI->dump());
+ Value *Cond, *WC;
+ BasicBlock *IfTrueBB, *IfFalseBB;
+ bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB);
+ assert(Parsed && "Must be able to parse widenable branch");
+ (void)Parsed;
+
TotalConsidered++;
SmallVector<Value *, 4> Checks;
unsigned NumWidened = collectChecks(Checks, BI->getCondition(),
@@ -843,6 +861,10 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
Value *AllChecks = Builder.CreateAnd(Checks);
auto *OldCond = BI->getCondition();
BI->setCondition(AllChecks);
+ if (InsertAssumesOfPredicatedGuardsConditions) {
+ Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt());
+ Builder.CreateAssumption(Cond);
+ }
RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
assert(isGuardAsWidenableBranch(BI) &&
"Stopped being a guard after transform?");
@@ -851,19 +873,19 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
return true;
}
-Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
+std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
using namespace PatternMatch;
BasicBlock *LoopLatch = L->getLoopLatch();
if (!LoopLatch) {
LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n");
- return None;
+ return std::nullopt;
}
auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
if (!BI || !BI->isConditional()) {
LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n");
- return None;
+ return std::nullopt;
}
BasicBlock *TrueDest = BI->getSuccessor(0);
assert(
@@ -873,12 +895,12 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
auto *ICI = dyn_cast<ICmpInst>(BI->getCondition());
if (!ICI) {
LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n");
- return None;
+ return std::nullopt;
}
auto Result = parseLoopICmp(ICI);
if (!Result) {
LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n");
- return None;
+ return std::nullopt;
}
if (TrueDest != L->getHeader())
@@ -888,13 +910,13 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
// recurrence.
if (!Result->IV->isAffine()) {
LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n");
- return None;
+ return std::nullopt;
}
auto *Step = Result->IV->getStepRecurrence(*SE);
if (!isSupportedStep(Step)) {
LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n");
- return None;
+ return std::nullopt;
}
auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) {
@@ -912,13 +934,12 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
if (IsUnsupportedPredicate(Step, Result->Pred)) {
LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred
<< ")!\n");
- return None;
+ return std::nullopt;
}
return Result;
}
-
bool LoopPredication::isLoopProfitableToPredicate() {
if (SkipProfitabilityChecks)
return true;
@@ -954,37 +975,24 @@ bool LoopPredication::isLoopProfitableToPredicate() {
LatchExitBlock->getTerminatingDeoptimizeCall())
return false;
- auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) {
- if (!ProfileData || !ProfileData->getOperand(0))
- return false;
- if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0)))
- if (!MDS->getString().equals("branch_weights"))
- return false;
- if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors())
- return false;
- return true;
- };
- MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof);
// Latch terminator has no valid profile data, so nothing to check
// profitability on.
- if (!IsValidProfileData(LatchProfileData, LatchTerm))
+ if (!hasValidBranchWeightMD(*LatchTerm))
return true;
auto ComputeBranchProbability =
[&](const BasicBlock *ExitingBlock,
const BasicBlock *ExitBlock) -> BranchProbability {
auto *Term = ExitingBlock->getTerminator();
- MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof);
unsigned NumSucc = Term->getNumSuccessors();
- if (IsValidProfileData(ProfileData, Term)) {
- uint64_t Numerator = 0, Denominator = 0, ProfVal = 0;
- for (unsigned i = 0; i < NumSucc; i++) {
- ConstantInt *CI =
- mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1));
- ProfVal = CI->getValue().getZExtValue();
+ if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) {
+ SmallVector<uint32_t> Weights;
+ extractBranchWeights(ProfileData, Weights);
+ uint64_t Numerator = 0, Denominator = 0;
+ for (auto [i, Weight] : llvm::enumerate(Weights)) {
if (Term->getSuccessor(i) == ExitBlock)
- Numerator += ProfVal;
- Denominator += ProfVal;
+ Numerator += Weight;
+ Denominator += Weight;
}
return BranchProbability::getBranchProbability(Numerator, Denominator);
} else {