diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopPredication.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopPredication.cpp | 182 |
1 files changed, 95 insertions, 87 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index b327d38d2a84..49c0fff84d81 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -191,6 +191,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -200,6 +201,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" +#include <optional> #define DEBUG_TYPE "loop-predication" @@ -233,6 +235,13 @@ static cl::opt<bool> PredicateWidenableBranchGuards( "expressed as widenable branches to deoptimize blocks"), cl::init(true)); +static cl::opt<bool> InsertAssumesOfPredicatedGuardsConditions( + "loop-predication-insert-assumes-of-predicated-guards-conditions", + cl::Hidden, + cl::desc("Whether or not we should insert assumes of conditions of " + "predicated guards"), + cl::init(true)); + namespace { /// Represents an induction variable check: /// icmp Pred, <induction variable>, <loop invariant limit> @@ -263,8 +272,8 @@ class LoopPredication { LoopICmp LatchCheck; bool isSupportedStep(const SCEV* Step); - Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); - Optional<LoopICmp> parseLoopLatchICmp(); + std::optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); + std::optional<LoopICmp> parseLoopLatchICmp(); /// Return an insertion point suitable for inserting a safe to speculate /// instruction whose only user will be 'User' which has operands 'Ops'. A @@ -287,16 +296,17 @@ class LoopPredication { ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); - Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - Instruction *Guard); - Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, - LoopICmp RangeCheck, - SCEVExpander &Expander, - Instruction *Guard); - Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, - LoopICmp RangeCheck, - SCEVExpander &Expander, - Instruction *Guard); + std::optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, + SCEVExpander &Expander, + Instruction *Guard); + std::optional<Value *> + widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, + Instruction *Guard); + std::optional<Value *> + widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, + Instruction *Guard); unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition, SCEVExpander &Expander, Instruction *Guard); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); @@ -376,18 +386,17 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, return PA; } -Optional<LoopICmp> -LoopPredication::parseLoopICmp(ICmpInst *ICI) { +std::optional<LoopICmp> LoopPredication::parseLoopICmp(ICmpInst *ICI) { auto Pred = ICI->getPredicate(); auto *LHS = ICI->getOperand(0); auto *RHS = ICI->getOperand(1); const SCEV *LHSS = SE->getSCEV(LHS); if (isa<SCEVCouldNotCompute>(LHSS)) - return None; + return std::nullopt; const SCEV *RHSS = SE->getSCEV(RHS); if (isa<SCEVCouldNotCompute>(RHSS)) - return None; + return std::nullopt; // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV if (SE->isLoopInvariant(LHSS, L)) { @@ -398,7 +407,7 @@ LoopPredication::parseLoopICmp(ICmpInst *ICI) { const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS); if (!AR || AR->getLoop() != L) - return None; + return std::nullopt; return LoopICmp(Pred, AR, RHSS); } @@ -446,8 +455,8 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL, Type *RangeCheckType) { if (!EnableIVTruncation) return false; - assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedSize() > - DL.getTypeSizeInBits(RangeCheckType).getFixedSize() && + assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()).getFixedValue() > + DL.getTypeSizeInBits(RangeCheckType).getFixedValue() && "Expected latch check IV type to be larger than range check operand " "type!"); // The start and end values of the IV should be known. This is to guarantee @@ -467,7 +476,7 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL, // guarantees that truncating the latch check to RangeCheckType is a safe // operation. auto RangeCheckTypeBitSize = - DL.getTypeSizeInBits(RangeCheckType).getFixedSize(); + DL.getTypeSizeInBits(RangeCheckType).getFixedValue(); return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; } @@ -475,20 +484,20 @@ static bool isSafeToTruncateWideIVType(const DataLayout &DL, // Return an LoopICmp describing a latch check equivlent to LatchCheck but with // the requested type if safe to do so. May involve the use of a new IV. -static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, - ScalarEvolution &SE, - const LoopICmp LatchCheck, - Type *RangeCheckType) { +static std::optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, + ScalarEvolution &SE, + const LoopICmp LatchCheck, + Type *RangeCheckType) { auto *LatchType = LatchCheck.IV->getType(); if (RangeCheckType == LatchType) return LatchCheck; // For now, bail out if latch type is narrower than range type. - if (DL.getTypeSizeInBits(LatchType).getFixedSize() < - DL.getTypeSizeInBits(RangeCheckType).getFixedSize()) - return None; + if (DL.getTypeSizeInBits(LatchType).getFixedValue() < + DL.getTypeSizeInBits(RangeCheckType).getFixedValue()) + return std::nullopt; if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType)) - return None; + return std::nullopt; // We can now safely identify the truncated version of the IV and limit for // RangeCheckType. LoopICmp NewLatchCheck; @@ -496,7 +505,7 @@ static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( SE.getTruncateExpr(LatchCheck.IV, RangeCheckType)); if (!NewLatchCheck.IV) - return None; + return std::nullopt; NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType); LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType << "can be represented as range check type:" @@ -562,15 +571,15 @@ bool LoopPredication::isLoopInvariantValue(const SCEV* S) { if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) if (LI->isUnordered() && L->hasLoopInvariantOperands(LI)) - if (AA->pointsToConstantMemory(LI->getOperand(0)) || + if (!isModSet(AA->getModRefInfoMask(LI->getOperand(0))) || LI->hasMetadata(LLVMContext::MD_invariant_load)) return true; return false; } -Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( - LoopICmp LatchCheck, LoopICmp RangeCheck, - SCEVExpander &Expander, Instruction *Guard) { +std::optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( + LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, + Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); // Generate the widened condition for the forward loop: // guardStart u< guardLimit && @@ -590,12 +599,12 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( !isLoopInvariantValue(LatchStart) || !isLoopInvariantValue(LatchLimit)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } if (!Expander.isSafeToExpandAt(LatchStart, Guard) || !Expander.isSafeToExpandAt(LatchLimit, Guard)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } // guardLimit - guardStart + latchStart - 1 @@ -617,9 +626,9 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } -Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( - LoopICmp LatchCheck, LoopICmp RangeCheck, - SCEVExpander &Expander, Instruction *Guard) { +std::optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( + LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, + Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); const SCEV *GuardStart = RangeCheck.IV->getStart(); const SCEV *GuardLimit = RangeCheck.Limit; @@ -633,12 +642,12 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( !isLoopInvariantValue(LatchStart) || !isLoopInvariantValue(LatchLimit)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } if (!Expander.isSafeToExpandAt(LatchStart, Guard) || !Expander.isSafeToExpandAt(LatchLimit, Guard)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; + return std::nullopt; } // The decrement of the latch check IV should be the same as the // rangeCheckIV. @@ -647,7 +656,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( LLVM_DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " << *PostDecLatchCheckIV << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); - return None; + return std::nullopt; } // Generate the widened condition for CountDownLoop: @@ -676,13 +685,12 @@ static void normalizePredicate(ScalarEvolution *SE, Loop *L, ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; } - /// If ICI can be widened to a loop invariant condition emits the loop /// invariant condition in the loop preheader and return it, otherwise -/// returns None. -Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, - SCEVExpander &Expander, - Instruction *Guard) { +/// returns std::nullopt. +std::optional<Value *> +LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, + Instruction *Guard) { LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); LLVM_DEBUG(ICI->dump()); @@ -693,26 +701,26 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, auto RangeCheck = parseLoopICmp(ICI); if (!RangeCheck) { LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); - return None; + return std::nullopt; } LLVM_DEBUG(dbgs() << "Guard check:\n"); LLVM_DEBUG(RangeCheck->dump()); if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { LLVM_DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred << ")!\n"); - return None; + return std::nullopt; } auto *RangeCheckIV = RangeCheck->IV; if (!RangeCheckIV->isAffine()) { LLVM_DEBUG(dbgs() << "Range check IV is not affine!\n"); - return None; + return std::nullopt; } auto *Step = RangeCheckIV->getStepRecurrence(*SE); // We cannot just compare with latch IV step because the latch and range IVs // may have different types. if (!isSupportedStep(Step)) { LLVM_DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); - return None; + return std::nullopt; } auto *Ty = RangeCheckIV->getType(); auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty); @@ -720,7 +728,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " "corresponding to range type: " << *Ty << "\n"); - return None; + return std::nullopt; } LoopICmp CurrLatchCheck = *CurrLatchCheckOpt; @@ -731,7 +739,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, "Range and latch steps should be of same type!"); if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) { LLVM_DEBUG(dbgs() << "Range and latch have different step values!\n"); - return None; + return std::nullopt; } if (Step->isOne()) @@ -756,17 +764,17 @@ unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks, // resulting list of subconditions in Checks vector. SmallVector<Value *, 4> Worklist(1, Condition); SmallPtrSet<Value *, 4> Visited; + Visited.insert(Condition); Value *WideableCond = nullptr; do { Value *Condition = Worklist.pop_back_val(); - if (!Visited.insert(Condition).second) - continue; - Value *LHS, *RHS; using namespace llvm::PatternMatch; if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) { - Worklist.push_back(LHS); - Worklist.push_back(RHS); + if (Visited.insert(LHS).second) + Worklist.push_back(LHS); + if (Visited.insert(RHS).second) + Worklist.push_back(RHS); continue; } @@ -817,6 +825,10 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = Guard->getOperand(0); Guard->setOperand(0, AllChecks); + if (InsertAssumesOfPredicatedGuardsConditions) { + Builder.SetInsertPoint(&*++BasicBlock::iterator(Guard)); + Builder.CreateAssumption(OldCond); + } RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); @@ -829,6 +841,12 @@ bool LoopPredication::widenWidenableBranchGuardConditions( LLVM_DEBUG(dbgs() << "Processing guard:\n"); LLVM_DEBUG(BI->dump()); + Value *Cond, *WC; + BasicBlock *IfTrueBB, *IfFalseBB; + bool Parsed = parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB); + assert(Parsed && "Must be able to parse widenable branch"); + (void)Parsed; + TotalConsidered++; SmallVector<Value *, 4> Checks; unsigned NumWidened = collectChecks(Checks, BI->getCondition(), @@ -843,6 +861,10 @@ bool LoopPredication::widenWidenableBranchGuardConditions( Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = BI->getCondition(); BI->setCondition(AllChecks); + if (InsertAssumesOfPredicatedGuardsConditions) { + Builder.SetInsertPoint(IfTrueBB, IfTrueBB->getFirstInsertionPt()); + Builder.CreateAssumption(Cond); + } RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); assert(isGuardAsWidenableBranch(BI) && "Stopped being a guard after transform?"); @@ -851,19 +873,19 @@ bool LoopPredication::widenWidenableBranchGuardConditions( return true; } -Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { +std::optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { using namespace PatternMatch; BasicBlock *LoopLatch = L->getLoopLatch(); if (!LoopLatch) { LLVM_DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); - return None; + return std::nullopt; } auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); if (!BI || !BI->isConditional()) { LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n"); - return None; + return std::nullopt; } BasicBlock *TrueDest = BI->getSuccessor(0); assert( @@ -873,12 +895,12 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { auto *ICI = dyn_cast<ICmpInst>(BI->getCondition()); if (!ICI) { LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n"); - return None; + return std::nullopt; } auto Result = parseLoopICmp(ICI); if (!Result) { LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); - return None; + return std::nullopt; } if (TrueDest != L->getHeader()) @@ -888,13 +910,13 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { // recurrence. if (!Result->IV->isAffine()) { LLVM_DEBUG(dbgs() << "The induction variable is not affine!\n"); - return None; + return std::nullopt; } auto *Step = Result->IV->getStepRecurrence(*SE); if (!isSupportedStep(Step)) { LLVM_DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); - return None; + return std::nullopt; } auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) { @@ -912,13 +934,12 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { if (IsUnsupportedPredicate(Step, Result->Pred)) { LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred << ")!\n"); - return None; + return std::nullopt; } return Result; } - bool LoopPredication::isLoopProfitableToPredicate() { if (SkipProfitabilityChecks) return true; @@ -954,37 +975,24 @@ bool LoopPredication::isLoopProfitableToPredicate() { LatchExitBlock->getTerminatingDeoptimizeCall()) return false; - auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) { - if (!ProfileData || !ProfileData->getOperand(0)) - return false; - if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0))) - if (!MDS->getString().equals("branch_weights")) - return false; - if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors()) - return false; - return true; - }; - MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof); // Latch terminator has no valid profile data, so nothing to check // profitability on. - if (!IsValidProfileData(LatchProfileData, LatchTerm)) + if (!hasValidBranchWeightMD(*LatchTerm)) return true; auto ComputeBranchProbability = [&](const BasicBlock *ExitingBlock, const BasicBlock *ExitBlock) -> BranchProbability { auto *Term = ExitingBlock->getTerminator(); - MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof); unsigned NumSucc = Term->getNumSuccessors(); - if (IsValidProfileData(ProfileData, Term)) { - uint64_t Numerator = 0, Denominator = 0, ProfVal = 0; - for (unsigned i = 0; i < NumSucc; i++) { - ConstantInt *CI = - mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1)); - ProfVal = CI->getValue().getZExtValue(); + if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) { + SmallVector<uint32_t> Weights; + extractBranchWeights(ProfileData, Weights); + uint64_t Numerator = 0, Denominator = 0; + for (auto [i, Weight] : llvm::enumerate(Weights)) { if (Term->getSuccessor(i) == ExitBlock) - Numerator += ProfVal; - Denominator += ProfVal; + Numerator += Weight; + Denominator += Weight; } return BranchProbability::getBranchProbability(Numerator, Denominator); } else { |