diff options
Diffstat (limited to 'lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp')
-rw-r--r-- | lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp | 101 |
1 files changed, 59 insertions, 42 deletions
diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 1c701bbee185..997d68838152 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -1,9 +1,8 @@ //===- InductiveRangeCheckElimination.cpp - -------------------------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -116,6 +115,11 @@ static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch", cl::Hidden, cl::init(true)); +static cl::opt<bool> AllowNarrowLatchCondition( + "irce-allow-narrow-latch", cl::Hidden, cl::init(true), + cl::desc("If set to true, IRCE may eliminate wide range checks in loops " + "with narrow latch condition.")); + static const char *ClonedLoopTag = "irce.loop.clone"; #define DEBUG_TYPE "irce" @@ -532,12 +536,6 @@ class LoopConstrainer { Optional<const SCEV *> HighLimit; }; - // A utility function that does a `replaceUsesOfWith' on the incoming block - // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's - // incoming block list with `ReplaceBy'. - static void replacePHIBlock(PHINode *PN, BasicBlock *Block, - BasicBlock *ReplaceBy); - // Compute a safe set of limits for the main loop to run in -- effectively the // intersection of `Range' and the iteration space of the original loop. // Return None if unable to compute the set of subranges. @@ -639,13 +637,6 @@ public: } // end anonymous namespace -void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, - BasicBlock *ReplaceBy) { - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (PN->getIncomingBlock(i) == Block) - PN->setIncomingBlock(i, ReplaceBy); -} - /// Given a loop with an deccreasing induction variable, is it possible to /// safely calculate the bounds of a new loop using the given Predicate. static bool isSafeDecreasingBound(const SCEV *Start, @@ -868,7 +859,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, assert(!StepCI->isZero() && "Zero step?"); bool IsIncreasing = !StepCI->isNegative(); - bool IsSignedPredicate = ICmpInst::isSigned(Pred); + bool IsSignedPredicate; const SCEV *StartNext = IndVarBase->getStart(); const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); @@ -1045,11 +1036,23 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, return Result; } +/// If the type of \p S matches with \p Ty, return \p S. Otherwise, return +/// signed or unsigned extension of \p S to type \p Ty. +static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE, + bool Signed) { + return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty); +} + Optional<LoopConstrainer::SubRanges> LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); - if (Range.getType() != Ty) + auto *RTy = cast<IntegerType>(Range.getType()); + + // We only support wide range checks and narrow latches. + if (!AllowNarrowLatchCondition && RTy != Ty) + return None; + if (RTy->getBitWidth() < Ty->getBitWidth()) return None; LoopConstrainer::SubRanges Result; @@ -1057,8 +1060,10 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { // I think we can be more aggressive here and make this nuw / nsw if the // addition that feeds into the icmp for the latch's terminating branch is nuw // / nsw. In any case, a wrapping 2's complement addition is safe. - const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); - const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); + const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart), + RTy, SE, IsSignedPredicate); + const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy, + SE, IsSignedPredicate); bool Increasing = MainLoopStructure.IndVarIncreasing; @@ -1068,7 +1073,7 @@ LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; - const SCEV *One = SE.getOne(Ty); + const SCEV *One = SE.getOne(RTy); if (Increasing) { Smallest = Start; Greatest = End; @@ -1257,6 +1262,13 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( bool IsSignedPredicate = LS.IsSignedPredicate; IRBuilder<> B(PreheaderJump); + auto *RangeTy = Range.getBegin()->getType(); + auto NoopOrExt = [&](Value *V) { + if (V->getType() == RangeTy) + return V; + return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) + : B.CreateZExt(V, RangeTy, "wide." + V->getName()); + }; // EnterLoopCond - is it okay to start executing this `LS'? Value *EnterLoopCond = nullptr; @@ -1264,15 +1276,16 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( Increasing ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); - EnterLoopCond = B.CreateICmp(Pred, LS.IndVarStart, ExitSubloopAt); + Value *IndVarStart = NoopOrExt(LS.IndVarStart); + EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); PreheaderJump->eraseFromParent(); LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); - Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, LS.IndVarBase, - ExitSubloopAt); + Value *IndVarBase = NoopOrExt(LS.IndVarBase); + Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); Value *CondForBranch = LS.LatchBrExitIdx == 1 ? TakeBackedgeLoopCond @@ -1285,7 +1298,8 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. - Value *IterationsLeft = B.CreateICmp(Pred, LS.IndVarBase, LS.LoopExitAt); + Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); + Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); BranchInst *BranchToContinuation = @@ -1304,15 +1318,14 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( RRI.PHIValuesAtPseudoExit.push_back(NewPHI); } - RRI.IndVarEnd = PHINode::Create(LS.IndVarBase->getType(), 2, "indvar.end", + RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", BranchToContinuation); - RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); - RRI.IndVarEnd->addIncoming(LS.IndVarBase, RRI.ExitSelector); + RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); + RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); // The latch exit now has a branch from `RRI.ExitSelector' instead of // `LS.Latch'. The PHI nodes need to be updated to reflect that. - for (PHINode &PN : LS.LatchExit->phis()) - replacePHIBlock(&PN, LS.Latch, RRI.ExitSelector); + LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); return RRI; } @@ -1322,9 +1335,8 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs( const LoopConstrainer::RewrittenRangeInfo &RRI) const { unsigned PHIIndex = 0; for (PHINode &PN : LS.Header->phis()) - for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) - if (PN.getIncomingBlock(i) == ContinuationBlock) - PN.setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); + PN.setIncomingValueForBlock(ContinuationBlock, + RRI.PHIValuesAtPseudoExit[PHIIndex++]); LS.IndVarStart = RRI.IndVarEnd; } @@ -1335,9 +1347,7 @@ BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); BranchInst::Create(LS.Header, Preheader); - for (PHINode &PN : LS.Header->phis()) - for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) - replacePHIBlock(&PN, OldPreheader, Preheader); + LS.Header->replacePhiUsesWith(OldPreheader, Preheader); return Preheader; } @@ -1393,7 +1403,7 @@ bool LoopConstrainer::run() { SubRanges SR = MaybeSR.getValue(); bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = - cast<IntegerType>(MainLoopStructure.IndVarBase->getType()); + cast<IntegerType>(Range.getBegin()->getType()); SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); Instruction *InsertPt = OriginalPreheader->getTerminator(); @@ -1534,7 +1544,7 @@ bool LoopConstrainer::run() { // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) { formLCSSARecursively(*L, DT, &LI, &SE); - simplifyLoop(L, &DT, &LI, &SE, nullptr, true); + simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); // Pre/post loops are slow paths, we do not need to perform any loop // optimizations on them. if (!IsOriginalLoop) @@ -1556,6 +1566,12 @@ Optional<InductiveRangeCheck::Range> InductiveRangeCheck::computeSafeIterationSpace( ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, bool IsLatchSigned) const { + // We can deal when types of latch check and range checks don't match in case + // if latch check is more narrow. + auto *IVType = cast<IntegerType>(IndVar->getType()); + auto *RCType = cast<IntegerType>(getBegin()->getType()); + if (IVType->getBitWidth() > RCType->getBitWidth()) + return None; // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is @@ -1579,8 +1595,9 @@ InductiveRangeCheck::computeSafeIterationSpace( if (!IndVar->isAffine()) return None; - const SCEV *A = IndVar->getStart(); - const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); + const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned); + const SCEVConstant *B = dyn_cast<SCEVConstant>( + NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned)); if (!B) return None; assert(!B->isZero() && "Recurrence with zero step?"); @@ -1591,7 +1608,7 @@ InductiveRangeCheck::computeSafeIterationSpace( return None; assert(!D->getValue()->isZero() && "Recurrence with zero step?"); - unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); + unsigned BitWidth = RCType->getBitWidth(); const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); // Subtract Y from X so that it does not go through border of the IV |