diff options
Diffstat (limited to 'lib/Transforms/Scalar/LoopPredication.cpp')
-rw-r--r-- | lib/Transforms/Scalar/LoopPredication.cpp | 524 |
1 files changed, 353 insertions, 171 deletions
diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp index 5983c804c0c1..507a1e251ca6 100644 --- a/lib/Transforms/Scalar/LoopPredication.cpp +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -1,9 +1,8 @@ //===-- LoopPredication.cpp - Guard based loop predication pass -----------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -179,6 +178,7 @@ #include "llvm/Transforms/Scalar/LoopPredication.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" @@ -194,6 +194,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #define DEBUG_TYPE "loop-predication" @@ -222,24 +223,31 @@ static cl::opt<float> LatchExitProbabilityScale( cl::desc("scale factor for the latch probability. Value should be greater " "than 1. Lower values are ignored")); +static cl::opt<bool> PredicateWidenableBranchGuards( + "loop-predication-predicate-widenable-branches-to-deopt", cl::Hidden, + cl::desc("Whether or not we should predicate guards " + "expressed as widenable branches to deoptimize blocks"), + cl::init(true)); + namespace { -class LoopPredication { - /// Represents an induction variable check: - /// icmp Pred, <induction variable>, <loop invariant limit> - struct LoopICmp { - ICmpInst::Predicate Pred; - const SCEVAddRecExpr *IV; - const SCEV *Limit; - LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, - const SCEV *Limit) - : Pred(Pred), IV(IV), Limit(Limit) {} - LoopICmp() {} - void dump() { - dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV - << ", Limit = " << *Limit << "\n"; - } - }; +/// Represents an induction variable check: +/// icmp Pred, <induction variable>, <loop invariant limit> +struct LoopICmp { + ICmpInst::Predicate Pred; + const SCEVAddRecExpr *IV; + const SCEV *Limit; + LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, + const SCEV *Limit) + : Pred(Pred), IV(IV), Limit(Limit) {} + LoopICmp() {} + void dump() { + dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV + << ", Limit = " << *Limit << "\n"; + } +}; +class LoopPredication { + AliasAnalysis *AA; ScalarEvolution *SE; BranchProbabilityInfo *BPI; @@ -249,58 +257,53 @@ class LoopPredication { LoopICmp LatchCheck; bool isSupportedStep(const SCEV* Step); - Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) { - return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0), - ICI->getOperand(1)); - } - Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, - Value *RHS); - + Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); Optional<LoopICmp> parseLoopLatchICmp(); - bool CanExpand(const SCEV* S); - Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - Instruction *InsertAt); + /// Return an insertion point suitable for inserting a safe to speculate + /// instruction whose only user will be 'User' which has operands 'Ops'. A + /// trivial result would be the at the User itself, but we try to return a + /// loop invariant location if possible. + Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops); + /// Same as above, *except* that this uses the SCEV definition of invariant + /// which is that an expression *can be made* invariant via SCEVExpander. + /// Thus, this version is only suitable for finding an insert point to be be + /// passed to SCEVExpander! + Instruction *findInsertPt(Instruction *User, ArrayRef<const SCEV*> Ops); + + /// Return true if the value is known to produce a single fixed value across + /// all iterations on which it executes. Note that this does not imply + /// speculation safety. That must be established seperately. + bool isLoopInvariantValue(const SCEV* S); + + Value *expandCheck(SCEVExpander &Expander, Instruction *Guard, + ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS); Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, LoopICmp RangeCheck, SCEVExpander &Expander, - IRBuilder<> &Builder); + Instruction *Guard); + unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition, + SCEVExpander &Expander, Instruction *Guard); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); - + bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander); // If the loop always exits through another block in the loop, we should not // predicate based on the latch check. For example, the latch check can be a // very coarse grained check and there can be more fine grained exit checks // within the loop. We identify such unprofitable loops through BPI. bool isLoopProfitableToPredicate(); - // When the IV type is wider than the range operand type, we can still do loop - // predication, by generating SCEVs for the range and latch that are of the - // same type. We achieve this by generating a SCEV truncate expression for the - // latch IV. This is done iff truncation of the IV is a safe operation, - // without loss of information. - // Another way to achieve this is by generating a wider type SCEV for the - // range check operand, however, this needs a more involved check that - // operands do not overflow. This can lead to loss of information when the - // range operand is of the form: add i32 %offset, %iv. We need to prove that - // sext(x + y) is same as sext(x) + sext(y). - // This function returns true if we can safely represent the IV type in - // the RangeCheckType without loss of information. - bool isSafeToTruncateWideIVType(Type *RangeCheckType); - // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do - // so. - Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType); - public: - LoopPredication(ScalarEvolution *SE, BranchProbabilityInfo *BPI) - : SE(SE), BPI(BPI){}; + LoopPredication(AliasAnalysis *AA, ScalarEvolution *SE, + BranchProbabilityInfo *BPI) + : AA(AA), SE(SE), BPI(BPI){}; bool runOnLoop(Loop *L); }; @@ -322,7 +325,8 @@ public: auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); BranchProbabilityInfo &BPI = getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); - LoopPredication LP(SE, &BPI); + auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + LoopPredication LP(AA, SE, &BPI); return LP.runOnLoop(L); } }; @@ -348,16 +352,19 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); Function *F = L.getHeader()->getParent(); auto *BPI = FAM.getCachedResult<BranchProbabilityAnalysis>(*F); - LoopPredication LP(&AR.SE, BPI); + LoopPredication LP(&AR.AA, &AR.SE, BPI); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); } -Optional<LoopPredication::LoopICmp> -LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, - Value *RHS) { +Optional<LoopICmp> +LoopPredication::parseLoopICmp(ICmpInst *ICI) { + auto Pred = ICI->getPredicate(); + auto *LHS = ICI->getOperand(0); + auto *RHS = ICI->getOperand(1); + const SCEV *LHSS = SE->getSCEV(LHS); if (isa<SCEVCouldNotCompute>(LHSS)) return None; @@ -380,42 +387,98 @@ LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, } Value *LoopPredication::expandCheck(SCEVExpander &Expander, - IRBuilder<> &Builder, + Instruction *Guard, ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, Instruction *InsertAt) { - // TODO: we can check isLoopEntryGuardedByCond before emitting the check - + const SCEV *RHS) { Type *Ty = LHS->getType(); assert(Ty == RHS->getType() && "expandCheck operands have different types?"); - if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) - return Builder.getTrue(); + if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) { + IRBuilder<> Builder(Guard); + if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) + return Builder.getTrue(); + if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred), + LHS, RHS)) + return Builder.getFalse(); + } - Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); - Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt); + Value *LHSV = Expander.expandCodeFor(LHS, Ty, findInsertPt(Guard, {LHS})); + Value *RHSV = Expander.expandCodeFor(RHS, Ty, findInsertPt(Guard, {RHS})); + IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV})); return Builder.CreateICmp(Pred, LHSV, RHSV); } -Optional<LoopPredication::LoopICmp> -LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) { + +// Returns true if its safe to truncate the IV to RangeCheckType. +// When the IV type is wider than the range operand type, we can still do loop +// predication, by generating SCEVs for the range and latch that are of the +// same type. We achieve this by generating a SCEV truncate expression for the +// latch IV. This is done iff truncation of the IV is a safe operation, +// without loss of information. +// Another way to achieve this is by generating a wider type SCEV for the +// range check operand, however, this needs a more involved check that +// operands do not overflow. This can lead to loss of information when the +// range operand is of the form: add i32 %offset, %iv. We need to prove that +// sext(x + y) is same as sext(x) + sext(y). +// This function returns true if we can safely represent the IV type in +// the RangeCheckType without loss of information. +static bool isSafeToTruncateWideIVType(const DataLayout &DL, + ScalarEvolution &SE, + const LoopICmp LatchCheck, + Type *RangeCheckType) { + if (!EnableIVTruncation) + return false; + assert(DL.getTypeSizeInBits(LatchCheck.IV->getType()) > + DL.getTypeSizeInBits(RangeCheckType) && + "Expected latch check IV type to be larger than range check operand " + "type!"); + // The start and end values of the IV should be known. This is to guarantee + // that truncating the wide type will not lose information. + auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit); + auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart()); + if (!Limit || !Start) + return false; + // This check makes sure that the IV does not change sign during loop + // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, + // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the + // IV wraps around, and the truncation of the IV would lose the range of + // iterations between 2^32 and 2^64. + bool Increasing; + if (!SE.isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing)) + return false; + // The active bits should be less than the bits in the RangeCheckType. This + // guarantees that truncating the latch check to RangeCheckType is a safe + // operation. + auto RangeCheckTypeBitSize = DL.getTypeSizeInBits(RangeCheckType); + return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && + Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; +} + + +// Return an LoopICmp describing a latch check equivlent to LatchCheck but with +// the requested type if safe to do so. May involve the use of a new IV. +static Optional<LoopICmp> generateLoopLatchCheck(const DataLayout &DL, + ScalarEvolution &SE, + const LoopICmp LatchCheck, + Type *RangeCheckType) { auto *LatchType = LatchCheck.IV->getType(); if (RangeCheckType == LatchType) return LatchCheck; // For now, bail out if latch type is narrower than range type. - if (DL->getTypeSizeInBits(LatchType) < DL->getTypeSizeInBits(RangeCheckType)) + if (DL.getTypeSizeInBits(LatchType) < DL.getTypeSizeInBits(RangeCheckType)) return None; - if (!isSafeToTruncateWideIVType(RangeCheckType)) + if (!isSafeToTruncateWideIVType(DL, SE, LatchCheck, RangeCheckType)) return None; // We can now safely identify the truncated version of the IV and limit for // RangeCheckType. LoopICmp NewLatchCheck; NewLatchCheck.Pred = LatchCheck.Pred; NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( - SE->getTruncateExpr(LatchCheck.IV, RangeCheckType)); + SE.getTruncateExpr(LatchCheck.IV, RangeCheckType)); if (!NewLatchCheck.IV) return None; - NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType); + NewLatchCheck.Limit = SE.getTruncateExpr(LatchCheck.Limit, RangeCheckType); LLVM_DEBUG(dbgs() << "IV of type: " << *LatchType << "can be represented as range check type:" << *RangeCheckType << "\n"); @@ -428,13 +491,66 @@ bool LoopPredication::isSupportedStep(const SCEV* Step) { return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop); } -bool LoopPredication::CanExpand(const SCEV* S) { - return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); +Instruction *LoopPredication::findInsertPt(Instruction *Use, + ArrayRef<Value*> Ops) { + for (Value *Op : Ops) + if (!L->isLoopInvariant(Op)) + return Use; + return Preheader->getTerminator(); +} + +Instruction *LoopPredication::findInsertPt(Instruction *Use, + ArrayRef<const SCEV*> Ops) { + // Subtlety: SCEV considers things to be invariant if the value produced is + // the same across iterations. This is not the same as being able to + // evaluate outside the loop, which is what we actually need here. + for (const SCEV *Op : Ops) + if (!SE->isLoopInvariant(Op, L) || + !isSafeToExpandAt(Op, Preheader->getTerminator(), *SE)) + return Use; + return Preheader->getTerminator(); +} + +bool LoopPredication::isLoopInvariantValue(const SCEV* S) { + // Handling expressions which produce invariant results, but *haven't* yet + // been removed from the loop serves two important purposes. + // 1) Most importantly, it resolves a pass ordering cycle which would + // otherwise need us to iteration licm, loop-predication, and either + // loop-unswitch or loop-peeling to make progress on examples with lots of + // predicable range checks in a row. (Since, in the general case, we can't + // hoist the length checks until the dominating checks have been discharged + // as we can't prove doing so is safe.) + // 2) As a nice side effect, this exposes the value of peeling or unswitching + // much more obviously in the IR. Otherwise, the cost modeling for other + // transforms would end up needing to duplicate all of this logic to model a + // check which becomes predictable based on a modeled peel or unswitch. + // + // The cost of doing so in the worst case is an extra fill from the stack in + // the loop to materialize the loop invariant test value instead of checking + // against the original IV which is presumable in a register inside the loop. + // Such cases are presumably rare, and hint at missing oppurtunities for + // other passes. + + if (SE->isLoopInvariant(S, L)) + // Note: This the SCEV variant, so the original Value* may be within the + // loop even though SCEV has proven it is loop invariant. + return true; + + // Handle a particular important case which SCEV doesn't yet know about which + // shows up in range checks on arrays with immutable lengths. + // TODO: This should be sunk inside SCEV. + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) + if (const auto *LI = dyn_cast<LoadInst>(U->getValue())) + if (LI->isUnordered() && L->hasLoopInvariantOperands(LI)) + if (AA->pointsToConstantMemory(LI->getOperand(0)) || + LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr) + return true; + return false; } Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( - LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, - SCEVExpander &Expander, IRBuilder<> &Builder) { + LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); // Generate the widened condition for the forward loop: // guardStart u< guardLimit && @@ -446,40 +562,61 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( const SCEV *GuardLimit = RangeCheck.Limit; const SCEV *LatchStart = LatchCheck.IV->getStart(); const SCEV *LatchLimit = LatchCheck.Limit; + // Subtlety: We need all the values to be *invariant* across all iterations, + // but we only need to check expansion safety for those which *aren't* + // already guaranteed to dominate the guard. + if (!isLoopInvariantValue(GuardStart) || + !isLoopInvariantValue(GuardLimit) || + !isLoopInvariantValue(LatchStart) || + !isLoopInvariantValue(LatchLimit)) { + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + if (!isSafeToExpandAt(LatchStart, Guard, *SE) || + !isSafeToExpandAt(LatchLimit, Guard, *SE)) { + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } // guardLimit - guardStart + latchStart - 1 const SCEV *RHS = SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart), SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); - if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || - !CanExpand(LatchLimit) || !CanExpand(RHS)) { - LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); - return None; - } auto LimitCheckPred = ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); LLVM_DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); LLVM_DEBUG(dbgs() << "RHS: " << *RHS << "\n"); LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); - - Instruction *InsertAt = Preheader->getTerminator(); + auto *LimitCheck = - expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt); - auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred, - GuardStart, GuardLimit, InsertAt); + expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS); + auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred, + GuardStart, GuardLimit); + IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( - LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, - SCEVExpander &Expander, IRBuilder<> &Builder) { + LoopICmp LatchCheck, LoopICmp RangeCheck, + SCEVExpander &Expander, Instruction *Guard) { auto *Ty = RangeCheck.IV->getType(); const SCEV *GuardStart = RangeCheck.IV->getStart(); const SCEV *GuardLimit = RangeCheck.Limit; + const SCEV *LatchStart = LatchCheck.IV->getStart(); const SCEV *LatchLimit = LatchCheck.Limit; - if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || - !CanExpand(LatchLimit)) { + // Subtlety: We need all the values to be *invariant* across all iterations, + // but we only need to check expansion safety for those which *aren't* + // already guaranteed to dominate the guard. + if (!isLoopInvariantValue(GuardStart) || + !isLoopInvariantValue(GuardLimit) || + !isLoopInvariantValue(LatchStart) || + !isLoopInvariantValue(LatchLimit)) { + LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + if (!isSafeToExpandAt(LatchStart, Guard, *SE) || + !isSafeToExpandAt(LatchLimit, Guard, *SE)) { LLVM_DEBUG(dbgs() << "Can't expand limit check!\n"); return None; } @@ -497,22 +634,35 @@ Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( // guardStart u< guardLimit && // latchLimit <pred> 1. // See the header comment for reasoning of the checks. - Instruction *InsertAt = Preheader->getTerminator(); auto LimitCheckPred = ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred); - auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT, - GuardStart, GuardLimit, InsertAt); - auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, - SE->getOne(Ty), InsertAt); + auto *FirstIterationCheck = expandCheck(Expander, Guard, + ICmpInst::ICMP_ULT, + GuardStart, GuardLimit); + auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, + SE->getOne(Ty)); + IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck})); return Builder.CreateAnd(FirstIterationCheck, LimitCheck); } +static void normalizePredicate(ScalarEvolution *SE, Loop *L, + LoopICmp& RC) { + // LFTR canonicalizes checks to the ICMP_NE/EQ form; normalize back to the + // ULT/UGE form for ease of handling by our caller. + if (ICmpInst::isEquality(RC.Pred) && + RC.IV->getStepRecurrence(*SE)->isOne() && + SE->isKnownPredicate(ICmpInst::ICMP_ULE, RC.IV->getStart(), RC.Limit)) + RC.Pred = RC.Pred == ICmpInst::ICMP_NE ? + ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE; +} + + /// If ICI can be widened to a loop invariant condition emits the loop /// invariant condition in the loop preheader and return it, otherwise /// returns None. Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, - IRBuilder<> &Builder) { + Instruction *Guard) { LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); LLVM_DEBUG(ICI->dump()); @@ -545,7 +695,7 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, return None; } auto *Ty = RangeCheckIV->getType(); - auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty); + auto CurrLatchCheckOpt = generateLoopLatchCheck(*DL, *SE, LatchCheck, Ty); if (!CurrLatchCheckOpt) { LLVM_DEBUG(dbgs() << "Failed to generate a loop latch check " "corresponding to range type: " @@ -566,34 +716,27 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, if (Step->isOne()) return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck, - Expander, Builder); + Expander, Guard); else { assert(Step->isAllOnesValue() && "Step should be -1!"); return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck, - Expander, Builder); + Expander, Guard); } } -bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, - SCEVExpander &Expander) { - LLVM_DEBUG(dbgs() << "Processing guard:\n"); - LLVM_DEBUG(Guard->dump()); - - TotalConsidered++; - - IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); - +unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks, + Value *Condition, + SCEVExpander &Expander, + Instruction *Guard) { + unsigned NumWidened = 0; // The guard condition is expected to be in form of: // cond1 && cond2 && cond3 ... // Iterate over subconditions looking for icmp conditions which can be // widened across loop iterations. Widening these conditions remember the // resulting list of subconditions in Checks vector. - SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); + SmallVector<Value *, 4> Worklist(1, Condition); SmallPtrSet<Value *, 4> Visited; - - SmallVector<Value *, 4> Checks; - - unsigned NumWidened = 0; + Value *WideableCond = nullptr; do { Value *Condition = Worklist.pop_back_val(); if (!Visited.insert(Condition).second) @@ -607,8 +750,16 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, continue; } + if (match(Condition, + m_Intrinsic<Intrinsic::experimental_widenable_condition>())) { + // Pick any, we don't care which + WideableCond = Condition; + continue; + } + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { - if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) { + if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, + Guard)) { Checks.push_back(NewRangeCheck.getValue()); NumWidened++; continue; @@ -617,28 +768,70 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, // Save the condition as is if we can't widen it Checks.push_back(Condition); - } while (Worklist.size() != 0); + } while (!Worklist.empty()); + // At the moment, our matching logic for wideable conditions implicitly + // assumes we preserve the form: (br (and Cond, WC())). FIXME + // Note that if there were multiple calls to wideable condition in the + // traversal, we only need to keep one, and which one is arbitrary. + if (WideableCond) + Checks.push_back(WideableCond); + return NumWidened; +} + +bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, + SCEVExpander &Expander) { + LLVM_DEBUG(dbgs() << "Processing guard:\n"); + LLVM_DEBUG(Guard->dump()); + TotalConsidered++; + SmallVector<Value *, 4> Checks; + unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander, + Guard); + if (NumWidened == 0) + return false; + + TotalWidened += NumWidened; + + // Emit the new guard condition + IRBuilder<> Builder(findInsertPt(Guard, Checks)); + Value *AllChecks = Builder.CreateAnd(Checks); + auto *OldCond = Guard->getOperand(0); + Guard->setOperand(0, AllChecks); + RecursivelyDeleteTriviallyDeadInstructions(OldCond); + + LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + return true; +} + +bool LoopPredication::widenWidenableBranchGuardConditions( + BranchInst *BI, SCEVExpander &Expander) { + assert(isGuardAsWidenableBranch(BI) && "Must be!"); + LLVM_DEBUG(dbgs() << "Processing guard:\n"); + LLVM_DEBUG(BI->dump()); + + TotalConsidered++; + SmallVector<Value *, 4> Checks; + unsigned NumWidened = collectChecks(Checks, BI->getCondition(), + Expander, BI); if (NumWidened == 0) return false; TotalWidened += NumWidened; // Emit the new guard condition - Builder.SetInsertPoint(Guard); - Value *LastCheck = nullptr; - for (auto *Check : Checks) - if (!LastCheck) - LastCheck = Check; - else - LastCheck = Builder.CreateAnd(LastCheck, Check); - Guard->setOperand(0, LastCheck); + IRBuilder<> Builder(findInsertPt(BI, Checks)); + Value *AllChecks = Builder.CreateAnd(Checks); + auto *OldCond = BI->getCondition(); + BI->setCondition(AllChecks); + assert(isGuardAsWidenableBranch(BI) && + "Stopped being a guard after transform?"); + RecursivelyDeleteTriviallyDeadInstructions(OldCond); LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); return true; } -Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { +Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { using namespace PatternMatch; BasicBlock *LoopLatch = L->getLoopLatch(); @@ -647,27 +840,30 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { return None; } - ICmpInst::Predicate Pred; - Value *LHS, *RHS; - BasicBlock *TrueDest, *FalseDest; - - if (!match(LoopLatch->getTerminator(), - m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest, - FalseDest))) { + auto *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); + if (!BI || !BI->isConditional()) { LLVM_DEBUG(dbgs() << "Failed to match the latch terminator!\n"); return None; } - assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) && - "One of the latch's destinations must be the header"); - if (TrueDest != L->getHeader()) - Pred = ICmpInst::getInversePredicate(Pred); - - auto Result = parseLoopICmp(Pred, LHS, RHS); + BasicBlock *TrueDest = BI->getSuccessor(0); + assert( + (TrueDest == L->getHeader() || BI->getSuccessor(1) == L->getHeader()) && + "One of the latch's destinations must be the header"); + + auto *ICI = dyn_cast<ICmpInst>(BI->getCondition()); + if (!ICI) { + LLVM_DEBUG(dbgs() << "Failed to match the latch condition!\n"); + return None; + } + auto Result = parseLoopICmp(ICI); if (!Result) { LLVM_DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } + if (TrueDest != L->getHeader()) + Result->Pred = ICmpInst::getInversePredicate(Result->Pred); + // Check affine first, so if it's not we don't try to compute the step // recurrence. if (!Result->IV->isAffine()) { @@ -692,49 +888,22 @@ Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { } }; + normalizePredicate(SE, L, *Result); if (IsUnsupportedPredicate(Step, Result->Pred)) { LLVM_DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred << ")!\n"); return None; } + return Result; } -// Returns true if its safe to truncate the IV to RangeCheckType. -bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) { - if (!EnableIVTruncation) - return false; - assert(DL->getTypeSizeInBits(LatchCheck.IV->getType()) > - DL->getTypeSizeInBits(RangeCheckType) && - "Expected latch check IV type to be larger than range check operand " - "type!"); - // The start and end values of the IV should be known. This is to guarantee - // that truncating the wide type will not lose information. - auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit); - auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart()); - if (!Limit || !Start) - return false; - // This check makes sure that the IV does not change sign during loop - // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, - // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the - // IV wraps around, and the truncation of the IV would lose the range of - // iterations between 2^32 and 2^64. - bool Increasing; - if (!SE->isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing)) - return false; - // The active bits should be less than the bits in the RangeCheckType. This - // guarantees that truncating the latch check to RangeCheckType is a safe - // operation. - auto RangeCheckTypeBitSize = DL->getTypeSizeInBits(RangeCheckType); - return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && - Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; -} bool LoopPredication::isLoopProfitableToPredicate() { if (SkipProfitabilityChecks || !BPI) return true; - SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 8> ExitEdges; + SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges; L->getExitEdges(ExitEdges); // If there is only one exiting edge in the loop, it is always profitable to // predicate the loop. @@ -795,7 +964,12 @@ bool LoopPredication::runOnLoop(Loop *Loop) { // There is nothing to do if the module doesn't use guards auto *GuardDecl = M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard)); - if (!GuardDecl || GuardDecl->use_empty()) + bool HasIntrinsicGuards = GuardDecl && !GuardDecl->use_empty(); + auto *WCDecl = M->getFunction( + Intrinsic::getName(Intrinsic::experimental_widenable_condition)); + bool HasWidenableConditions = + PredicateWidenableBranchGuards && WCDecl && !WCDecl->use_empty(); + if (!HasIntrinsicGuards && !HasWidenableConditions) return false; DL = &M->getDataLayout(); @@ -819,12 +993,18 @@ bool LoopPredication::runOnLoop(Loop *Loop) { // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector<IntrinsicInst *, 4> Guards; - for (const auto BB : L->blocks()) + SmallVector<BranchInst *, 4> GuardsAsWidenableBranches; + for (const auto BB : L->blocks()) { for (auto &I : *BB) if (isGuard(&I)) Guards.push_back(cast<IntrinsicInst>(&I)); + if (PredicateWidenableBranchGuards && + isGuardAsWidenableBranch(BB->getTerminator())) + GuardsAsWidenableBranches.push_back( + cast<BranchInst>(BB->getTerminator())); + } - if (Guards.empty()) + if (Guards.empty() && GuardsAsWidenableBranches.empty()) return false; SCEVExpander Expander(*SE, *DL, "loop-predication"); @@ -832,6 +1012,8 @@ bool LoopPredication::runOnLoop(Loop *Loop) { bool Changed = false; for (auto *Guard : Guards) Changed |= widenGuardConditions(Guard, Expander); + for (auto *Guard : GuardsAsWidenableBranches) + Changed |= widenWidenableBranchGuardConditions(Guard, Expander); return Changed; } |