diff options
Diffstat (limited to 'lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp')
-rw-r--r-- | lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp | 681 |
1 files changed, 449 insertions, 232 deletions
diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 99b4458ea0fa..5c4d55bfbb2b 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -1,4 +1,4 @@ -//===-- InductiveRangeCheckElimination.cpp - ------------------------------===// +//===- InductiveRangeCheckElimination.cpp - -------------------------------===// // // The LLVM Compiler Infrastructure // @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// // The InductiveRangeCheckElimination pass splits a loop's iteration space into // three disjoint ranges. It does that in a way such that the loop running in // the middle loop provably does not need range checks. As an example, it will @@ -39,30 +40,61 @@ // throw_out_of_bounds(); // } // } +// //===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <limits> +#include <utility> +#include <vector> using namespace llvm; +using namespace llvm::PatternMatch; static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, cl::init(64)); @@ -79,6 +111,9 @@ static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", cl::Hidden, cl::init(false)); +static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch", + cl::Hidden, cl::init(true)); + static const char *ClonedLoopTag = "irce.loop.clone"; #define DEBUG_TYPE "irce" @@ -114,15 +149,16 @@ class InductiveRangeCheck { static StringRef rangeCheckKindToStr(RangeCheckKind); - const SCEV *Offset = nullptr; - const SCEV *Scale = nullptr; - Value *Length = nullptr; + const SCEV *Begin = nullptr; + const SCEV *Step = nullptr; + const SCEV *End = nullptr; Use *CheckUse = nullptr; RangeCheckKind Kind = RANGE_CHECK_UNKNOWN; + bool IsSigned = true; static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, - Value *&Length); + Value *&Length, bool &IsSigned); static void extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, @@ -130,20 +166,21 @@ class InductiveRangeCheck { SmallPtrSetImpl<Value *> &Visited); public: - const SCEV *getOffset() const { return Offset; } - const SCEV *getScale() const { return Scale; } - Value *getLength() const { return Length; } + const SCEV *getBegin() const { return Begin; } + const SCEV *getStep() const { return Step; } + const SCEV *getEnd() const { return End; } + bool isSigned() const { return IsSigned; } void print(raw_ostream &OS) const { OS << "InductiveRangeCheck:\n"; OS << " Kind: " << rangeCheckKindToStr(Kind) << "\n"; - OS << " Offset: "; - Offset->print(OS); - OS << " Scale: "; - Scale->print(OS); - OS << " Length: "; - if (Length) - Length->print(OS); + OS << " Begin: "; + Begin->print(OS); + OS << " Step: "; + Step->print(OS); + OS << " End: "; + if (End) + End->print(OS); else OS << "(null)"; OS << "\n CheckUse: "; @@ -173,6 +210,14 @@ public: Type *getType() const { return Begin->getType(); } const SCEV *getBegin() const { return Begin; } const SCEV *getEnd() const { return End; } + bool isEmpty(ScalarEvolution &SE, bool IsSigned) const { + if (Begin == End) + return true; + if (IsSigned) + return SE.isKnownPredicate(ICmpInst::ICMP_SGE, Begin, End); + else + return SE.isKnownPredicate(ICmpInst::ICMP_UGE, Begin, End); + } }; /// This is the value the condition of the branch needs to evaluate to for the @@ -183,7 +228,8 @@ public: /// check is redundant and can be constant-folded away. The induction /// variable is not required to be the canonical {0,+,1} induction variable. Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, - const SCEVAddRecExpr *IndVar) const; + const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const; /// Parse out a set of inductive range checks from \p BI and append them to \p /// Checks. @@ -199,6 +245,7 @@ public: class InductiveRangeCheckElimination : public LoopPass { public: static char ID; + InductiveRangeCheckElimination() : LoopPass(ID) { initializeInductiveRangeCheckEliminationPass( *PassRegistry::getPassRegistry()); @@ -212,8 +259,9 @@ public: bool runOnLoop(Loop *L, LPPassManager &LPM) override; }; +} // end anonymous namespace + char InductiveRangeCheckElimination::ID = 0; -} INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce", "Inductive range check elimination", false, false) @@ -247,12 +295,10 @@ StringRef InductiveRangeCheck::rangeCheckKindToStr( /// range checked, and set `Length` to the upper limit `Index` is being range /// checked with if (and only if) the range check type is stronger or equal to /// RANGE_CHECK_UPPER. -/// InductiveRangeCheck::RangeCheckKind InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, - Value *&Length) { - + Value *&Length, bool &IsSigned) { auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { const SCEV *S = SE.getSCEV(V); if (isa<SCEVCouldNotCompute>(S)) @@ -262,8 +308,6 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, SE.isKnownNonNegative(S); }; - using namespace llvm::PatternMatch; - ICmpInst::Predicate Pred = ICI->getPredicate(); Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); @@ -276,6 +320,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGE: + IsSigned = true; if (match(RHS, m_ConstantInt<0>())) { Index = LHS; return RANGE_CHECK_LOWER; @@ -286,6 +331,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGT: + IsSigned = true; if (match(RHS, m_ConstantInt<-1>())) { Index = LHS; return RANGE_CHECK_LOWER; @@ -302,6 +348,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_UGT: + IsSigned = false; if (IsNonNegativeAndNotLoopVarying(LHS)) { Index = RHS; Length = LHS; @@ -317,42 +364,16 @@ void InductiveRangeCheck::extractRangeChecksFromCond( Loop *L, ScalarEvolution &SE, Use &ConditionUse, SmallVectorImpl<InductiveRangeCheck> &Checks, SmallPtrSetImpl<Value *> &Visited) { - using namespace llvm::PatternMatch; - Value *Condition = ConditionUse.get(); if (!Visited.insert(Condition).second) return; + // TODO: Do the same for OR, XOR, NOT etc? if (match(Condition, m_And(m_Value(), m_Value()))) { - SmallVector<InductiveRangeCheck, 8> SubChecks; extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0), - SubChecks, Visited); + Checks, Visited); extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1), - SubChecks, Visited); - - if (SubChecks.size() == 2) { - // Handle a special case where we know how to merge two checks separately - // checking the upper and lower bounds into a full range check. - const auto &RChkA = SubChecks[0]; - const auto &RChkB = SubChecks[1]; - if ((RChkA.Length == RChkB.Length || !RChkA.Length || !RChkB.Length) && - RChkA.Offset == RChkB.Offset && RChkA.Scale == RChkB.Scale) { - - // If RChkA.Kind == RChkB.Kind then we just found two identical checks. - // But if one of them is a RANGE_CHECK_LOWER and the other is a - // RANGE_CHECK_UPPER (only possibility if they're different) then - // together they form a RANGE_CHECK_BOTH. - SubChecks[0].Kind = - (InductiveRangeCheck::RangeCheckKind)(RChkA.Kind | RChkB.Kind); - SubChecks[0].Length = RChkA.Length ? RChkA.Length : RChkB.Length; - SubChecks[0].CheckUse = &ConditionUse; - - // We updated one of the checks in place, now erase the other. - SubChecks.pop_back(); - } - } - - Checks.insert(Checks.end(), SubChecks.begin(), SubChecks.end()); + Checks, Visited); return; } @@ -361,7 +382,8 @@ void InductiveRangeCheck::extractRangeChecksFromCond( return; Value *Length = nullptr, *Index; - auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length); + bool IsSigned; + auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned); if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) return; @@ -373,18 +395,18 @@ void InductiveRangeCheck::extractRangeChecksFromCond( return; InductiveRangeCheck IRC; - IRC.Length = Length; - IRC.Offset = IndexAddRec->getStart(); - IRC.Scale = IndexAddRec->getStepRecurrence(SE); + IRC.End = Length ? SE.getSCEV(Length) : nullptr; + IRC.Begin = IndexAddRec->getStart(); + IRC.Step = IndexAddRec->getStepRecurrence(SE); IRC.CheckUse = &ConditionUse; IRC.Kind = RCKind; + IRC.IsSigned = IsSigned; Checks.push_back(IRC); } void InductiveRangeCheck::extractRangeChecksFromBranch( BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo &BPI, SmallVectorImpl<InductiveRangeCheck> &Checks) { - if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) return; @@ -435,16 +457,16 @@ namespace { // kinds of loops we can deal with -- ones that have a single latch that is also // an exiting block *and* have a canonical induction variable. struct LoopStructure { - const char *Tag; + const char *Tag = ""; - BasicBlock *Header; - BasicBlock *Latch; + BasicBlock *Header = nullptr; + BasicBlock *Latch = nullptr; // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th // successor is `LatchExit', the exit block of the loop. - BranchInst *LatchBr; - BasicBlock *LatchExit; - unsigned LatchBrExitIdx; + BranchInst *LatchBr = nullptr; + BasicBlock *LatchExit = nullptr; + unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max(); // The loop represented by this instance of LoopStructure is semantically // equivalent to: @@ -452,18 +474,17 @@ struct LoopStructure { // intN_ty inc = IndVarIncreasing ? 1 : -1; // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT; // - // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarNext) + // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase) // ... body ... - Value *IndVarNext; - Value *IndVarStart; - Value *LoopExitAt; - bool IndVarIncreasing; + Value *IndVarBase = nullptr; + Value *IndVarStart = nullptr; + Value *IndVarStep = nullptr; + Value *LoopExitAt = nullptr; + bool IndVarIncreasing = false; + bool IsSignedPredicate = true; - LoopStructure() - : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr), - LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr), - IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {} + LoopStructure() = default; template <typename M> LoopStructure map(M Map) const { LoopStructure Result; @@ -473,10 +494,12 @@ struct LoopStructure { Result.LatchBr = cast<BranchInst>(Map(LatchBr)); Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); Result.LatchBrExitIdx = LatchBrExitIdx; - Result.IndVarNext = Map(IndVarNext); + Result.IndVarBase = Map(IndVarBase); Result.IndVarStart = Map(IndVarStart); + Result.IndVarStep = Map(IndVarStep); Result.LoopExitAt = Map(LoopExitAt); Result.IndVarIncreasing = IndVarIncreasing; + Result.IsSignedPredicate = IsSignedPredicate; return Result; } @@ -494,7 +517,6 @@ struct LoopStructure { /// loops to run any remaining iterations. The pre loop runs any iterations in /// which the induction variable is < Begin, and the post loop runs any /// iterations in which the induction variable is >= End. -/// class LoopConstrainer { // The representation of a clone of the original loop we started out with. struct ClonedLoop { @@ -511,13 +533,12 @@ class LoopConstrainer { // Result of rewriting the range of a loop. See changeIterationSpaceEnd for // more details on what these fields mean. struct RewrittenRangeInfo { - BasicBlock *PseudoExit; - BasicBlock *ExitSelector; + BasicBlock *PseudoExit = nullptr; + BasicBlock *ExitSelector = nullptr; std::vector<PHINode *> PHIValuesAtPseudoExit; - PHINode *IndVarEnd; + PHINode *IndVarEnd = nullptr; - RewrittenRangeInfo() - : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {} + RewrittenRangeInfo() = default; }; // Calculated subranges we restrict the iteration space of the main loop to. @@ -541,14 +562,12 @@ class LoopConstrainer { // Compute a safe set of limits for the main loop to run in -- effectively the // intersection of `Range' and the iteration space of the original loop. // Return None if unable to compute the set of subranges. - // - Optional<SubRanges> calculateSubRanges() const; + Optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const; // Clone `OriginalLoop' and return the result in CLResult. The IR after // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- // the PHI nodes say that there is an incoming edge from `OriginalPreheader` // but there is no such edge. - // void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; // Create the appropriate loop structure needed to describe a cloned copy of @@ -577,7 +596,6 @@ class LoopConstrainer { // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate // preheader because it is made to branch to the loop header only // conditionally. - // RewrittenRangeInfo changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, Value *ExitLoopAt, @@ -585,7 +603,6 @@ class LoopConstrainer { // The loop denoted by `LS' has `OldPreheader' as its preheader. This // function creates a new preheader for `LS' and returns it. - // BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, const char *Tag) const; @@ -613,12 +630,13 @@ class LoopConstrainer { // Information about the original loop we started out with. Loop &OriginalLoop; - const SCEV *LatchTakenCount; - BasicBlock *OriginalPreheader; + + const SCEV *LatchTakenCount = nullptr; + BasicBlock *OriginalPreheader = nullptr; // The preheader of the main loop. This may or may not be different from // `OriginalPreheader'. - BasicBlock *MainLoopPreheader; + BasicBlock *MainLoopPreheader = nullptr; // The range we need to run the main loop in. InductiveRangeCheck::Range Range; @@ -632,15 +650,14 @@ public: const LoopStructure &LS, ScalarEvolution &SE, DominatorTree &DT, InductiveRangeCheck::Range R) : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), - SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), - LatchTakenCount(nullptr), OriginalPreheader(nullptr), - MainLoopPreheader(nullptr), Range(R), MainLoopStructure(LS) {} + SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), Range(R), + MainLoopStructure(LS) {} // Entry point for the algorithm. Returns true on success. bool run(); }; -} +} // end anonymous namespace void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, BasicBlock *ReplaceBy) { @@ -649,22 +666,55 @@ void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, PN->setIncomingBlock(i, ReplaceBy); } -static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) { - APInt SMax = - APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(SMax) && - SE.getUnsignedRange(S).contains(SMax); +static bool CanBeMax(ScalarEvolution &SE, const SCEV *S, bool Signed) { + APInt Max = Signed ? + APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()) : + APInt::getMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(Max) && + SE.getUnsignedRange(S).contains(Max); +} + +static bool SumCanReachMax(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, + bool Signed) { + // S1 < INT_MAX - S2 ===> S1 + S2 < INT_MAX. + assert(SE.isKnownNonNegative(S2) && + "We expected the 2nd arg to be non-negative!"); + const SCEV *Max = SE.getConstant( + Signed ? APInt::getSignedMaxValue( + cast<IntegerType>(S1->getType())->getBitWidth()) + : APInt::getMaxValue( + cast<IntegerType>(S1->getType())->getBitWidth())); + const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); + return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + S1, CapForS1); } -static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) { - APInt SMin = - APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(SMin) && - SE.getUnsignedRange(S).contains(SMin); +static bool CanBeMin(ScalarEvolution &SE, const SCEV *S, bool Signed) { + APInt Min = Signed ? + APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()) : + APInt::getMinValue(cast<IntegerType>(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(Min) && + SE.getUnsignedRange(S).contains(Min); +} + +static bool SumCanReachMin(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, + bool Signed) { + // S1 > INT_MIN - S2 ===> S1 + S2 > INT_MIN. + assert(SE.isKnownNonPositive(S2) && + "We expected the 2nd arg to be non-positive!"); + const SCEV *Max = SE.getConstant( + Signed ? APInt::getSignedMinValue( + cast<IntegerType>(S1->getType())->getBitWidth()) + : APInt::getMinValue( + cast<IntegerType>(S1->getType())->getBitWidth())); + const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); + return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, + S1, CapForS1); } Optional<LoopStructure> -LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI, +LoopStructure::parseLoopStructure(ScalarEvolution &SE, + BranchProbabilityInfo &BPI, Loop &L, const char *&FailureReason) { if (!L.isLoopSimplifyForm()) { FailureReason = "loop not in LoopSimplify form"; @@ -766,7 +816,11 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; }; - auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) { + // Here we check whether the suggested AddRec is an induction variable that + // can be handled (i.e. with known constant step), and if yes, calculate its + // step and identify whether it is increasing or decreasing. + auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing, + ConstantInt *&StepCI) { if (!AR->isAffine()) return false; @@ -778,11 +832,10 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP if (const SCEVConstant *StepExpr = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) { - ConstantInt *StepCI = StepExpr->getValue(); - if (StepCI->isOne() || StepCI->isMinusOne()) { - IsIncreasing = StepCI->isOne(); - return true; - } + StepCI = StepExpr->getValue(); + assert(!StepCI->isZero() && "Zero step?"); + IsIncreasing = !StepCI->isNegative(); + return true; } return false; @@ -791,59 +844,87 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP // `ICI` is interpreted as taking the backedge if the *next* value of the // induction variable satisfies some constraint. - const SCEVAddRecExpr *IndVarNext = cast<SCEVAddRecExpr>(LeftSCEV); + const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); bool IsIncreasing = false; - if (!IsInductionVar(IndVarNext, IsIncreasing)) { + bool IsSignedPredicate = true; + ConstantInt *StepCI; + if (!IsInductionVar(IndVarBase, IsIncreasing, StepCI)) { FailureReason = "LHS in icmp not induction variable"; return None; } - const SCEV *StartNext = IndVarNext->getStart(); - const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); + const SCEV *StartNext = IndVarBase->getStart(); + const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); + const SCEV *Step = SE.getSCEV(StepCI); ConstantInt *One = ConstantInt::get(IndVarTy, 1); - // TODO: generalize the predicates here to also match their unsigned variants. if (IsIncreasing) { bool DecreasedRightValueByOne = false; - // Try to turn eq/ne predicates to those we can work with. - if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) - // while (++i != len) { while (++i < len) { - // ... ---> ... - // } } - Pred = ICmpInst::ICMP_SLT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeSMin(SE, RightSCEV)) { - // while (true) { while (true) { - // if (++i == len) ---> if (++i > len - 1) - // break; break; - // ... ... - // } } - Pred = ICmpInst::ICMP_SGT; - RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); - DecreasedRightValueByOne = true; + if (StepCI->isOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (++i != len) { while (++i < len) { + // ... ---> ... + // } } + // If both parts are known non-negative, it is profitable to use + // unsigned comparison in increasing loop. This allows us to make the + // comparison check against "RightSCEV + 1" more optimistic. + if (SE.isKnownNonNegative(IndVarStart) && + SE.isKnownNonNegative(RightSCEV)) + Pred = ICmpInst::ICMP_ULT; + else + Pred = ICmpInst::ICMP_SLT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeMin(SE, RightSCEV, /* IsSignedPredicate */ true)) { + // while (true) { while (true) { + // if (++i == len) ---> if (++i > len - 1) + // break; break; + // ... ... + // } } + // TODO: Insert ICMP_UGT if both are non-negative? + Pred = ICmpInst::ICMP_SGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } } + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); bool FoundExpectedPred = - (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) || - (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0); + (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); if (!FoundExpectedPred) { FailureReason = "expected icmp slt semantically, found something else"; return None; } + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return None; + } + + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSignedPredicate ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + if (LatchBrExitIdx == 0) { - if (CanBeSMax(SE, RightSCEV)) { + const SCEV *StepMinusOne = SE.getMinusSCEV(Step, + SE.getOne(Step->getType())); + if (SumCanReachMax(SE, RightSCEV, StepMinusOne, IsSignedPredicate)) { // TODO: this restriction is easily removable -- we just have to // remember that the icmp was an slt and not an sle. - FailureReason = "limit may overflow when coercing sle to slt"; + FailureReason = "limit may overflow when coercing le to lt"; return None; } if (!SE.isLoopEntryGuardedByCond( - &L, CmpInst::ICMP_SLT, IndVarStart, - SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())))) { + &L, BoundPred, IndVarStart, + SE.getAddExpr(RightSCEV, Step))) { FailureReason = "Induction variable start not bounded by upper limit"; return None; } @@ -855,8 +936,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP RightValue = B.CreateAdd(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SLT, IndVarStart, - RightSCEV)) { + if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { FailureReason = "Induction variable start not bounded by upper limit"; return None; } @@ -865,43 +945,65 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP } } else { bool IncreasedRightValueByOne = false; - // Try to turn eq/ne predicates to those we can work with. - if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) - // while (--i != len) { while (--i > len) { - // ... ---> ... - // } } - Pred = ICmpInst::ICMP_SGT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeSMax(SE, RightSCEV)) { - // while (true) { while (true) { - // if (--i == len) ---> if (--i < len + 1) - // break; break; - // ... ... - // } } - Pred = ICmpInst::ICMP_SLT; - RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); - IncreasedRightValueByOne = true; + if (StepCI->isMinusOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (--i != len) { while (--i > len) { + // ... ---> ... + // } } + // We intentionally don't turn the predicate into UGT even if we know + // that both operands are non-negative, because it will only pessimize + // our check against "RightSCEV - 1". + Pred = ICmpInst::ICMP_SGT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeMax(SE, RightSCEV, /* IsSignedPredicate */ true)) { + // while (true) { while (true) { + // if (--i == len) ---> if (--i < len + 1) + // break; break; + // ... ... + // } } + // TODO: Insert ICMP_ULT if both are non-negative? + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } } + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + bool FoundExpectedPred = - (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) || - (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0); + (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); if (!FoundExpectedPred) { FailureReason = "expected icmp sgt semantically, found something else"; return None; } + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return None; + } + + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSignedPredicate ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + if (LatchBrExitIdx == 0) { - if (CanBeSMin(SE, RightSCEV)) { + const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); + if (SumCanReachMin(SE, RightSCEV, StepPlusOne, IsSignedPredicate)) { // TODO: this restriction is easily removable -- we just have to // remember that the icmp was an sgt and not an sge. - FailureReason = "limit may overflow when coercing sge to sgt"; + FailureReason = "limit may overflow when coercing ge to gt"; return None; } if (!SE.isLoopEntryGuardedByCond( - &L, CmpInst::ICMP_SGT, IndVarStart, + &L, BoundPred, IndVarStart, SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) { FailureReason = "Induction variable start not bounded by lower limit"; return None; @@ -914,8 +1016,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP RightValue = B.CreateSub(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SGT, IndVarStart, - RightSCEV)) { + if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { FailureReason = "Induction variable start not bounded by lower limit"; return None; } @@ -923,7 +1024,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP "Right value can be increased only for LatchBrExitIdx == 0!"); } } - BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); assert(SE.getLoopDisposition(LatchCount, &L) == @@ -946,9 +1046,11 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP Result.LatchExit = LatchExit; Result.LatchBrExitIdx = LatchBrExitIdx; Result.IndVarStart = IndVarStartV; - Result.IndVarNext = LeftValue; + Result.IndVarStep = StepCI; + Result.IndVarBase = LeftValue; Result.IndVarIncreasing = IsIncreasing; Result.LoopExitAt = RightValue; + Result.IsSignedPredicate = IsSignedPredicate; FailureReason = nullptr; @@ -956,7 +1058,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP } Optional<LoopConstrainer::SubRanges> -LoopConstrainer::calculateSubRanges() const { +LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); if (Range.getType() != Ty) @@ -999,26 +1101,31 @@ LoopConstrainer::calculateSubRanges() const { // that case, `Clamp` will always return `Smallest` and // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) // will be an empty range. Returning an empty range is always safe. - // Smallest = SE.getAddExpr(End, One); Greatest = SE.getAddExpr(Start, One); GreatestSeen = Start; } - auto Clamp = [this, Smallest, Greatest](const SCEV *S) { - return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)); + auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) { + return IsSignedPredicate + ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)) + : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S)); }; - // In some cases we can prove that we don't need a pre or post loop + // In some cases we can prove that we don't need a pre or post loop. + ICmpInst::Predicate PredLE = + IsSignedPredicate ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; + ICmpInst::Predicate PredLT = + IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; bool ProvablyNoPreloop = - SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest); + SE.isKnownPredicate(PredLE, Range.getBegin(), Smallest); if (!ProvablyNoPreloop) Result.LowLimit = Clamp(Range.getBegin()); bool ProvablyNoPostLoop = - SE.isKnownPredicate(ICmpInst::ICMP_SLT, GreatestSeen, Range.getEnd()); + SE.isKnownPredicate(PredLT, GreatestSeen, Range.getEnd()); if (!ProvablyNoPostLoop) Result.HighLimit = Clamp(Range.getEnd()); @@ -1082,7 +1189,6 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, BasicBlock *ContinuationBlock) const { - // We start with a loop with a single latch: // // +--------------------+ @@ -1153,7 +1259,6 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // | original exit <----+ // | | // +--------------------+ - // RewrittenRangeInfo RRI; @@ -1165,22 +1270,35 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); bool Increasing = LS.IndVarIncreasing; + bool IsSignedPredicate = LS.IsSignedPredicate; IRBuilder<> B(PreheaderJump); // EnterLoopCond - is it okay to start executing this `LS'? - Value *EnterLoopCond = Increasing - ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) - : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt); + Value *EnterLoopCond = nullptr; + if (Increasing) + EnterLoopCond = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) + : B.CreateICmpULT(LS.IndVarStart, ExitSubloopAt); + else + EnterLoopCond = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt) + : B.CreateICmpUGT(LS.IndVarStart, ExitSubloopAt); B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); PreheaderJump->eraseFromParent(); LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); - Value *TakeBackedgeLoopCond = - Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt) - : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt); + Value *TakeBackedgeLoopCond = nullptr; + if (Increasing) + TakeBackedgeLoopCond = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarBase, ExitSubloopAt) + : B.CreateICmpULT(LS.IndVarBase, ExitSubloopAt); + else + TakeBackedgeLoopCond = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarBase, ExitSubloopAt) + : B.CreateICmpUGT(LS.IndVarBase, ExitSubloopAt); Value *CondForBranch = LS.LatchBrExitIdx == 1 ? TakeBackedgeLoopCond : B.CreateNot(TakeBackedgeLoopCond); @@ -1192,9 +1310,15 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. - Value *IterationsLeft = Increasing - ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt) - : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt); + Value *IterationsLeft = nullptr; + if (Increasing) + IterationsLeft = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarBase, LS.LoopExitAt) + : B.CreateICmpULT(LS.IndVarBase, LS.LoopExitAt); + else + IterationsLeft = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarBase, LS.LoopExitAt) + : B.CreateICmpUGT(LS.IndVarBase, LS.LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); BranchInst *BranchToContinuation = @@ -1217,10 +1341,10 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( RRI.PHIValuesAtPseudoExit.push_back(NewPHI); } - RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end", + RRI.IndVarEnd = PHINode::Create(LS.IndVarBase->getType(), 2, "indvar.end", BranchToContinuation); RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); - RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector); + RRI.IndVarEnd->addIncoming(LS.IndVarBase, RRI.ExitSelector); // The latch exit now has a branch from `RRI.ExitSelector' instead of // `LS.Latch'. The PHI nodes need to be updated to reflect that. @@ -1237,7 +1361,6 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( void LoopConstrainer::rewriteIncomingValuesForPHIs( LoopStructure &LS, BasicBlock *ContinuationBlock, const LoopConstrainer::RewrittenRangeInfo &RRI) const { - unsigned PHIIndex = 0; for (Instruction &I : *LS.Header) { auto *PN = dyn_cast<PHINode>(&I); @@ -1255,7 +1378,6 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs( BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, const char *Tag) const { - BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); BranchInst::Create(LS.Header, Preheader); @@ -1282,7 +1404,7 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, ValueToValueMapTy &VM) { - Loop &New = *new Loop(); + Loop &New = *LI.AllocateLoop(); if (Parent) Parent->addChildLoop(&New); else @@ -1311,7 +1433,8 @@ bool LoopConstrainer::run() { OriginalPreheader = Preheader; MainLoopPreheader = Preheader; - Optional<SubRanges> MaybeSR = calculateSubRanges(); + bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; + Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); if (!MaybeSR.hasValue()) { DEBUG(dbgs() << "irce: could not compute subranges\n"); return false; @@ -1320,7 +1443,7 @@ bool LoopConstrainer::run() { SubRanges SR = MaybeSR.getValue(); bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = - cast<IntegerType>(MainLoopStructure.IndVarNext->getType()); + cast<IntegerType>(MainLoopStructure.IndVarBase->getType()); SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); Instruction *InsertPt = OriginalPreheader->getTerminator(); @@ -1345,7 +1468,7 @@ bool LoopConstrainer::run() { if (Increasing) ExitPreLoopAtSCEV = *SR.LowLimit; else { - if (CanBeSMin(SE, *SR.HighLimit)) { + if (CanBeMin(SE, *SR.HighLimit, IsSignedPredicate)) { DEBUG(dbgs() << "irce: could not prove no-overflow when computing " << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) << "\n"); @@ -1354,6 +1477,13 @@ bool LoopConstrainer::run() { ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); } + if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { + DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " preloop exit limit " << *ExitPreLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() << "\n"); + return false; + } + ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); ExitPreLoopAt->setName("exit.preloop.at"); } @@ -1364,7 +1494,7 @@ bool LoopConstrainer::run() { if (Increasing) ExitMainLoopAtSCEV = *SR.HighLimit; else { - if (CanBeSMin(SE, *SR.LowLimit)) { + if (CanBeMin(SE, *SR.LowLimit, IsSignedPredicate)) { DEBUG(dbgs() << "irce: could not prove no-overflow when computing " << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) << "\n"); @@ -1373,6 +1503,13 @@ bool LoopConstrainer::run() { ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); } + if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { + DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " main loop exit limit " << *ExitMainLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() << "\n"); + return false; + } + ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); ExitMainLoopAt->setName("exit.mainloop.at"); } @@ -1463,34 +1600,27 @@ bool LoopConstrainer::run() { /// range, returns None. Optional<InductiveRangeCheck::Range> InductiveRangeCheck::computeSafeIterationSpace( - ScalarEvolution &SE, const SCEVAddRecExpr *IndVar) const { + ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const { // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is - // getOffset() and "D" is getScale()). We rewrite the value being range + // getBegin() and "D" is getStep()). We rewrite the value being range // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". - // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code - // can be generalized as needed. // // The actual inequalities we solve are of the form // // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) // - // The inequality is satisfied by -M <= IndVar < (L - M) [^1]. All additions - // and subtractions are twos-complement wrapping and comparisons are signed. - // - // Proof: - // - // If there exists IndVar such that -M <= IndVar < (L - M) then it follows - // that -M <= (-M + L) [== Eq. 1]. Since L >= 0, if (-M + L) sign-overflows - // then (-M + L) < (-M). Hence by [Eq. 1], (-M + L) could not have - // overflown. - // - // This means IndVar = t + (-M) for t in [0, L). Hence (IndVar + M) = t. - // Hence 0 <= (IndVar + M) < L - - // [^1]: Note that the solution does _not_ apply if L < 0; consider values M = - // 127, IndVar = 126 and L = -2 in an i8 world. + // Here L stands for upper limit of the safe iteration space. + // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid + // overflows when calculating (0 - M) and (L - M) we, depending on type of + // IV's iteration space, limit the calculations by borders of the iteration + // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0. + // If we figured out that "anything greater than (-M) is safe", we strengthen + // this to "everything greater than 0 is safe", assuming that values between + // -M and 0 just do not exist in unsigned iteration space, and we don't want + // to deal with overflown values. if (!IndVar->isAffine()) return None; @@ -1499,42 +1629,89 @@ InductiveRangeCheck::computeSafeIterationSpace( const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); if (!B) return None; + assert(!B->isZero() && "Recurrence with zero step?"); - const SCEV *C = getOffset(); - const SCEVConstant *D = dyn_cast<SCEVConstant>(getScale()); + const SCEV *C = getBegin(); + const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep()); if (D != B) return None; - ConstantInt *ConstD = D->getValue(); - if (!(ConstD->isMinusOne() || ConstD->isOne())) - return None; + assert(!D->getValue()->isZero() && "Recurrence with zero step?"); + unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); + const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + // Substract Y from X so that it does not go through border of the IV + // iteration space. Mathematically, it is equivalent to: + // + // ClampedSubstract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] + // + // In [1], 'X - Y' is a mathematical substraction (result is not bounded to + // any width of bit grid). But after we take min/max, the result is + // guaranteed to be within [INT_MIN, INT_MAX]. + // + // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min + // values, depending on type of latch condition that defines IV iteration + // space. + auto ClampedSubstract = [&](const SCEV *X, const SCEV *Y) { + assert(SE.isKnownNonNegative(X) && + "We can only substract from values in [0; SINT_MAX]!"); + if (IsLatchSigned) { + // X is a number from signed range, Y is interpreted as signed. + // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only + // thing we should care about is that we didn't cross SINT_MAX. + // So, if Y is positive, we substract Y safely. + // Rule 1: Y > 0 ---> Y. + // If 0 <= -Y <= (SINT_MAX - X), we substract Y safely. + // Rule 2: Y >=s (X - SINT_MAX) ---> Y. + // If 0 <= (SINT_MAX - X) < -Y, we can only substract (X - SINT_MAX). + // Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX). + // It gives us smax(Y, X - SINT_MAX) to substract in all cases. + const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax); + return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax), + SCEV::FlagNSW); + } else + // X is a number from unsigned range, Y is interpreted as signed. + // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only + // thing we should care about is that we didn't cross zero. + // So, if Y is negative, we substract Y safely. + // Rule 1: Y <s 0 ---> Y. + // If 0 <= Y <= X, we substract Y safely. + // Rule 2: Y <=s X ---> Y. + // If 0 <= X < Y, we should stop at 0 and can only substract X. + // Rule 3: Y >s X ---> X. + // It gives us smin(X, Y) to substract in all cases. + return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW); + }; const SCEV *M = SE.getMinusSCEV(C, A); - - const SCEV *Begin = SE.getNegativeSCEV(M); - const SCEV *UpperLimit = nullptr; + const SCEV *Zero = SE.getZero(M->getType()); + const SCEV *Begin = ClampedSubstract(Zero, M); + const SCEV *L = nullptr; // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". // We can potentially do much better here. - if (Value *V = getLength()) { - UpperLimit = SE.getSCEV(V); - } else { + if (const SCEV *EndLimit = getEnd()) + L = EndLimit; + else { assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); - unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); - UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + L = SIntMax; } - - const SCEV *End = SE.getMinusSCEV(UpperLimit, M); + const SCEV *End = ClampedSubstract(L, M); return InductiveRangeCheck::Range(Begin, End); } static Optional<InductiveRangeCheck::Range> -IntersectRange(ScalarEvolution &SE, - const Optional<InductiveRangeCheck::Range> &R1, - const InductiveRangeCheck::Range &R2) { +IntersectSignedRange(ScalarEvolution &SE, + const Optional<InductiveRangeCheck::Range> &R1, + const InductiveRangeCheck::Range &R2) { + if (R2.isEmpty(SE, /* IsSigned */ true)) + return None; if (!R1.hasValue()) return R2; auto &R1Value = R1.getValue(); + // We never return empty ranges from this function, and R1 is supposed to be + // a result of intersection. Thus, R1 is never empty. + assert(!R1Value.isEmpty(SE, /* IsSigned */ true) && + "We should never have empty R1!"); // TODO: we could widen the smaller range and have this work; but for now we // bail out to keep things simple. @@ -1544,7 +1721,40 @@ IntersectRange(ScalarEvolution &SE, const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); - return InductiveRangeCheck::Range(NewBegin, NewEnd); + // If the resulting range is empty, just return None. + auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); + if (Ret.isEmpty(SE, /* IsSigned */ true)) + return None; + return Ret; +} + +static Optional<InductiveRangeCheck::Range> +IntersectUnsignedRange(ScalarEvolution &SE, + const Optional<InductiveRangeCheck::Range> &R1, + const InductiveRangeCheck::Range &R2) { + if (R2.isEmpty(SE, /* IsSigned */ false)) + return None; + if (!R1.hasValue()) + return R2; + auto &R1Value = R1.getValue(); + // We never return empty ranges from this function, and R1 is supposed to be + // a result of intersection. Thus, R1 is never empty. + assert(!R1Value.isEmpty(SE, /* IsSigned */ false) && + "We should never have empty R1!"); + + // TODO: we could widen the smaller range and have this work; but for now we + // bail out to keep things simple. + if (R1Value.getType() != R2.getType()) + return None; + + const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin()); + const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd()); + + // If the resulting range is empty, just return None. + auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); + if (Ret.isEmpty(SE, /* IsSigned */ false)) + return None; + return Ret; } bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { @@ -1598,24 +1808,31 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { return false; } LoopStructure LS = MaybeLoopStructure.getValue(); - bool Increasing = LS.IndVarIncreasing; - const SCEV *MinusOne = - SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true); const SCEVAddRecExpr *IndVar = - cast<SCEVAddRecExpr>(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne)); + cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep))); Optional<InductiveRangeCheck::Range> SafeIterRange; Instruction *ExprInsertPt = Preheader->getTerminator(); SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; + // Basing on the type of latch predicate, we interpret the IV iteration range + // as signed or unsigned range. We use different min/max functions (signed or + // unsigned) when intersecting this range with safe iteration ranges implied + // by range checks. + auto IntersectRange = + LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange; IRBuilder<> B(ExprInsertPt); for (InductiveRangeCheck &IRC : RangeChecks) { - auto Result = IRC.computeSafeIterationSpace(SE, IndVar); + auto Result = IRC.computeSafeIterationSpace(SE, IndVar, + LS.IsSignedPredicate); if (Result.hasValue()) { auto MaybeSafeIterRange = IntersectRange(SE, SafeIterRange, Result.getValue()); if (MaybeSafeIterRange.hasValue()) { + assert( + !MaybeSafeIterRange.getValue().isEmpty(SE, LS.IsSignedPredicate) && + "We should never return empty ranges!"); RangeChecksToEliminate.push_back(IRC); SafeIterRange = MaybeSafeIterRange.getValue(); } |