diff options
Diffstat (limited to 'lib/Transforms/Scalar/LoopUnswitch.cpp')
-rw-r--r-- | lib/Transforms/Scalar/LoopUnswitch.cpp | 301 |
1 files changed, 211 insertions, 90 deletions
diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index 76fe91884c7b8..a99c9999c6191 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -33,6 +33,7 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -47,6 +48,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/MDBuilder.h" #include "llvm/Support/CommandLine.h" @@ -77,19 +79,6 @@ static cl::opt<unsigned> Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"), cl::init(100), cl::Hidden); -static cl::opt<bool> -LoopUnswitchWithBlockFrequency("loop-unswitch-with-block-frequency", - cl::init(false), cl::Hidden, - cl::desc("Enable the use of the block frequency analysis to access PGO " - "heuristics to minimize code growth in cold regions.")); - -static cl::opt<unsigned> -ColdnessThreshold("loop-unswitch-coldness-threshold", cl::init(1), cl::Hidden, - cl::desc("Coldness threshold in percentage. The loop header frequency " - "(relative to the entry frequency) is compared with this " - "threshold to determine if non-trivial unswitching should be " - "enabled.")); - namespace { class LUAnalysisCache { @@ -174,13 +163,6 @@ namespace { LUAnalysisCache BranchesInfo; - bool EnabledPGO; - - // BFI and ColdEntryFreq are only used when PGO and - // LoopUnswitchWithBlockFrequency are enabled. - BlockFrequencyInfo BFI; - BlockFrequency ColdEntryFreq; - bool OptimizeForSize; bool redoLoop; @@ -199,12 +181,14 @@ namespace { // NewBlocks contained cloned copy of basic blocks from LoopBlocks. std::vector<BasicBlock*> NewBlocks; + bool hasBranchDivergence; + public: static char ID; // Pass ID, replacement for typeid - explicit LoopUnswitch(bool Os = false) : + explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) : LoopPass(ID), OptimizeForSize(Os), redoLoop(false), currentLoop(nullptr), DT(nullptr), loopHeader(nullptr), - loopPreheader(nullptr) { + loopPreheader(nullptr), hasBranchDivergence(hasBranchDivergence) { initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); } @@ -217,6 +201,8 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + if (hasBranchDivergence) + AU.addRequired<DivergenceAnalysis>(); getLoopAnalysisUsage(AU); } @@ -255,6 +241,11 @@ namespace { TerminatorInst *TI); void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); + + /// Given that the Invariant is not equal to Val. Simplify instructions + /// in the loop. + Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, + Constant *Val); }; } @@ -381,16 +372,35 @@ INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) -Pass *llvm::createLoopUnswitchPass(bool Os) { - return new LoopUnswitch(Os); +Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) { + return new LoopUnswitch(Os, hasBranchDivergence); } +/// Operator chain lattice. +enum OperatorChain { + OC_OpChainNone, ///< There is no operator. + OC_OpChainOr, ///< There are only ORs. + OC_OpChainAnd, ///< There are only ANDs. + OC_OpChainMixed ///< There are ANDs and ORs. +}; + /// Cond is a condition that occurs in L. If it is invariant in the loop, or has /// an invariant piece, return the invariant. Otherwise, return null. +// +/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will simplify +/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 +/// to simplify the chain. +/// +/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to +/// simplify the condition itself to a loop variant condition, but at the +/// cost of creating an entirely new loop. static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + OperatorChain &ParentChain, DenseMap<Value *, Value *> &Cache) { auto CacheIt = Cache.find(Cond); if (CacheIt != Cache.end()) @@ -414,21 +424,53 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, return Cond; } + // Walk up the operator chain to find partial invariant conditions. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond)) if (BO->getOpcode() == Instruction::And || BO->getOpcode() == Instruction::Or) { - // If either the left or right side is invariant, we can unswitch on this, - // which will cause the branch to go away in one loop and the condition to - // simplify in the other one. - if (Value *LHS = - FindLIVLoopCondition(BO->getOperand(0), L, Changed, Cache)) { - Cache[Cond] = LHS; - return LHS; + // Given the previous operator, compute the current operator chain status. + OperatorChain NewChain; + switch (ParentChain) { + case OC_OpChainNone: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainOr; + break; + case OC_OpChainOr: + NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr : + OC_OpChainMixed; + break; + case OC_OpChainAnd: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainMixed; + break; + case OC_OpChainMixed: + NewChain = OC_OpChainMixed; + break; } - if (Value *RHS = - FindLIVLoopCondition(BO->getOperand(1), L, Changed, Cache)) { - Cache[Cond] = RHS; - return RHS; + + // If we reach a Mixed state, we do not want to keep walking up as we can not + // reliably find a value that will simplify the chain. With this check, we + // will return null on the first sight of mixed chain and the caller will + // either backtrack to find partial LIV in other operand or return null. + if (NewChain != OC_OpChainMixed) { + // Update the current operator chain type before we search up the chain. + ParentChain = NewChain; + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = LHS; + return LHS; + } + // We did not manage to find a partial LIV in operand(0). Backtrack and try + // operand(1). + ParentChain = NewChain; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = RHS; + return RHS; + } } } @@ -436,9 +478,21 @@ static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, return nullptr; } -static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed) { +/// Cond is a condition that occurs in L. If it is invariant in the loop, or has +/// an invariant piece, return the invariant along with the operator chain type. +/// Otherwise, return null. +static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond, + Loop *L, + bool &Changed) { DenseMap<Value *, Value *> Cache; - return FindLIVLoopCondition(Cond, L, Changed, Cache); + OperatorChain OpChain = OC_OpChainNone; + Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache); + + // In case we do find a LIV, it can not be obtained by walking up a mixed + // operator chain. + assert((!FCond || OpChain != OC_OpChainMixed) && + "Do not expect a partial LIV with mixed operator chain"); + return {FCond, OpChain}; } bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { @@ -457,19 +511,6 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { if (SanitizeMemory) computeLoopSafetyInfo(&SafetyInfo, L); - EnabledPGO = F->getEntryCount().hasValue(); - - if (LoopUnswitchWithBlockFrequency && EnabledPGO) { - BranchProbabilityInfo BPI(*F, *LI); - BFI.calculate(*L->getHeader()->getParent(), BPI, *LI); - - // Use BranchProbability to compute a minimum frequency based on - // function entry baseline frequency. Loops with headers below this - // frequency are considered as cold. - const BranchProbability ColdProb(ColdnessThreshold, 100); - ColdEntryFreq = BlockFrequency(BFI.getEntryFreq()) * ColdProb; - } - bool Changed = false; do { assert(currentLoop->isLCSSAForm(*DT)); @@ -581,19 +622,9 @@ bool LoopUnswitch::processCurrentLoop() { loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) return false; - if (LoopUnswitchWithBlockFrequency && EnabledPGO) { - // Compute the weighted frequency of the hottest block in the - // loop (loopHeader in this case since inner loops should be - // processed before outer loop). If it is less than ColdFrequency, - // we should not unswitch. - BlockFrequency LoopEntryFreq = BFI.getBlockFreq(loopHeader); - if (LoopEntryFreq < ColdEntryFreq) - return false; - } - for (IntrinsicInst *Guard : Guards) { Value *LoopCond = - FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed); + FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { // NB! Unswitching (if successful) could have erased some of the @@ -634,7 +665,7 @@ bool LoopUnswitch::processCurrentLoop() { // See if this, or some part of it, is loop invariant. If so, we can // unswitch on it if we desire. Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; @@ -642,24 +673,48 @@ bool LoopUnswitch::processCurrentLoop() { } } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + Value *SC = SI->getCondition(); + Value *LoopCond; + OperatorChain OpChain; + std::tie(LoopCond, OpChain) = + FindLIVLoopCondition(SC, currentLoop, Changed); + unsigned NumCases = SI->getNumCases(); if (LoopCond && NumCases) { // Find a value to unswitch on: // FIXME: this should chose the most expensive case! // FIXME: scan for a case with a non-critical edge? Constant *UnswitchVal = nullptr; - - // Do not process same value again and again. - // At this point we have some cases already unswitched and - // some not yet unswitched. Let's find the first not yet unswitched one. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { - Constant *UnswitchValCandidate = i.getCaseValue(); - if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { - UnswitchVal = UnswitchValCandidate; - break; + // Find a case value such that at least one case value is unswitched + // out. + if (OpChain == OC_OpChainAnd) { + // If the chain only has ANDs and the switch has a case value of 0. + // Dropping in a 0 to the chain will unswitch out the 0-casevalue. + auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllZero)) + continue; + // We are unswitching 0 out. + UnswitchVal = AllZero; + } else if (OpChain == OC_OpChainOr) { + // If the chain only has ORs and the switch has a case value of ~0. + // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. + auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllOne)) + continue; + // We are unswitching ~0 out. + UnswitchVal = AllOne; + } else { + assert(OpChain == OC_OpChainNone && + "Expect to unswitch on trivial chain"); + // Do not process same value again and again. + // At this point we have some cases already unswitched and + // some not yet unswitched. Let's find the first not yet unswitched one. + for (auto Case : SI->cases()) { + Constant *UnswitchValCandidate = Case.getCaseValue(); + if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { + UnswitchVal = UnswitchValCandidate; + break; + } } } @@ -668,6 +723,11 @@ bool LoopUnswitch::processCurrentLoop() { if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { ++NumSwitches; + // In case of a full LIV, UnswitchVal is the value we unswitched out. + // In case of a partial LIV, we only unswitch when its an AND-chain + // or OR-chain. In both cases switch input value simplifies to + // UnswitchVal. + BranchesInfo.setUnswitched(SI, UnswitchVal); return true; } } @@ -678,7 +738,7 @@ bool LoopUnswitch::processCurrentLoop() { BBI != E; ++BBI) if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; if (LoopCond && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { ++NumSelects; @@ -753,6 +813,15 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, << ". Cost too high.\n"); return false; } + if (hasBranchDivergence && + getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) { + DEBUG(dbgs() << "NOT unswitching loop %" + << currentLoop->getHeader()->getName() + << " at non-trivial condition '" << *Val + << "' == " << *LoopCond << "\n" + << ". Condition is divergent.\n"); + return false; + } UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI); return true; @@ -899,7 +968,6 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { if (I.mayHaveSideEffects()) return false; - // FIXME: add check for constant foldable switch instructions. if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) { if (BI->isUnconditional()) { CurrentBB = BI->getSuccessor(0); @@ -911,7 +979,16 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { // Found a trivial condition candidate: non-foldable conditional branch. break; } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { + // At this point, any constant-foldable instructions should have probably + // been folded. + ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition()); + if (!Cond) + break; + // Find the target block we are definitely going to. + CurrentBB = SI->findCaseValue(Cond)->getCaseSuccessor(); } else { + // We do not understand these terminator instructions. break; } @@ -929,7 +1006,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { return false; Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -960,7 +1037,7 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { // If this isn't switching on an invariant condition, we can't unswitch it. Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), - currentLoop, Changed); + currentLoop, Changed).first; // Unswitch only if the trivial condition itself is an LIV (not // partial LIV which could occur in and/or) @@ -973,13 +1050,12 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { // this. // Note that we can't trivially unswitch on the default case or // on already unswitched cases. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { + for (auto Case : SI->cases()) { BasicBlock *LoopExitCandidate; - if ((LoopExitCandidate = isTrivialLoopExitBlock(currentLoop, - i.getCaseSuccessor()))) { + if ((LoopExitCandidate = + isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) { // Okay, we found a trivial case, remember the value that is trivial. - ConstantInt *CaseVal = i.getCaseValue(); + ConstantInt *CaseVal = Case.getCaseValue(); // Check that it was not unswitched before, since already unswitched // trivial vals are looks trivial too. @@ -998,6 +1074,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, nullptr); + + // We are only unswitching full LIV. + BranchesInfo.setUnswitched(SI, CondVal); ++NumSwitches; return true; } @@ -1253,18 +1332,38 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, if (!UI || !L->contains(UI)) continue; - Worklist.push_back(UI); + // At this point, we know LIC is definitely not Val. Try to use some simple + // logic to simplify the user w.r.t. to the context. + if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) { + if (LI->replacementPreservesLCSSAForm(UI, Replacement)) { + // This in-loop instruction has been simplified w.r.t. its context, + // i.e. LIC != Val, make sure we propagate its replacement value to + // all its users. + // + // We can not yet delete UI, the LIC user, yet, because that would invalidate + // the LIC->users() iterator !. However, we can make this instruction + // dead by replacing all its users and push it onto the worklist so that + // it can be properly deleted and its operands simplified. + UI->replaceAllUsesWith(Replacement); + } + } - // TODO: We could do other simplifications, for example, turning - // 'icmp eq LIC, Val' -> false. + // This is a LIC user, push it into the worklist so that SimplifyCode can + // attempt to simplify it. + Worklist.push_back(UI); // If we know that LIC is not Val, use this info to simplify code. SwitchInst *SI = dyn_cast<SwitchInst>(UI); if (!SI || !isa<ConstantInt>(Val)) continue; - SwitchInst::CaseIt DeadCase = SI->findCaseValue(cast<ConstantInt>(Val)); + // NOTE: if a case value for the switch is unswitched out, we record it + // after the unswitch finishes. We can not record it here as the switch + // is not a direct user of the partial LIV. + SwitchInst::CaseHandle DeadCase = + *SI->findCaseValue(cast<ConstantInt>(Val)); // Default case is live for multiple values. - if (DeadCase == SI->case_default()) continue; + if (DeadCase == *SI->case_default()) + continue; // Found a dead case value. Don't remove PHI nodes in the // successor if they become single-entry, those PHI nodes may @@ -1274,8 +1373,6 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, BasicBlock *SISucc = DeadCase.getCaseSuccessor(); BasicBlock *Latch = L->getLoopLatch(); - BranchesInfo.setUnswitched(SI, Val); - if (!SI->findCaseDest(SISucc)) continue; // Edge is critical. // If the DeadCase successor dominates the loop latch, then the // transformation isn't safe since it will delete the sole predecessor edge @@ -1397,3 +1494,27 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { } } } + +/// Simple simplifications we can do given the information that Cond is +/// definitely not equal to Val. +Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst, + Value *Invariant, + Constant *Val) { + // icmp eq cond, val -> false + ICmpInst *CI = dyn_cast<ICmpInst>(Inst); + if (CI && CI->isEquality()) { + Value *Op0 = CI->getOperand(0); + Value *Op1 = CI->getOperand(1); + if ((Op0 == Invariant && Op1 == Val) || (Op0 == Val && Op1 == Invariant)) { + LLVMContext &Ctx = Inst->getContext(); + if (CI->getPredicate() == CmpInst::ICMP_EQ) + return ConstantInt::getFalse(Ctx); + else + return ConstantInt::getTrue(Ctx); + } + } + + // FIXME: there may be other opportunities, e.g. comparison with floating + // point, or Invariant - Val != 0, etc. + return nullptr; +} |