diff options
Diffstat (limited to 'lib/Transforms')
179 files changed, 17285 insertions, 7932 deletions
diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp index b622d018478a..c795866ec0f2 100644 --- a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -16,7 +16,7 @@ #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" #include "AggressiveInstCombineInternal.h" #include "llvm-c/Initialization.h" -#include "llvm-c/Transforms/Scalar.h" +#include "llvm-c/Transforms/AggressiveInstCombine.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" @@ -59,6 +59,99 @@ public: }; } // namespace +/// Match a pattern for a bitwise rotate operation that partially guards +/// against undefined behavior by branching around the rotation when the shift +/// amount is 0. +static bool foldGuardedRotateToFunnelShift(Instruction &I) { + if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2) + return false; + + // As with the one-use checks below, this is not strictly necessary, but we + // are being cautious to avoid potential perf regressions on targets that + // do not actually have a rotate instruction (where the funnel shift would be + // expanded back into math/shift/logic ops). + if (!isPowerOf2_32(I.getType()->getScalarSizeInBits())) + return false; + + // Match V to funnel shift left/right and capture the source operand and + // shift amount in X and Y. + auto matchRotate = [](Value *V, Value *&X, Value *&Y) { + Value *L0, *L1, *R0, *R1; + unsigned Width = V->getType()->getScalarSizeInBits(); + auto Sub = m_Sub(m_SpecificInt(Width), m_Value(R1)); + + // rotate_left(X, Y) == (X << Y) | (X >> (Width - Y)) + auto RotL = m_OneUse( + m_c_Or(m_Shl(m_Value(L0), m_Value(L1)), m_LShr(m_Value(R0), Sub))); + if (RotL.match(V) && L0 == R0 && L1 == R1) { + X = L0; + Y = L1; + return Intrinsic::fshl; + } + + // rotate_right(X, Y) == (X >> Y) | (X << (Width - Y)) + auto RotR = m_OneUse( + m_c_Or(m_LShr(m_Value(L0), m_Value(L1)), m_Shl(m_Value(R0), Sub))); + if (RotR.match(V) && L0 == R0 && L1 == R1) { + X = L0; + Y = L1; + return Intrinsic::fshr; + } + + return Intrinsic::not_intrinsic; + }; + + // One phi operand must be a rotate operation, and the other phi operand must + // be the source value of that rotate operation: + // phi [ rotate(RotSrc, RotAmt), RotBB ], [ RotSrc, GuardBB ] + PHINode &Phi = cast<PHINode>(I); + Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1); + Value *RotSrc, *RotAmt; + Intrinsic::ID IID = matchRotate(P0, RotSrc, RotAmt); + if (IID == Intrinsic::not_intrinsic || RotSrc != P1) { + IID = matchRotate(P1, RotSrc, RotAmt); + if (IID == Intrinsic::not_intrinsic || RotSrc != P0) + return false; + assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && + "Pattern must match funnel shift left or right"); + } + + // The incoming block with our source operand must be the "guard" block. + // That must contain a cmp+branch to avoid the rotate when the shift amount + // is equal to 0. The other incoming block is the block with the rotate. + BasicBlock *GuardBB = Phi.getIncomingBlock(RotSrc == P1); + BasicBlock *RotBB = Phi.getIncomingBlock(RotSrc != P1); + Instruction *TermI = GuardBB->getTerminator(); + BasicBlock *TrueBB, *FalseBB; + ICmpInst::Predicate Pred; + if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(RotAmt), m_ZeroInt()), TrueBB, + FalseBB))) + return false; + + BasicBlock *PhiBB = Phi.getParent(); + if (Pred != CmpInst::ICMP_EQ || TrueBB != PhiBB || FalseBB != RotBB) + return false; + + // We matched a variation of this IR pattern: + // GuardBB: + // %cmp = icmp eq i32 %RotAmt, 0 + // br i1 %cmp, label %PhiBB, label %RotBB + // RotBB: + // %sub = sub i32 32, %RotAmt + // %shr = lshr i32 %X, %sub + // %shl = shl i32 %X, %RotAmt + // %rot = or i32 %shr, %shl + // br label %PhiBB + // PhiBB: + // %cond = phi i32 [ %rot, %RotBB ], [ %X, %GuardBB ] + // --> + // llvm.fshl.i32(i32 %X, i32 %RotAmt) + IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt()); + Function *F = Intrinsic::getDeclaration(Phi.getModule(), IID, Phi.getType()); + Phi.replaceAllUsesWith(Builder.CreateCall(F, {RotSrc, RotSrc, RotAmt})); + return true; +} + /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain /// of 'and' ops, then we also need to capture the fact that we saw an @@ -69,9 +162,9 @@ struct MaskOps { bool MatchAndChain; bool FoundAnd1; - MaskOps(unsigned BitWidth, bool MatchAnds) : - Root(nullptr), Mask(APInt::getNullValue(BitWidth)), - MatchAndChain(MatchAnds), FoundAnd1(false) {} + MaskOps(unsigned BitWidth, bool MatchAnds) + : Root(nullptr), Mask(APInt::getNullValue(BitWidth)), + MatchAndChain(MatchAnds), FoundAnd1(false) {} }; /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a @@ -152,8 +245,8 @@ static bool foldAnyOrAllBitsSet(Instruction &I) { IRBuilder<> Builder(&I); Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask); Value *And = Builder.CreateAnd(MOps.Root, Mask); - Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask) : - Builder.CreateIsNotNull(And); + Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask) + : Builder.CreateIsNotNull(And); Value *Zext = Builder.CreateZExt(Cmp, I.getType()); I.replaceAllUsesWith(Zext); return true; @@ -174,8 +267,10 @@ static bool foldUnusualPatterns(Function &F, DominatorTree &DT) { // Also, we want to avoid matching partial patterns. // TODO: It would be more efficient if we removed dead instructions // iteratively in this loop rather than waiting until the end. - for (Instruction &I : make_range(BB.rbegin(), BB.rend())) + for (Instruction &I : make_range(BB.rbegin(), BB.rend())) { MadeChange |= foldAnyOrAllBitsSet(I); + MadeChange |= foldGuardedRotateToFunnelShift(I); + } } // We're done with transforms, so remove dead instructions. diff --git a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h index 199374cdabf3..f3c8bde9f8ff 100644 --- a/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h +++ b/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h @@ -13,6 +13,9 @@ // //===----------------------------------------------------------------------===// +#ifndef LLVM_LIB_TRANSFORMS_AGGRESSIVEINSTCOMBINE_COMBINEINTERNAL_H +#define LLVM_LIB_TRANSFORMS_AGGRESSIVEINSTCOMBINE_COMBINEINTERNAL_H + #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/AliasAnalysis.h" @@ -119,3 +122,5 @@ private: void ReduceExpressionDag(Type *SclTy); }; } // end namespace llvm. + +#endif diff --git a/lib/Transforms/Coroutines/CoroElide.cpp b/lib/Transforms/Coroutines/CoroElide.cpp index dfe05c4b2a5e..58f952b54f3a 100644 --- a/lib/Transforms/Coroutines/CoroElide.cpp +++ b/lib/Transforms/Coroutines/CoroElide.cpp @@ -157,7 +157,7 @@ bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { SmallPtrSet<Instruction *, 8> Terminators; for (BasicBlock &B : *F) { auto *TI = B.getTerminator(); - if (TI->getNumSuccessors() == 0 && !TI->isExceptional() && + if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && !isa<UnreachableInst>(TI)) Terminators.insert(TI); } diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp index cf63b678b618..4cb0a52961cc 100644 --- a/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/lib/Transforms/Coroutines/CoroFrame.cpp @@ -49,7 +49,7 @@ public: BlockToIndexMapping(Function &F) { for (BasicBlock &BB : F) V.push_back(&BB); - llvm::sort(V.begin(), V.end()); + llvm::sort(V); } size_t blockToIndex(BasicBlock *BB) const { @@ -546,7 +546,8 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { } else { // For all other values, the spill is placed immediately after // the definition. - assert(!isa<TerminatorInst>(E.def()) && "unexpected terminator"); + assert(!cast<Instruction>(E.def())->isTerminator() && + "unexpected terminator"); InsertPt = cast<Instruction>(E.def())->getNextNode(); } @@ -600,7 +601,7 @@ static Instruction *insertSpills(SpillInfo &Spills, coro::Shape &Shape) { } // Sets the unwind edge of an instruction to a particular successor. -static void setUnwindEdgeTo(TerminatorInst *TI, BasicBlock *Succ) { +static void setUnwindEdgeTo(Instruction *TI, BasicBlock *Succ) { if (auto *II = dyn_cast<InvokeInst>(TI)) II->setUnwindDest(Succ); else if (auto *CS = dyn_cast<CatchSwitchInst>(TI)) diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp index 49acc5e93a39..9eeceb217ba8 100644 --- a/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/lib/Transforms/Coroutines/CoroSplit.cpp @@ -459,7 +459,7 @@ static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { DenseMap<Value *, Value *> ResolvedValues; Instruction *I = InitialInst; - while (isa<TerminatorInst>(I)) { + while (I->isTerminator()) { if (isa<ReturnInst>(I)) { if (I != InitialInst) ReplaceInstWithInst(InitialInst, I->clone()); @@ -538,43 +538,92 @@ static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { CoroBegin->eraseFromParent(); } -// look for a very simple pattern -// coro.save -// no other calls -// resume or destroy call -// coro.suspend -// -// If there are other calls between coro.save and coro.suspend, they can -// potentially resume or destroy the coroutine, so it is unsafe to eliminate a -// suspend point. -static bool simplifySuspendPoint(CoroSuspendInst *Suspend, - CoroBeginInst *CoroBegin) { - auto *Save = Suspend->getCoroSave(); - auto *BB = Suspend->getParent(); - if (BB != Save->getParent()) - return false; +// SimplifySuspendPoint needs to check that there is no calls between +// coro_save and coro_suspend, since any of the calls may potentially resume +// the coroutine and if that is the case we cannot eliminate the suspend point. +static bool hasCallsInBlockBetween(Instruction *From, Instruction *To) { + for (Instruction *I = From; I != To; I = I->getNextNode()) { + // Assume that no intrinsic can resume the coroutine. + if (isa<IntrinsicInst>(I)) + continue; - CallSite SingleCallSite; + if (CallSite(I)) + return true; + } + return false; +} - // Check that we have only one CallSite. - for (Instruction *I = Save->getNextNode(); I != Suspend; - I = I->getNextNode()) { - if (isa<CoroFrameInst>(I)) - continue; - if (isa<CoroSubFnInst>(I)) - continue; - if (CallSite CS = CallSite(I)) { - if (SingleCallSite) - return false; - else - SingleCallSite = CS; - } +static bool hasCallsInBlocksBetween(BasicBlock *SaveBB, BasicBlock *ResDesBB) { + SmallPtrSet<BasicBlock *, 8> Set; + SmallVector<BasicBlock *, 8> Worklist; + + Set.insert(SaveBB); + Worklist.push_back(ResDesBB); + + // Accumulate all blocks between SaveBB and ResDesBB. Because CoroSaveIntr + // returns a token consumed by suspend instruction, all blocks in between + // will have to eventually hit SaveBB when going backwards from ResDesBB. + while (!Worklist.empty()) { + auto *BB = Worklist.pop_back_val(); + Set.insert(BB); + for (auto *Pred : predecessors(BB)) + if (Set.count(Pred) == 0) + Worklist.push_back(Pred); } - auto *CallInstr = SingleCallSite.getInstruction(); - if (!CallInstr) + + // SaveBB and ResDesBB are checked separately in hasCallsBetween. + Set.erase(SaveBB); + Set.erase(ResDesBB); + + for (auto *BB : Set) + if (hasCallsInBlockBetween(BB->getFirstNonPHI(), nullptr)) + return true; + + return false; +} + +static bool hasCallsBetween(Instruction *Save, Instruction *ResumeOrDestroy) { + auto *SaveBB = Save->getParent(); + auto *ResumeOrDestroyBB = ResumeOrDestroy->getParent(); + + if (SaveBB == ResumeOrDestroyBB) + return hasCallsInBlockBetween(Save->getNextNode(), ResumeOrDestroy); + + // Any calls from Save to the end of the block? + if (hasCallsInBlockBetween(Save->getNextNode(), nullptr)) + return true; + + // Any calls from begging of the block up to ResumeOrDestroy? + if (hasCallsInBlockBetween(ResumeOrDestroyBB->getFirstNonPHI(), + ResumeOrDestroy)) + return true; + + // Any calls in all of the blocks between SaveBB and ResumeOrDestroyBB? + if (hasCallsInBlocksBetween(SaveBB, ResumeOrDestroyBB)) + return true; + + return false; +} + +// If a SuspendIntrin is preceded by Resume or Destroy, we can eliminate the +// suspend point and replace it with nornal control flow. +static bool simplifySuspendPoint(CoroSuspendInst *Suspend, + CoroBeginInst *CoroBegin) { + Instruction *Prev = Suspend->getPrevNode(); + if (!Prev) { + auto *Pred = Suspend->getParent()->getSinglePredecessor(); + if (!Pred) + return false; + Prev = Pred->getTerminator(); + } + + CallSite CS{Prev}; + if (!CS) return false; - auto *Callee = SingleCallSite.getCalledValue()->stripPointerCasts(); + auto *CallInstr = CS.getInstruction(); + + auto *Callee = CS.getCalledValue()->stripPointerCasts(); // See if the callsite is for resumption or destruction of the coroutine. auto *SubFn = dyn_cast<CoroSubFnInst>(Callee); @@ -585,6 +634,13 @@ static bool simplifySuspendPoint(CoroSuspendInst *Suspend, if (SubFn->getFrame() != CoroBegin) return false; + // See if the transformation is safe. Specifically, see if there are any + // calls in between Save and CallInstr. They can potenitally resume the + // coroutine rendering this optimization unsafe. + auto *Save = Suspend->getCoroSave(); + if (hasCallsBetween(Save, CallInstr)) + return false; + // Replace llvm.coro.suspend with the value that results in resumption over // the resume or cleanup path. Suspend->replaceAllUsesWith(SubFn->getRawIndex()); @@ -592,8 +648,20 @@ static bool simplifySuspendPoint(CoroSuspendInst *Suspend, Save->eraseFromParent(); // No longer need a call to coro.resume or coro.destroy. + if (auto *Invoke = dyn_cast<InvokeInst>(CallInstr)) { + BranchInst::Create(Invoke->getNormalDest(), Invoke); + } + + // Grab the CalledValue from CS before erasing the CallInstr. + auto *CalledValue = CS.getCalledValue(); CallInstr->eraseFromParent(); + // If no more users remove it. Usually it is a bitcast of SubFn. + if (CalledValue != SubFn && CalledValue->user_empty()) + if (auto *I = dyn_cast<Instruction>(CalledValue)) + I->eraseFromParent(); + + // Now we are good to remove SubFn. if (SubFn->user_empty()) SubFn->eraseFromParent(); diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp index 731faeb5dce4..cf84f916e24b 100644 --- a/lib/Transforms/Coroutines/Coroutines.cpp +++ b/lib/Transforms/Coroutines/Coroutines.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Coroutines.h" +#include "llvm-c/Transforms/Coroutines.h" #include "CoroInstr.h" #include "CoroInternal.h" #include "llvm/ADT/SmallVector.h" @@ -344,3 +345,19 @@ void coro::Shape::buildFrom(Function &F) { for (CoroSaveInst *CoroSave : UnusedCoroSaves) CoroSave->eraseFromParent(); } + +void LLVMAddCoroEarlyPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCoroEarlyPass()); +} + +void LLVMAddCoroSplitPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCoroSplitPass()); +} + +void LLVMAddCoroElidePass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCoroElidePass()); +} + +void LLVMAddCoroCleanupPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCoroCleanupPass()); +} diff --git a/lib/Transforms/Hello/CMakeLists.txt b/lib/Transforms/Hello/CMakeLists.txt index 4a55dd9c04b8..c4f10247c1a6 100644 --- a/lib/Transforms/Hello/CMakeLists.txt +++ b/lib/Transforms/Hello/CMakeLists.txt @@ -10,7 +10,7 @@ if(WIN32 OR CYGWIN) set(LLVM_LINK_COMPONENTS Core Support) endif() -add_llvm_loadable_module( LLVMHello +add_llvm_library( LLVMHello MODULE BUILDTREE_ONLY Hello.cpp DEPENDS diff --git a/lib/Transforms/IPO/AlwaysInliner.cpp b/lib/Transforms/IPO/AlwaysInliner.cpp index 3b735ddd192e..07138718ce2c 100644 --- a/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/lib/Transforms/IPO/AlwaysInliner.cpp @@ -150,7 +150,7 @@ InlineCost AlwaysInlinerLegacyPass::getInlineCost(CallSite CS) { // declarations. if (Callee && !Callee->isDeclaration() && CS.hasFnAttr(Attribute::AlwaysInline) && isInlineViable(*Callee)) - return InlineCost::getAlways(); + return InlineCost::getAlways("always inliner"); - return InlineCost::getNever(); + return InlineCost::getNever("always inliner"); } diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index f2c2b55b1c5b..4663de0b049e 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -49,6 +49,7 @@ #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -213,7 +214,8 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg()); // Create the new function body and insert it into the module. - Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName()); + Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace(), + F->getName()); NF->copyAttributesFrom(F); // Patch the pointer to LLVM function in debug info descriptor. @@ -808,6 +810,21 @@ static bool canPaddingBeAccessed(Argument *arg) { return false; } +static bool areFunctionArgsABICompatible( + const Function &F, const TargetTransformInfo &TTI, + SmallPtrSetImpl<Argument *> &ArgsToPromote, + SmallPtrSetImpl<Argument *> &ByValArgsToTransform) { + for (const Use &U : F.uses()) { + CallSite CS(U.getUser()); + const Function *Caller = CS.getCaller(); + const Function *Callee = CS.getCalledFunction(); + if (!TTI.areFunctionArgsABICompatible(Caller, Callee, ArgsToPromote) || + !TTI.areFunctionArgsABICompatible(Caller, Callee, ByValArgsToTransform)) + return false; + } + return true; +} + /// PromoteArguments - This method checks the specified function to see if there /// are any promotable arguments and if it is safe to promote the function (for /// example, all callers are direct). If safe to promote some arguments, it @@ -816,7 +833,8 @@ static Function * promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, unsigned MaxElements, Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>> - ReplaceCallSite) { + ReplaceCallSite, + const TargetTransformInfo &TTI) { // Don't perform argument promotion for naked functions; otherwise we can end // up removing parameters that are seemingly 'not used' as they are referred // to in the assembly. @@ -845,7 +863,7 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, // Second check: make sure that all callers are direct callers. We can't // transform functions that have indirect callers. Also see if the function - // is self-recursive. + // is self-recursive and check that target features are compatible. bool isSelfRecursive = false; for (Use &U : F->uses()) { CallSite CS(U.getUser()); @@ -954,6 +972,10 @@ promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, if (ArgsToPromote.empty() && ByValArgsToTransform.empty()) return nullptr; + if (!areFunctionArgsABICompatible(*F, TTI, ArgsToPromote, + ByValArgsToTransform)) + return nullptr; + return doPromotion(F, ArgsToPromote, ByValArgsToTransform, ReplaceCallSite); } @@ -979,7 +1001,9 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, return FAM.getResult<AAManager>(F); }; - Function *NewF = promoteArguments(&OldF, AARGetter, MaxElements, None); + const TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(OldF); + Function *NewF = + promoteArguments(&OldF, AARGetter, MaxElements, None, TTI); if (!NewF) continue; LocalChange = true; @@ -1017,6 +1041,7 @@ struct ArgPromotion : public CallGraphSCCPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); getAAResultsAnalysisUsage(AU); CallGraphSCCPass::getAnalysisUsage(AU); } @@ -1042,6 +1067,7 @@ INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(ArgPromotion, "argpromotion", "Promote 'by reference' arguments to scalars", false, false) @@ -1078,8 +1104,10 @@ bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) { CallerNode->replaceCallEdge(OldCS, NewCS, NewCalleeNode); }; + const TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(*OldF); if (Function *NewF = promoteArguments(OldF, AARGetter, MaxElements, - {ReplaceCallSite})) { + {ReplaceCallSite}, TTI)) { LocalChange = true; // Update the call graph for the newly promoted function. diff --git a/lib/Transforms/IPO/CMakeLists.txt b/lib/Transforms/IPO/CMakeLists.txt index 4772baf5976c..7e2bca0f8f80 100644 --- a/lib/Transforms/IPO/CMakeLists.txt +++ b/lib/Transforms/IPO/CMakeLists.txt @@ -15,6 +15,7 @@ add_llvm_library(LLVMipo GlobalDCE.cpp GlobalOpt.cpp GlobalSplit.cpp + HotColdSplitting.cpp IPConstantPropagation.cpp IPO.cpp InferFunctionAttrs.cpp diff --git a/lib/Transforms/IPO/CalledValuePropagation.cpp b/lib/Transforms/IPO/CalledValuePropagation.cpp index d642445b35de..de62cfc0c1db 100644 --- a/lib/Transforms/IPO/CalledValuePropagation.cpp +++ b/lib/Transforms/IPO/CalledValuePropagation.cpp @@ -345,6 +345,9 @@ private: void visitInst(Instruction &I, DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) { + // Simply bail if this instruction has no user. + if (I.use_empty()) + return; auto RegI = CVPLatticeKey(&I, IPOGrouping::Register); ChangedValues[RegI] = getOverdefinedVal(); } diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp index e0b1037053f0..81f3634eaf28 100644 --- a/lib/Transforms/IPO/ConstantMerge.cpp +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -40,7 +40,7 @@ using namespace llvm; #define DEBUG_TYPE "constmerge" -STATISTIC(NumMerged, "Number of global constants merged"); +STATISTIC(NumIdenticalMerged, "Number of identical global constants merged"); /// Find values that are marked as llvm.used. static void FindUsedValues(GlobalVariable *LLVMUsed, @@ -91,6 +91,37 @@ static unsigned getAlignment(GlobalVariable *GV) { return GV->getParent()->getDataLayout().getPreferredAlignment(GV); } +enum class CanMerge { No, Yes }; +static CanMerge makeMergeable(GlobalVariable *Old, GlobalVariable *New) { + if (!Old->hasGlobalUnnamedAddr() && !New->hasGlobalUnnamedAddr()) + return CanMerge::No; + if (hasMetadataOtherThanDebugLoc(Old)) + return CanMerge::No; + assert(!hasMetadataOtherThanDebugLoc(New)); + if (!Old->hasGlobalUnnamedAddr()) + New->setUnnamedAddr(GlobalValue::UnnamedAddr::None); + return CanMerge::Yes; +} + +static void replace(Module &M, GlobalVariable *Old, GlobalVariable *New) { + Constant *NewConstant = New; + + LLVM_DEBUG(dbgs() << "Replacing global: @" << Old->getName() << " -> @" + << New->getName() << "\n"); + + // Bump the alignment if necessary. + if (Old->getAlignment() || New->getAlignment()) + New->setAlignment(std::max(getAlignment(Old), getAlignment(New))); + + copyDebugLocMetadata(Old, New); + Old->replaceAllUsesWith(NewConstant); + + // Delete the global value from the module. + assert(Old->hasLocalLinkage() && + "Refusing to delete an externally visible global variable."); + Old->eraseFromParent(); +} + static bool mergeConstants(Module &M) { // Find all the globals that are marked "used". These cannot be merged. SmallPtrSet<const GlobalValue*, 8> UsedGlobals; @@ -100,17 +131,18 @@ static bool mergeConstants(Module &M) { // Map unique constants to globals. DenseMap<Constant *, GlobalVariable *> CMap; - // Replacements - This vector contains a list of replacements to perform. - SmallVector<std::pair<GlobalVariable*, GlobalVariable*>, 32> Replacements; + SmallVector<std::pair<GlobalVariable *, GlobalVariable *>, 32> + SameContentReplacements; - bool MadeChange = false; + size_t ChangesMade = 0; + size_t OldChangesMade = 0; // Iterate constant merging while we are still making progress. Merging two // constants together may allow us to merge other constants together if the // second level constants have initializers which point to the globals that // were just merged. while (true) { - // First: Find the canonical constants others will be merged with. + // Find the canonical constants others will be merged with. for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); GVI != E; ) { GlobalVariable *GV = &*GVI++; @@ -119,6 +151,7 @@ static bool mergeConstants(Module &M) { GV->removeDeadConstantUsers(); if (GV->use_empty() && GV->hasLocalLinkage()) { GV->eraseFromParent(); + ++ChangesMade; continue; } @@ -148,12 +181,16 @@ static bool mergeConstants(Module &M) { // If this is the first constant we find or if the old one is local, // replace with the current one. If the current is externally visible // it cannot be replace, but can be the canonical constant we merge with. - if (!Slot || IsBetterCanonical(*GV, *Slot)) + bool FirstConstantFound = !Slot; + if (FirstConstantFound || IsBetterCanonical(*GV, *Slot)) { Slot = GV; + LLVM_DEBUG(dbgs() << "Cmap[" << *Init << "] = " << GV->getName() + << (FirstConstantFound ? "\n" : " (updated)\n")); + } } - // Second: identify all globals that can be merged together, filling in - // the Replacements vector. We cannot do the replacement in this pass + // Identify all globals that can be merged together, filling in the + // SameContentReplacements vector. We cannot do the replacement in this pass // because doing so may cause initializers of other globals to be rewritten, // invalidating the Constant* pointers in CMap. for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); @@ -174,54 +211,43 @@ static bool mergeConstants(Module &M) { Constant *Init = GV->getInitializer(); // Check to see if the initializer is already known. - GlobalVariable *Slot = CMap[Init]; - - if (!Slot || Slot == GV) + auto Found = CMap.find(Init); + if (Found == CMap.end()) continue; - if (!Slot->hasGlobalUnnamedAddr() && !GV->hasGlobalUnnamedAddr()) + GlobalVariable *Slot = Found->second; + if (Slot == GV) continue; - if (hasMetadataOtherThanDebugLoc(GV)) + if (makeMergeable(GV, Slot) == CanMerge::No) continue; - if (!GV->hasGlobalUnnamedAddr()) - Slot->setUnnamedAddr(GlobalValue::UnnamedAddr::None); - // Make all uses of the duplicate constant use the canonical version. - Replacements.push_back(std::make_pair(GV, Slot)); + LLVM_DEBUG(dbgs() << "Will replace: @" << GV->getName() << " -> @" + << Slot->getName() << "\n"); + SameContentReplacements.push_back(std::make_pair(GV, Slot)); } - if (Replacements.empty()) - return MadeChange; - CMap.clear(); - // Now that we have figured out which replacements must be made, do them all // now. This avoid invalidating the pointers in CMap, which are unneeded // now. - for (unsigned i = 0, e = Replacements.size(); i != e; ++i) { - // Bump the alignment if necessary. - if (Replacements[i].first->getAlignment() || - Replacements[i].second->getAlignment()) { - Replacements[i].second->setAlignment( - std::max(getAlignment(Replacements[i].first), - getAlignment(Replacements[i].second))); - } - - copyDebugLocMetadata(Replacements[i].first, Replacements[i].second); - - // Eliminate any uses of the dead global. - Replacements[i].first->replaceAllUsesWith(Replacements[i].second); - - // Delete the global value from the module. - assert(Replacements[i].first->hasLocalLinkage() && - "Refusing to delete an externally visible global variable."); - Replacements[i].first->eraseFromParent(); + for (unsigned i = 0, e = SameContentReplacements.size(); i != e; ++i) { + GlobalVariable *Old = SameContentReplacements[i].first; + GlobalVariable *New = SameContentReplacements[i].second; + replace(M, Old, New); + ++ChangesMade; + ++NumIdenticalMerged; } - NumMerged += Replacements.size(); - Replacements.clear(); + if (ChangesMade == OldChangesMade) + break; + OldChangesMade = ChangesMade; + + SameContentReplacements.clear(); + CMap.clear(); } + + return ChangesMade; } PreservedAnalyses ConstantMergePass::run(Module &M, ModuleAnalysisManager &) { diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp index cd2bd734eb26..cb30e8f46a54 100644 --- a/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -24,7 +24,6 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" -#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -165,7 +164,7 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { unsigned NumArgs = Params.size(); // Create the new function body and insert it into the module... - Function *NF = Function::Create(NFTy, Fn.getLinkage()); + Function *NF = Function::Create(NFTy, Fn.getLinkage(), Fn.getAddressSpace()); NF->copyAttributesFrom(&Fn); NF->setComdat(Fn.getComdat()); Fn.getParent()->getFunctionList().insert(Fn.getIterator(), NF); @@ -289,16 +288,21 @@ bool DeadArgumentEliminationPass::RemoveDeadArgumentsFromCallers(Function &Fn) { return false; SmallVector<unsigned, 8> UnusedArgs; + bool Changed = false; + for (Argument &Arg : Fn.args()) { - if (!Arg.hasSwiftErrorAttr() && Arg.use_empty() && !Arg.hasByValOrInAllocaAttr()) + if (!Arg.hasSwiftErrorAttr() && Arg.use_empty() && !Arg.hasByValOrInAllocaAttr()) { + if (Arg.isUsedByMetadata()) { + Arg.replaceAllUsesWith(UndefValue::get(Arg.getType())); + Changed = true; + } UnusedArgs.push_back(Arg.getArgNo()); + } } if (UnusedArgs.empty()) return false; - bool Changed = false; - for (Use &U : Fn.uses()) { CallSite CS(U.getUser()); if (!CS || !CS.isCallee(&U)) @@ -859,7 +863,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { return false; // Create the new function body and insert it into the module... - Function *NF = Function::Create(NFTy, F->getLinkage()); + Function *NF = Function::Create(NFTy, F->getLinkage(), F->getAddressSpace()); NF->copyAttributesFrom(F); NF->setComdat(F->getComdat()); NF->setAttributes(NewPAL); @@ -949,16 +953,16 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { ArgAttrVec.clear(); Instruction *New = NewCS.getInstruction(); - if (!Call->use_empty()) { + if (!Call->use_empty() || Call->isUsedByMetadata()) { if (New->getType() == Call->getType()) { // Return type not changed? Just replace users then. Call->replaceAllUsesWith(New); New->takeName(Call); } else if (New->getType()->isVoidTy()) { - // Our return value has uses, but they will get removed later on. - // Replace by null for now. + // If the return value is dead, replace any uses of it with undef + // (any non-debug value uses will get removed later on). if (!Call->getType()->isX86_MMXTy()) - Call->replaceAllUsesWith(Constant::getNullValue(Call->getType())); + Call->replaceAllUsesWith(UndefValue::get(Call->getType())); } else { assert((RetTy->isStructTy() || RetTy->isArrayTy()) && "Return type changed, but not into a void. The old return type" @@ -1018,10 +1022,10 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { I2->takeName(&*I); ++I2; } else { - // If this argument is dead, replace any uses of it with null constants - // (these are guaranteed to become unused later on). + // If this argument is dead, replace any uses of it with undef + // (any non-debug value uses will get removed later on). if (!I->getType()->isX86_MMXTy()) - I->replaceAllUsesWith(Constant::getNullValue(I->getType())); + I->replaceAllUsesWith(UndefValue::get(I->getType())); } // If we change the return value of the function we must rewrite any return diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp index d45a88323910..a744d7f2d2d9 100644 --- a/lib/Transforms/IPO/ExtractGV.cpp +++ b/lib/Transforms/IPO/ExtractGV.cpp @@ -135,6 +135,7 @@ namespace { llvm::Value *Declaration; if (FunctionType *FTy = dyn_cast<FunctionType>(Ty)) { Declaration = Function::Create(FTy, GlobalValue::ExternalLinkage, + CurI->getAddressSpace(), CurI->getName(), &M); } else { diff --git a/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 37273f975417..4dc1529ddbf5 100644 --- a/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -58,6 +58,7 @@ static Attribute::AttrKind parseAttrKind(StringRef Kind) { .Case("sanitize_hwaddress", Attribute::SanitizeHWAddress) .Case("sanitize_memory", Attribute::SanitizeMemory) .Case("sanitize_thread", Attribute::SanitizeThread) + .Case("speculative_load_hardening", Attribute::SpeculativeLoadHardening) .Case("ssp", Attribute::StackProtect) .Case("sspreq", Attribute::StackProtectReq) .Case("sspstrong", Attribute::StackProtectStrong) diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 010b0a29807d..4e2a82b56eec 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -41,6 +41,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -66,6 +67,7 @@ using namespace llvm; STATISTIC(NumReadNone, "Number of functions marked readnone"); STATISTIC(NumReadOnly, "Number of functions marked readonly"); +STATISTIC(NumWriteOnly, "Number of functions marked writeonly"); STATISTIC(NumNoCapture, "Number of arguments marked nocapture"); STATISTIC(NumReturned, "Number of arguments marked returned"); STATISTIC(NumReadNoneArg, "Number of arguments marked readnone"); @@ -113,27 +115,30 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, if (AliasAnalysis::onlyReadsMemory(MRB)) return MAK_ReadOnly; - // Conservatively assume it writes to memory. + if (AliasAnalysis::doesNotReadMemory(MRB)) + return MAK_WriteOnly; + + // Conservatively assume it reads and writes to memory. return MAK_MayWrite; } // Scan the function body for instructions that may read or write memory. bool ReadsMemory = false; + bool WritesMemory = false; for (inst_iterator II = inst_begin(F), E = inst_end(F); II != E; ++II) { Instruction *I = &*II; // Some instructions can be ignored even if they read or write memory. // Detect these now, skipping to the next instruction if one is found. - CallSite CS(cast<Value>(I)); - if (CS) { + if (auto *Call = dyn_cast<CallBase>(I)) { // Ignore calls to functions in the same SCC, as long as the call sites // don't have operand bundles. Calls with operand bundles are allowed to // have memory effects not described by the memory effects of the call // target. - if (!CS.hasOperandBundles() && CS.getCalledFunction() && - SCCNodes.count(CS.getCalledFunction())) + if (!Call->hasOperandBundles() && Call->getCalledFunction() && + SCCNodes.count(Call->getCalledFunction())) continue; - FunctionModRefBehavior MRB = AAR.getModRefBehavior(CS); + FunctionModRefBehavior MRB = AAR.getModRefBehavior(Call); ModRefInfo MRI = createModRefInfo(MRB); // If the call doesn't access memory, we're done. @@ -141,9 +146,9 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, continue; if (!AliasAnalysis::onlyAccessesArgPointees(MRB)) { - // The call could access any memory. If that includes writes, give up. + // The call could access any memory. If that includes writes, note it. if (isModSet(MRI)) - return MAK_MayWrite; + WritesMemory = true; // If it reads, note it. if (isRefSet(MRI)) ReadsMemory = true; @@ -152,7 +157,7 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, // Check whether all pointer arguments point to local memory, and // ignore calls that only access local memory. - for (CallSite::arg_iterator CI = CS.arg_begin(), CE = CS.arg_end(); + for (CallSite::arg_iterator CI = Call->arg_begin(), CE = Call->arg_end(); CI != CE; ++CI) { Value *Arg = *CI; if (!Arg->getType()->isPtrOrPtrVectorTy()) @@ -160,7 +165,7 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, AAMDNodes AAInfo; I->getAAMetadata(AAInfo); - MemoryLocation Loc(Arg, MemoryLocation::UnknownSize, AAInfo); + MemoryLocation Loc(Arg, LocationSize::unknown(), AAInfo); // Skip accesses to local or constant memory as they don't impact the // externally visible mod/ref behavior. @@ -168,8 +173,8 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, continue; if (isModSet(MRI)) - // Writes non-local memory. Give up. - return MAK_MayWrite; + // Writes non-local memory. + WritesMemory = true; if (isRefSet(MRI)) // Ok, it reads non-local memory. ReadsMemory = true; @@ -198,14 +203,21 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, // Any remaining instructions need to be taken seriously! Check if they // read or write memory. - if (I->mayWriteToMemory()) - // Writes memory. Just give up. - return MAK_MayWrite; + // + // Writes memory, remember that. + WritesMemory |= I->mayWriteToMemory(); // If this instruction may read memory, remember that. ReadsMemory |= I->mayReadFromMemory(); } + if (WritesMemory) { + if (!ReadsMemory) + return MAK_WriteOnly; + else + return MAK_MayWrite; + } + return ReadsMemory ? MAK_ReadOnly : MAK_ReadNone; } @@ -220,6 +232,7 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { // Check if any of the functions in the SCC read or write memory. If they // write memory then they can't be marked readnone or readonly. bool ReadsMemory = false; + bool WritesMemory = false; for (Function *F : SCCNodes) { // Call the callable parameter to look up AA results for this function. AAResults &AAR = AARGetter(*F); @@ -234,6 +247,9 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { case MAK_ReadOnly: ReadsMemory = true; break; + case MAK_WriteOnly: + WritesMemory = true; + break; case MAK_ReadNone: // Nothing to do! break; @@ -243,6 +259,9 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { // Success! Functions in this SCC do not access memory, or only read memory. // Give them the appropriate attribute. bool MadeChange = false; + + assert(!(ReadsMemory && WritesMemory) && + "Function marked read-only and write-only"); for (Function *F : SCCNodes) { if (F->doesNotAccessMemory()) // Already perfect! @@ -252,16 +271,32 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { // No change. continue; + if (F->doesNotReadMemory() && WritesMemory) + continue; + MadeChange = true; // Clear out any existing attributes. F->removeFnAttr(Attribute::ReadOnly); F->removeFnAttr(Attribute::ReadNone); + F->removeFnAttr(Attribute::WriteOnly); + + if (!WritesMemory && !ReadsMemory) { + // Clear out any "access range attributes" if readnone was deduced. + F->removeFnAttr(Attribute::ArgMemOnly); + F->removeFnAttr(Attribute::InaccessibleMemOnly); + F->removeFnAttr(Attribute::InaccessibleMemOrArgMemOnly); + } // Add in the new attribute. - F->addFnAttr(ReadsMemory ? Attribute::ReadOnly : Attribute::ReadNone); + if (WritesMemory && !ReadsMemory) + F->addFnAttr(Attribute::WriteOnly); + else + F->addFnAttr(ReadsMemory ? Attribute::ReadOnly : Attribute::ReadNone); - if (ReadsMemory) + if (WritesMemory && !ReadsMemory) + ++NumWriteOnly; + else if (ReadsMemory) ++NumReadOnly; else ++NumReadNone; @@ -1272,13 +1307,14 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { // If all of the calls in F are identifiable and are to norecurse functions, F // is norecurse. This check also detects self-recursion as F is not currently // marked norecurse, so any called from F to F will not be marked norecurse. - for (Instruction &I : instructions(*F)) - if (auto CS = CallSite(&I)) { - Function *Callee = CS.getCalledFunction(); - if (!Callee || Callee == F || !Callee->doesNotRecurse()) - // Function calls a potentially recursive function. - return false; - } + for (auto &BB : *F) + for (auto &I : BB.instructionsWithoutDebug()) + if (auto CS = CallSite(&I)) { + Function *Callee = CS.getCalledFunction(); + if (!Callee || Callee == F || !Callee->doesNotRecurse()) + // Function calls a potentially recursive function. + return false; + } // Every call was to a non-recursive function other than this function, and // we have no indirect recursion as the SCC size is one. This function cannot @@ -1286,6 +1322,31 @@ static bool addNoRecurseAttrs(const SCCNodeSet &SCCNodes) { return setDoesNotRecurse(*F); } +template <typename AARGetterT> +static bool deriveAttrsInPostOrder(SCCNodeSet &SCCNodes, AARGetterT &&AARGetter, + bool HasUnknownCall) { + bool Changed = false; + + // Bail if the SCC only contains optnone functions. + if (SCCNodes.empty()) + return Changed; + + Changed |= addArgumentReturnedAttrs(SCCNodes); + Changed |= addReadAttrs(SCCNodes, AARGetter); + Changed |= addArgumentAttrs(SCCNodes); + + // If we have no external nodes participating in the SCC, we can deduce some + // more precise attributes as well. + if (!HasUnknownCall) { + Changed |= addNoAliasAttrs(SCCNodes); + Changed |= addNonNullAttrs(SCCNodes); + Changed |= inferAttrsFromFunctionBodies(SCCNodes); + Changed |= addNoRecurseAttrs(SCCNodes); + } + + return Changed; +} + PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, @@ -1328,21 +1389,10 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, SCCNodes.insert(&F); } - bool Changed = false; - Changed |= addArgumentReturnedAttrs(SCCNodes); - Changed |= addReadAttrs(SCCNodes, AARGetter); - Changed |= addArgumentAttrs(SCCNodes); - - // If we have no external nodes participating in the SCC, we can deduce some - // more precise attributes as well. - if (!HasUnknownCall) { - Changed |= addNoAliasAttrs(SCCNodes); - Changed |= addNonNullAttrs(SCCNodes); - Changed |= inferAttrsFromFunctionBodies(SCCNodes); - Changed |= addNoRecurseAttrs(SCCNodes); - } + if (deriveAttrsInPostOrder(SCCNodes, AARGetter, HasUnknownCall)) + return PreservedAnalyses::none(); - return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); + return PreservedAnalyses::all(); } namespace { @@ -1382,7 +1432,6 @@ Pass *llvm::createPostOrderFunctionAttrsLegacyPass() { template <typename AARGetterT> static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { - bool Changed = false; // Fill SCCNodes with the elements of the SCC. Used for quickly looking up // whether a given CallGraphNode is in this SCC. Also track whether there are @@ -1403,24 +1452,7 @@ static bool runImpl(CallGraphSCC &SCC, AARGetterT AARGetter) { SCCNodes.insert(F); } - // Skip it if the SCC only contains optnone functions. - if (SCCNodes.empty()) - return Changed; - - Changed |= addArgumentReturnedAttrs(SCCNodes); - Changed |= addReadAttrs(SCCNodes, AARGetter); - Changed |= addArgumentAttrs(SCCNodes); - - // If we have no external nodes participating in the SCC, we can deduce some - // more precise attributes as well. - if (!ExternalNode) { - Changed |= addNoAliasAttrs(SCCNodes); - Changed |= addNonNullAttrs(SCCNodes); - Changed |= inferAttrsFromFunctionBodies(SCCNodes); - Changed |= addNoRecurseAttrs(SCCNodes); - } - - return Changed; + return deriveAttrsInPostOrder(SCCNodes, AARGetter, ExternalNode); } bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) { diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp index 15808a073894..1223a23512ed 100644 --- a/lib/Transforms/IPO/FunctionImport.cpp +++ b/lib/Transforms/IPO/FunctionImport.cpp @@ -60,8 +60,17 @@ using namespace llvm; #define DEBUG_TYPE "function-import" -STATISTIC(NumImportedFunctions, "Number of functions imported"); -STATISTIC(NumImportedGlobalVars, "Number of global variables imported"); +STATISTIC(NumImportedFunctionsThinLink, + "Number of functions thin link decided to import"); +STATISTIC(NumImportedHotFunctionsThinLink, + "Number of hot functions thin link decided to import"); +STATISTIC(NumImportedCriticalFunctionsThinLink, + "Number of critical functions thin link decided to import"); +STATISTIC(NumImportedGlobalVarsThinLink, + "Number of global variables thin link decided to import"); +STATISTIC(NumImportedFunctions, "Number of functions imported in backend"); +STATISTIC(NumImportedGlobalVars, + "Number of global variables imported in backend"); STATISTIC(NumImportedModules, "Number of modules imported from"); STATISTIC(NumDeadSymbols, "Number of dead stripped symbols in index"); STATISTIC(NumLiveSymbols, "Number of live symbols in index"); @@ -107,6 +116,10 @@ static cl::opt<float> ImportColdMultiplier( static cl::opt<bool> PrintImports("print-imports", cl::init(false), cl::Hidden, cl::desc("Print imported functions")); +static cl::opt<bool> PrintImportFailures( + "print-import-failures", cl::init(false), cl::Hidden, + cl::desc("Print information for functions rejected for importing")); + static cl::opt<bool> ComputeDead("compute-dead", cl::init(true), cl::Hidden, cl::desc("Compute dead symbols")); @@ -163,13 +176,18 @@ static std::unique_ptr<Module> loadFile(const std::string &FileName, static const GlobalValueSummary * selectCallee(const ModuleSummaryIndex &Index, ArrayRef<std::unique_ptr<GlobalValueSummary>> CalleeSummaryList, - unsigned Threshold, StringRef CallerModulePath) { + unsigned Threshold, StringRef CallerModulePath, + FunctionImporter::ImportFailureReason &Reason, + GlobalValue::GUID GUID) { + Reason = FunctionImporter::ImportFailureReason::None; auto It = llvm::find_if( CalleeSummaryList, [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { auto *GVSummary = SummaryPtr.get(); - if (!Index.isGlobalValueLive(GVSummary)) + if (!Index.isGlobalValueLive(GVSummary)) { + Reason = FunctionImporter::ImportFailureReason::NotLive; return false; + } // For SamplePGO, in computeImportForFunction the OriginalId // may have been used to locate the callee summary list (See @@ -184,11 +202,15 @@ selectCallee(const ModuleSummaryIndex &Index, // When this happens, the logic for SamplePGO kicks in and // the static variable in 2) will be found, which needs to be // filtered out. - if (GVSummary->getSummaryKind() == GlobalValueSummary::GlobalVarKind) + if (GVSummary->getSummaryKind() == GlobalValueSummary::GlobalVarKind) { + Reason = FunctionImporter::ImportFailureReason::GlobalVar; return false; - if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) + } + if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) { + Reason = FunctionImporter::ImportFailureReason::InterposableLinkage; // There is no point in importing these, we can't inline them return false; + } auto *Summary = cast<FunctionSummary>(GVSummary->getBaseObject()); @@ -204,14 +226,29 @@ selectCallee(const ModuleSummaryIndex &Index, // a local in another module. if (GlobalValue::isLocalLinkage(Summary->linkage()) && CalleeSummaryList.size() > 1 && - Summary->modulePath() != CallerModulePath) + Summary->modulePath() != CallerModulePath) { + Reason = + FunctionImporter::ImportFailureReason::LocalLinkageNotInModule; return false; + } - if (Summary->instCount() > Threshold) + if (Summary->instCount() > Threshold) { + Reason = FunctionImporter::ImportFailureReason::TooLarge; return false; + } - if (Summary->notEligibleToImport()) + // Skip if it isn't legal to import (e.g. may reference unpromotable + // locals). + if (Summary->notEligibleToImport()) { + Reason = FunctionImporter::ImportFailureReason::NotEligible; return false; + } + + // Don't bother importing if we can't inline it anyway. + if (Summary->fflags().NoInline) { + Reason = FunctionImporter::ImportFailureReason::NoInline; + return false; + } return true; }); @@ -256,13 +293,25 @@ static void computeImportForReferencedGlobals( LLVM_DEBUG(dbgs() << " ref -> " << VI << "\n"); + // If this is a local variable, make sure we import the copy + // in the caller's module. The only time a local variable can + // share an entry in the index is if there is a local with the same name + // in another module that had the same source file name (in a different + // directory), where each was compiled in their own directory so there + // was not distinguishing path. + auto LocalNotInModule = [&](const GlobalValueSummary *RefSummary) -> bool { + return GlobalValue::isLocalLinkage(RefSummary->linkage()) && + RefSummary->modulePath() != Summary.modulePath(); + }; + for (auto &RefSummary : VI.getSummaryList()) - if (RefSummary->getSummaryKind() == GlobalValueSummary::GlobalVarKind && - // Don't try to import regular LTO summaries added to dummy module. - !RefSummary->modulePath().empty() && - !GlobalValue::isInterposableLinkage(RefSummary->linkage()) && - RefSummary->refs().empty()) { - ImportList[RefSummary->modulePath()].insert(VI.getGUID()); + if (isa<GlobalVarSummary>(RefSummary.get()) && + canImportGlobalVar(RefSummary.get()) && + !LocalNotInModule(RefSummary.get())) { + auto ILI = ImportList[RefSummary->modulePath()].insert(VI.getGUID()); + // Only update stat if we haven't already imported this variable. + if (ILI.second) + NumImportedGlobalVarsThinLink++; if (ExportLists) (*ExportLists)[RefSummary->modulePath()].insert(VI.getGUID()); break; @@ -270,6 +319,29 @@ static void computeImportForReferencedGlobals( } } +static const char * +getFailureName(FunctionImporter::ImportFailureReason Reason) { + switch (Reason) { + case FunctionImporter::ImportFailureReason::None: + return "None"; + case FunctionImporter::ImportFailureReason::GlobalVar: + return "GlobalVar"; + case FunctionImporter::ImportFailureReason::NotLive: + return "NotLive"; + case FunctionImporter::ImportFailureReason::TooLarge: + return "TooLarge"; + case FunctionImporter::ImportFailureReason::InterposableLinkage: + return "InterposableLinkage"; + case FunctionImporter::ImportFailureReason::LocalLinkageNotInModule: + return "LocalLinkageNotInModule"; + case FunctionImporter::ImportFailureReason::NotEligible: + return "NotEligible"; + case FunctionImporter::ImportFailureReason::NoInline: + return "NoInline"; + } + llvm_unreachable("invalid reason"); +} + /// Compute the list of functions to import for a given caller. Mark these /// imported functions and the symbols they reference in their source module as /// exported from their source module. @@ -316,11 +388,17 @@ static void computeImportForFunction( const auto NewThreshold = Threshold * GetBonusMultiplier(Edge.second.getHotness()); - auto IT = ImportThresholds.insert( - std::make_pair(VI.getGUID(), std::make_pair(NewThreshold, nullptr))); + auto IT = ImportThresholds.insert(std::make_pair( + VI.getGUID(), std::make_tuple(NewThreshold, nullptr, nullptr))); bool PreviouslyVisited = !IT.second; - auto &ProcessedThreshold = IT.first->second.first; - auto &CalleeSummary = IT.first->second.second; + auto &ProcessedThreshold = std::get<0>(IT.first->second); + auto &CalleeSummary = std::get<1>(IT.first->second); + auto &FailureInfo = std::get<2>(IT.first->second); + + bool IsHotCallsite = + Edge.second.getHotness() == CalleeInfo::HotnessType::Hot; + bool IsCriticalCallsite = + Edge.second.getHotness() == CalleeInfo::HotnessType::Critical; const FunctionSummary *ResolvedCalleeSummary = nullptr; if (CalleeSummary) { @@ -345,16 +423,37 @@ static void computeImportForFunction( LLVM_DEBUG( dbgs() << "ignored! Target was already rejected with Threshold " << ProcessedThreshold << "\n"); + if (PrintImportFailures) { + assert(FailureInfo && + "Expected FailureInfo for previously rejected candidate"); + FailureInfo->Attempts++; + } continue; } + FunctionImporter::ImportFailureReason Reason; CalleeSummary = selectCallee(Index, VI.getSummaryList(), NewThreshold, - Summary.modulePath()); + Summary.modulePath(), Reason, VI.getGUID()); if (!CalleeSummary) { // Update with new larger threshold if this was a retry (otherwise - // we would have already inserted with NewThreshold above). - if (PreviouslyVisited) + // we would have already inserted with NewThreshold above). Also + // update failure info if requested. + if (PreviouslyVisited) { ProcessedThreshold = NewThreshold; + if (PrintImportFailures) { + assert(FailureInfo && + "Expected FailureInfo for previously rejected candidate"); + FailureInfo->Reason = Reason; + FailureInfo->Attempts++; + FailureInfo->MaxHotness = + std::max(FailureInfo->MaxHotness, Edge.second.getHotness()); + } + } else if (PrintImportFailures) { + assert(!FailureInfo && + "Expected no FailureInfo for newly rejected candidate"); + FailureInfo = llvm::make_unique<FunctionImporter::ImportFailureInfo>( + VI, Edge.second.getHotness(), Reason, 1); + } LLVM_DEBUG( dbgs() << "ignored! No qualifying callee with summary found.\n"); continue; @@ -372,6 +471,13 @@ static void computeImportForFunction( // We previously decided to import this GUID definition if it was already // inserted in the set of imports from the exporting module. bool PreviouslyImported = !ILI.second; + if (!PreviouslyImported) { + NumImportedFunctionsThinLink++; + if (IsHotCallsite) + NumImportedHotFunctionsThinLink++; + if (IsCriticalCallsite) + NumImportedCriticalFunctionsThinLink++; + } // Make exports in the source module. if (ExportLists) { @@ -405,8 +511,6 @@ static void computeImportForFunction( return Threshold * ImportInstrFactor; }; - bool IsHotCallsite = - Edge.second.getHotness() == CalleeInfo::HotnessType::Hot; const auto AdjThreshold = GetAdjustedThreshold(Threshold, IsHotCallsite); ImportCount++; @@ -421,7 +525,7 @@ static void computeImportForFunction( /// another module (that may require promotion). static void ComputeImportForModule( const GVSummaryMapTy &DefinedGVSummaries, const ModuleSummaryIndex &Index, - FunctionImporter::ImportMapTy &ImportList, + StringRef ModName, FunctionImporter::ImportMapTy &ImportList, StringMap<FunctionImporter::ExportSetTy> *ExportLists = nullptr) { // Worklist contains the list of function imported in this module, for which // we will analyse the callees and may import further down the callgraph. @@ -461,6 +565,30 @@ static void ComputeImportForModule( Worklist, ImportList, ExportLists, ImportThresholds); } + + // Print stats about functions considered but rejected for importing + // when requested. + if (PrintImportFailures) { + dbgs() << "Missed imports into module " << ModName << "\n"; + for (auto &I : ImportThresholds) { + auto &ProcessedThreshold = std::get<0>(I.second); + auto &CalleeSummary = std::get<1>(I.second); + auto &FailureInfo = std::get<2>(I.second); + if (CalleeSummary) + continue; // We are going to import. + assert(FailureInfo); + FunctionSummary *FS = nullptr; + if (!FailureInfo->VI.getSummaryList().empty()) + FS = dyn_cast<FunctionSummary>( + FailureInfo->VI.getSummaryList()[0]->getBaseObject()); + dbgs() << FailureInfo->VI + << ": Reason = " << getFailureName(FailureInfo->Reason) + << ", Threshold = " << ProcessedThreshold + << ", Size = " << (FS ? (int)FS->instCount() : -1) + << ", MaxHotness = " << getHotnessName(FailureInfo->MaxHotness) + << ", Attempts = " << FailureInfo->Attempts << "\n"; + } + } } #ifndef NDEBUG @@ -498,7 +626,8 @@ void llvm::ComputeCrossModuleImport( auto &ImportList = ImportLists[DefinedGVSummaries.first()]; LLVM_DEBUG(dbgs() << "Computing import for Module '" << DefinedGVSummaries.first() << "'\n"); - ComputeImportForModule(DefinedGVSummaries.second, Index, ImportList, + ComputeImportForModule(DefinedGVSummaries.second, Index, + DefinedGVSummaries.first(), ImportList, &ExportLists); } @@ -569,7 +698,7 @@ void llvm::ComputeCrossModuleImportForModule( // Compute the import list for this module. LLVM_DEBUG(dbgs() << "Computing import for Module '" << ModulePath << "'\n"); - ComputeImportForModule(FunctionSummaryMap, Index, ImportList); + ComputeImportForModule(FunctionSummaryMap, Index, ModulePath, ImportList); #ifndef NDEBUG dumpImportListForModule(Index, ModulePath, ImportList); @@ -648,29 +777,38 @@ void llvm::computeDeadSymbols( VI = updateValueInfoForIndirectCalls(Index, VI); if (!VI) return; - for (auto &S : VI.getSummaryList()) - if (S->isLive()) - return; + + // We need to make sure all variants of the symbol are scanned, alias can + // make one (but not all) alive. + if (llvm::all_of(VI.getSummaryList(), + [](const std::unique_ptr<llvm::GlobalValueSummary> &S) { + return S->isLive(); + })) + return; // We only keep live symbols that are known to be non-prevailing if any are - // available_externally. Those symbols are discarded later in the - // EliminateAvailableExternally pass and setting them to not-live breaks - // downstreams users of liveness information (PR36483). + // available_externally, linkonceodr, weakodr. Those symbols are discarded + // later in the EliminateAvailableExternally pass and setting them to + // not-live could break downstreams users of liveness information (PR36483) + // or limit optimization opportunities. if (isPrevailing(VI.getGUID()) == PrevailingType::No) { - bool AvailableExternally = false; + bool KeepAliveLinkage = false; bool Interposable = false; for (auto &S : VI.getSummaryList()) { - if (S->linkage() == GlobalValue::AvailableExternallyLinkage) - AvailableExternally = true; + if (S->linkage() == GlobalValue::AvailableExternallyLinkage || + S->linkage() == GlobalValue::WeakODRLinkage || + S->linkage() == GlobalValue::LinkOnceODRLinkage) + KeepAliveLinkage = true; else if (GlobalValue::isInterposableLinkage(S->linkage())) Interposable = true; } - if (!AvailableExternally) + if (!KeepAliveLinkage) return; if (Interposable) - report_fatal_error("Interposable and available_externally symbol"); + report_fatal_error( + "Interposable and available_externally/linkonce_odr/weak_odr symbol"); } for (auto &S : VI.getSummaryList()) @@ -701,6 +839,25 @@ void llvm::computeDeadSymbols( NumLiveSymbols += LiveSymbols; } +// Compute dead symbols and propagate constants in combined index. +void llvm::computeDeadSymbolsWithConstProp( + ModuleSummaryIndex &Index, + const DenseSet<GlobalValue::GUID> &GUIDPreservedSymbols, + function_ref<PrevailingType(GlobalValue::GUID)> isPrevailing, + bool ImportEnabled) { + computeDeadSymbols(Index, GUIDPreservedSymbols, isPrevailing); + if (ImportEnabled) { + Index.propagateConstants(GUIDPreservedSymbols); + } else { + // If import is disabled we should drop read-only attribute + // from all summaries to prevent internalization. + for (auto &P : Index) + for (auto &S : P.second.SummaryList) + if (auto *GVS = dyn_cast<GlobalVarSummary>(S.get())) + GVS->setReadOnly(false); + } +} + /// Compute the set of summaries needed for a ThinLTO backend compilation of /// \p ModulePath. void llvm::gatherImportedSummariesForModule( @@ -759,7 +916,8 @@ bool llvm::convertToDeclaration(GlobalValue &GV) { if (GV.getValueType()->isFunctionTy()) NewGV = Function::Create(cast<FunctionType>(GV.getValueType()), - GlobalValue::ExternalLinkage, "", GV.getParent()); + GlobalValue::ExternalLinkage, GV.getAddressSpace(), + "", GV.getParent()); else NewGV = new GlobalVariable(*GV.getParent(), GV.getValueType(), @@ -774,8 +932,8 @@ bool llvm::convertToDeclaration(GlobalValue &GV) { return true; } -/// Fixup WeakForLinker linkages in \p TheModule based on summary analysis. -void llvm::thinLTOResolveWeakForLinkerModule( +/// Fixup prevailing symbol linkages in \p TheModule based on summary analysis. +void llvm::thinLTOResolvePrevailingInModule( Module &TheModule, const GVSummaryMapTy &DefinedGlobals) { auto updateLinkage = [&](GlobalValue &GV) { // See if the global summary analysis computed a new resolved linkage. @@ -792,13 +950,15 @@ void llvm::thinLTOResolveWeakForLinkerModule( // as we need access to the resolution vectors for each input file in // order to find which symbols have been redefined. // We may consider reorganizing this code and moving the linkage recording - // somewhere else, e.g. in thinLTOResolveWeakForLinkerInIndex. + // somewhere else, e.g. in thinLTOResolvePrevailingInIndex. if (NewLinkage == GlobalValue::WeakAnyLinkage) { GV.setLinkage(NewLinkage); return; } - if (!GlobalValue::isWeakForLinker(GV.getLinkage())) + if (GlobalValue::isLocalLinkage(GV.getLinkage()) || + // In case it was dead and already converted to declaration. + GV.isDeclaration()) return; // Check for a non-prevailing def that has interposable linkage // (e.g. non-odr weak or linkonce). In that case we can't simply @@ -809,7 +969,7 @@ void llvm::thinLTOResolveWeakForLinkerModule( GlobalValue::isInterposableLinkage(GV.getLinkage())) { if (!convertToDeclaration(GV)) // FIXME: Change this to collect replaced GVs and later erase - // them from the parent module once thinLTOResolveWeakForLinkerGUID is + // them from the parent module once thinLTOResolvePrevailingGUID is // changed to enable this for aliases. llvm_unreachable("Expected GV to be converted"); } else { @@ -895,6 +1055,18 @@ static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { return NewFn; } +// Internalize values that we marked with specific attribute +// in processGlobalForThinLTO. +static void internalizeImmutableGVs(Module &M) { + for (auto &GV : M.globals()) + // Skip GVs which have been converted to declarations + // by dropDeadSymbols. + if (!GV.isDeclaration() && GV.hasAttribute("thinlto-internalize")) { + GV.setLinkage(GlobalValue::InternalLinkage); + GV.setVisibility(GlobalValue::DefaultVisibility); + } +} + // Automatically import functions in Module \p DestModule based on the summaries // index. Expected<bool> FunctionImporter::importFunctions( @@ -1018,6 +1190,8 @@ Expected<bool> FunctionImporter::importFunctions( NumImportedModules++; } + internalizeImmutableGVs(DestModule); + NumImportedFunctions += (ImportedCount - ImportedGVCount); NumImportedGlobalVars += ImportedGVCount; diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp index ada9eb80e680..34de87433367 100644 --- a/lib/Transforms/IPO/GlobalDCE.cpp +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" @@ -75,13 +76,17 @@ ModulePass *llvm::createGlobalDCEPass() { return new GlobalDCELegacyPass(); } -/// Returns true if F contains only a single "ret" instruction. +/// Returns true if F is effectively empty. static bool isEmptyFunction(Function *F) { BasicBlock &Entry = F->getEntryBlock(); - if (Entry.size() != 1 || !isa<ReturnInst>(Entry.front())) - return false; - ReturnInst &RI = cast<ReturnInst>(Entry.front()); - return RI.getReturnValue() == nullptr; + for (auto &I : Entry) { + if (isa<DbgInfoIntrinsic>(I)) + continue; + if (auto *RI = dyn_cast<ReturnInst>(&I)) + return !RI->getReturnValue(); + break; + } + return false; } /// Compute the set of GlobalValue that depends from V. @@ -165,7 +170,7 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { // Functions with external linkage are needed if they have a body. // Externally visible & appending globals are needed, if they have an // initializer. - if (!GO.isDeclaration() && !GO.hasAvailableExternallyLinkage()) + if (!GO.isDeclaration()) if (!GO.isDiscardableIfUnused()) MarkLive(GO); diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index 1761d7faff57..3005aafd06b1 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -1710,19 +1710,25 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { assert(isa<LoadInst>(StoreVal) && "Not a load of NewGV!"); } } - new StoreInst(StoreVal, NewGV, false, 0, - SI->getOrdering(), SI->getSyncScopeID(), SI); + StoreInst *NSI = + new StoreInst(StoreVal, NewGV, false, 0, SI->getOrdering(), + SI->getSyncScopeID(), SI); + NSI->setDebugLoc(SI->getDebugLoc()); } else { // Change the load into a load of bool then a select. LoadInst *LI = cast<LoadInst>(UI); LoadInst *NLI = new LoadInst(NewGV, LI->getName()+".b", false, 0, LI->getOrdering(), LI->getSyncScopeID(), LI); - Value *NSI; + Instruction *NSI; if (IsOneZero) NSI = new ZExtInst(NLI, LI->getType(), "", LI); else NSI = SelectInst::Create(NLI, OtherVal, InitVal, "", LI); NSI->takeName(LI); + // Since LI is split into two instructions, NLI and NSI both inherit the + // same DebugLoc + NLI->setDebugLoc(LI->getDebugLoc()); + NSI->setDebugLoc(LI->getDebugLoc()); LI->replaceAllUsesWith(NSI); } UI->eraseFromParent(); @@ -2107,6 +2113,13 @@ static bool hasChangeableCC(Function *F) { if (CC != CallingConv::C && CC != CallingConv::X86_ThisCall) return false; + // Don't break the invariant that the inalloca parameter is the only parameter + // passed in memory. + // FIXME: GlobalOpt should remove inalloca when possible and hoist the dynamic + // alloca it uses to the entry block if possible. + if (F->getAttributes().hasAttrSomewhere(Attribute::InAlloca)) + return false; + // FIXME: Change CC for the whole chain of musttail calls when possible. // // Can't change CC of the function that either has musttail calls, or is a diff --git a/lib/Transforms/IPO/HotColdSplitting.cpp b/lib/Transforms/IPO/HotColdSplitting.cpp new file mode 100644 index 000000000000..924a7d5fbd9c --- /dev/null +++ b/lib/Transforms/IPO/HotColdSplitting.cpp @@ -0,0 +1,643 @@ +//===- HotColdSplitting.cpp -- Outline Cold Regions -------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Outline cold regions to a separate function. +// TODO: Update BFI and BPI +// TODO: Add all the outlined functions to a separate section. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/IPO/HotColdSplitting.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/CodeExtractor.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> + +#define DEBUG_TYPE "hotcoldsplit" + +STATISTIC(NumColdRegionsFound, "Number of cold regions found."); +STATISTIC(NumColdRegionsOutlined, "Number of cold regions outlined."); + +using namespace llvm; + +static cl::opt<bool> EnableStaticAnalyis("hot-cold-static-analysis", + cl::init(true), cl::Hidden); + +static cl::opt<int> + MinOutliningThreshold("min-outlining-thresh", cl::init(3), cl::Hidden, + cl::desc("Code size threshold for outlining within a " + "single BB (as a multiple of TCC_Basic)")); + +namespace { + +struct PostDomTree : PostDomTreeBase<BasicBlock> { + PostDomTree(Function &F) { recalculate(F); } +}; + +/// A sequence of basic blocks. +/// +/// A 0-sized SmallVector is slightly cheaper to move than a std::vector. +using BlockSequence = SmallVector<BasicBlock *, 0>; + +// Same as blockEndsInUnreachable in CodeGen/BranchFolding.cpp. Do not modify +// this function unless you modify the MBB version as well. +// +/// A no successor, non-return block probably ends in unreachable and is cold. +/// Also consider a block that ends in an indirect branch to be a return block, +/// since many targets use plain indirect branches to return. +bool blockEndsInUnreachable(const BasicBlock &BB) { + if (!succ_empty(&BB)) + return false; + if (BB.empty()) + return true; + const Instruction *I = BB.getTerminator(); + return !(isa<ReturnInst>(I) || isa<IndirectBrInst>(I)); +} + +bool unlikelyExecuted(BasicBlock &BB) { + // Exception handling blocks are unlikely executed. + if (BB.isEHPad()) + return true; + + // The block is cold if it calls/invokes a cold function. + for (Instruction &I : BB) + if (auto CS = CallSite(&I)) + if (CS.hasFnAttr(Attribute::Cold)) + return true; + + // The block is cold if it has an unreachable terminator, unless it's + // preceded by a call to a (possibly warm) noreturn call (e.g. longjmp). + if (blockEndsInUnreachable(BB)) { + if (auto *CI = + dyn_cast_or_null<CallInst>(BB.getTerminator()->getPrevNode())) + if (CI->hasFnAttr(Attribute::NoReturn)) + return false; + return true; + } + + return false; +} + +/// Check whether it's safe to outline \p BB. +static bool mayExtractBlock(const BasicBlock &BB) { + return !BB.hasAddressTaken() && !BB.isEHPad(); +} + +/// Check whether \p Region is profitable to outline. +static bool isProfitableToOutline(const BlockSequence &Region, + TargetTransformInfo &TTI) { + if (Region.size() > 1) + return true; + + int Cost = 0; + const BasicBlock &BB = *Region[0]; + for (const Instruction &I : BB) { + if (isa<DbgInfoIntrinsic>(&I) || &I == BB.getTerminator()) + continue; + + Cost += TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); + + if (Cost >= (MinOutliningThreshold * TargetTransformInfo::TCC_Basic)) + return true; + } + return false; +} + +/// Mark \p F cold. Return true if it's changed. +static bool markEntireFunctionCold(Function &F) { + assert(!F.hasFnAttribute(Attribute::OptimizeNone) && "Can't mark this cold"); + bool Changed = false; + if (!F.hasFnAttribute(Attribute::MinSize)) { + F.addFnAttr(Attribute::MinSize); + Changed = true; + } + // TODO: Move this function into a cold section. + return Changed; +} + +class HotColdSplitting { +public: + HotColdSplitting(ProfileSummaryInfo *ProfSI, + function_ref<BlockFrequencyInfo *(Function &)> GBFI, + function_ref<TargetTransformInfo &(Function &)> GTTI, + std::function<OptimizationRemarkEmitter &(Function &)> *GORE) + : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE) {} + bool run(Module &M); + +private: + bool shouldOutlineFrom(const Function &F) const; + bool outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, + BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, + DominatorTree &DT, PostDomTree &PDT, + OptimizationRemarkEmitter &ORE); + Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT, + BlockFrequencyInfo *BFI, TargetTransformInfo &TTI, + OptimizationRemarkEmitter &ORE, unsigned Count); + SmallPtrSet<const Function *, 2> OutlinedFunctions; + ProfileSummaryInfo *PSI; + function_ref<BlockFrequencyInfo *(Function &)> GetBFI; + function_ref<TargetTransformInfo &(Function &)> GetTTI; + std::function<OptimizationRemarkEmitter &(Function &)> *GetORE; +}; + +class HotColdSplittingLegacyPass : public ModulePass { +public: + static char ID; + HotColdSplittingLegacyPass() : ModulePass(ID) { + initializeHotColdSplittingLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<BlockFrequencyInfoWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + bool runOnModule(Module &M) override; +}; + +} // end anonymous namespace + +// Returns false if the function should not be considered for hot-cold split +// optimization. +bool HotColdSplitting::shouldOutlineFrom(const Function &F) const { + // Do not try to outline again from an already outlined cold function. + if (OutlinedFunctions.count(&F)) + return false; + + if (F.size() <= 2) + return false; + + // TODO: Consider only skipping functions marked `optnone` or `cold`. + + if (F.hasAddressTaken()) + return false; + + if (F.hasFnAttribute(Attribute::AlwaysInline)) + return false; + + if (F.hasFnAttribute(Attribute::NoInline)) + return false; + + if (F.getCallingConv() == CallingConv::Cold) + return false; + + if (PSI->isFunctionEntryCold(&F)) + return false; + return true; +} + +Function *HotColdSplitting::extractColdRegion(const BlockSequence &Region, + DominatorTree &DT, + BlockFrequencyInfo *BFI, + TargetTransformInfo &TTI, + OptimizationRemarkEmitter &ORE, + unsigned Count) { + assert(!Region.empty()); + + // TODO: Pass BFI and BPI to update profile information. + CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr, + /* BPI */ nullptr, /* AllowVarArgs */ false, + /* AllowAlloca */ false, + /* Suffix */ "cold." + std::to_string(Count)); + + SetVector<Value *> Inputs, Outputs, Sinks; + CE.findInputsOutputs(Inputs, Outputs, Sinks); + + // Do not extract regions that have live exit variables. + if (Outputs.size() > 0) { + LLVM_DEBUG(llvm::dbgs() << "Not outlining; live outputs\n"); + return nullptr; + } + + // TODO: Run MergeBasicBlockIntoOnlyPred on the outlined function. + Function *OrigF = Region[0]->getParent(); + if (Function *OutF = CE.extractCodeRegion()) { + User *U = *OutF->user_begin(); + CallInst *CI = cast<CallInst>(U); + CallSite CS(CI); + NumColdRegionsOutlined++; + if (TTI.useColdCCForColdCall(*OutF)) { + OutF->setCallingConv(CallingConv::Cold); + CS.setCallingConv(CallingConv::Cold); + } + CI->setIsNoInline(); + + // Try to make the outlined code as small as possible on the assumption + // that it's cold. + markEntireFunctionCold(*OutF); + + LLVM_DEBUG(llvm::dbgs() << "Outlined Region: " << *OutF); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "HotColdSplit", + &*Region[0]->begin()) + << ore::NV("Original", OrigF) << " split cold code into " + << ore::NV("Split", OutF); + }); + return OutF; + } + + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ExtractFailed", + &*Region[0]->begin()) + << "Failed to extract region at block " + << ore::NV("Block", Region.front()); + }); + return nullptr; +} + +/// A pair of (basic block, score). +using BlockTy = std::pair<BasicBlock *, unsigned>; + +namespace { +/// A maximal outlining region. This contains all blocks post-dominated by a +/// sink block, the sink block itself, and all blocks dominated by the sink. +class OutliningRegion { + /// A list of (block, score) pairs. A block's score is non-zero iff it's a + /// viable sub-region entry point. Blocks with higher scores are better entry + /// points (i.e. they are more distant ancestors of the sink block). + SmallVector<BlockTy, 0> Blocks = {}; + + /// The suggested entry point into the region. If the region has multiple + /// entry points, all blocks within the region may not be reachable from this + /// entry point. + BasicBlock *SuggestedEntryPoint = nullptr; + + /// Whether the entire function is cold. + bool EntireFunctionCold = false; + + /// Whether or not \p BB could be the entry point of an extracted region. + static bool isViableEntryPoint(BasicBlock &BB) { return !BB.isEHPad(); } + + /// If \p BB is a viable entry point, return \p Score. Return 0 otherwise. + static unsigned getEntryPointScore(BasicBlock &BB, unsigned Score) { + return isViableEntryPoint(BB) ? Score : 0; + } + + /// These scores should be lower than the score for predecessor blocks, + /// because regions starting at predecessor blocks are typically larger. + static constexpr unsigned ScoreForSuccBlock = 1; + static constexpr unsigned ScoreForSinkBlock = 1; + + OutliningRegion(const OutliningRegion &) = delete; + OutliningRegion &operator=(const OutliningRegion &) = delete; + +public: + OutliningRegion() = default; + OutliningRegion(OutliningRegion &&) = default; + OutliningRegion &operator=(OutliningRegion &&) = default; + + static OutliningRegion create(BasicBlock &SinkBB, const DominatorTree &DT, + const PostDomTree &PDT) { + OutliningRegion ColdRegion; + + SmallPtrSet<BasicBlock *, 4> RegionBlocks; + + auto addBlockToRegion = [&](BasicBlock *BB, unsigned Score) { + RegionBlocks.insert(BB); + ColdRegion.Blocks.emplace_back(BB, Score); + assert(RegionBlocks.size() == ColdRegion.Blocks.size() && "Duplicate BB"); + }; + + // The ancestor farthest-away from SinkBB, and also post-dominated by it. + unsigned SinkScore = getEntryPointScore(SinkBB, ScoreForSinkBlock); + ColdRegion.SuggestedEntryPoint = (SinkScore > 0) ? &SinkBB : nullptr; + unsigned BestScore = SinkScore; + + // Visit SinkBB's ancestors using inverse DFS. + auto PredIt = ++idf_begin(&SinkBB); + auto PredEnd = idf_end(&SinkBB); + while (PredIt != PredEnd) { + BasicBlock &PredBB = **PredIt; + bool SinkPostDom = PDT.dominates(&SinkBB, &PredBB); + + // If the predecessor is cold and has no predecessors, the entire + // function must be cold. + if (SinkPostDom && pred_empty(&PredBB)) { + ColdRegion.EntireFunctionCold = true; + return ColdRegion; + } + + // If SinkBB does not post-dominate a predecessor, do not mark the + // predecessor (or any of its predecessors) cold. + if (!SinkPostDom || !mayExtractBlock(PredBB)) { + PredIt.skipChildren(); + continue; + } + + // Keep track of the post-dominated ancestor farthest away from the sink. + // The path length is always >= 2, ensuring that predecessor blocks are + // considered as entry points before the sink block. + unsigned PredScore = getEntryPointScore(PredBB, PredIt.getPathLength()); + if (PredScore > BestScore) { + ColdRegion.SuggestedEntryPoint = &PredBB; + BestScore = PredScore; + } + + addBlockToRegion(&PredBB, PredScore); + ++PredIt; + } + + // Add SinkBB to the cold region. It's considered as an entry point before + // any sink-successor blocks. + addBlockToRegion(&SinkBB, SinkScore); + + // Find all successors of SinkBB dominated by SinkBB using DFS. + auto SuccIt = ++df_begin(&SinkBB); + auto SuccEnd = df_end(&SinkBB); + while (SuccIt != SuccEnd) { + BasicBlock &SuccBB = **SuccIt; + bool SinkDom = DT.dominates(&SinkBB, &SuccBB); + + // Don't allow the backwards & forwards DFSes to mark the same block. + bool DuplicateBlock = RegionBlocks.count(&SuccBB); + + // If SinkBB does not dominate a successor, do not mark the successor (or + // any of its successors) cold. + if (DuplicateBlock || !SinkDom || !mayExtractBlock(SuccBB)) { + SuccIt.skipChildren(); + continue; + } + + unsigned SuccScore = getEntryPointScore(SuccBB, ScoreForSuccBlock); + if (SuccScore > BestScore) { + ColdRegion.SuggestedEntryPoint = &SuccBB; + BestScore = SuccScore; + } + + addBlockToRegion(&SuccBB, SuccScore); + ++SuccIt; + } + + return ColdRegion; + } + + /// Whether this region has nothing to extract. + bool empty() const { return !SuggestedEntryPoint; } + + /// The blocks in this region. + ArrayRef<std::pair<BasicBlock *, unsigned>> blocks() const { return Blocks; } + + /// Whether the entire function containing this region is cold. + bool isEntireFunctionCold() const { return EntireFunctionCold; } + + /// Remove a sub-region from this region and return it as a block sequence. + BlockSequence takeSingleEntrySubRegion(DominatorTree &DT) { + assert(!empty() && !isEntireFunctionCold() && "Nothing to extract"); + + // Remove blocks dominated by the suggested entry point from this region. + // During the removal, identify the next best entry point into the region. + // Ensure that the first extracted block is the suggested entry point. + BlockSequence SubRegion = {SuggestedEntryPoint}; + BasicBlock *NextEntryPoint = nullptr; + unsigned NextScore = 0; + auto RegionEndIt = Blocks.end(); + auto RegionStartIt = remove_if(Blocks, [&](const BlockTy &Block) { + BasicBlock *BB = Block.first; + unsigned Score = Block.second; + bool InSubRegion = + BB == SuggestedEntryPoint || DT.dominates(SuggestedEntryPoint, BB); + if (!InSubRegion && Score > NextScore) { + NextEntryPoint = BB; + NextScore = Score; + } + if (InSubRegion && BB != SuggestedEntryPoint) + SubRegion.push_back(BB); + return InSubRegion; + }); + Blocks.erase(RegionStartIt, RegionEndIt); + + // Update the suggested entry point. + SuggestedEntryPoint = NextEntryPoint; + + return SubRegion; + } +}; +} // namespace + +bool HotColdSplitting::outlineColdRegions(Function &F, ProfileSummaryInfo &PSI, + BlockFrequencyInfo *BFI, + TargetTransformInfo &TTI, + DominatorTree &DT, PostDomTree &PDT, + OptimizationRemarkEmitter &ORE) { + bool Changed = false; + + // The set of cold blocks. + SmallPtrSet<BasicBlock *, 4> ColdBlocks; + + // The worklist of non-intersecting regions left to outline. + SmallVector<OutliningRegion, 2> OutliningWorklist; + + // Set up an RPO traversal. Experimentally, this performs better (outlines + // more) than a PO traversal, because we prevent region overlap by keeping + // the first region to contain a block. + ReversePostOrderTraversal<Function *> RPOT(&F); + + // Find all cold regions. + for (BasicBlock *BB : RPOT) { + // Skip blocks which can't be outlined. + if (!mayExtractBlock(*BB)) + continue; + + // This block is already part of some outlining region. + if (ColdBlocks.count(BB)) + continue; + + bool Cold = PSI.isColdBlock(BB, BFI) || + (EnableStaticAnalyis && unlikelyExecuted(*BB)); + if (!Cold) + continue; + + LLVM_DEBUG({ + dbgs() << "Found a cold block:\n"; + BB->dump(); + }); + + auto Region = OutliningRegion::create(*BB, DT, PDT); + if (Region.empty()) + continue; + + if (Region.isEntireFunctionCold()) { + LLVM_DEBUG(dbgs() << "Entire function is cold\n"); + return markEntireFunctionCold(F); + } + + // If this outlining region intersects with another, drop the new region. + // + // TODO: It's theoretically possible to outline more by only keeping the + // largest region which contains a block, but the extra bookkeeping to do + // this is tricky/expensive. + bool RegionsOverlap = any_of(Region.blocks(), [&](const BlockTy &Block) { + return !ColdBlocks.insert(Block.first).second; + }); + if (RegionsOverlap) + continue; + + OutliningWorklist.emplace_back(std::move(Region)); + ++NumColdRegionsFound; + } + + // Outline single-entry cold regions, splitting up larger regions as needed. + unsigned OutlinedFunctionID = 1; + while (!OutliningWorklist.empty()) { + OutliningRegion Region = OutliningWorklist.pop_back_val(); + assert(!Region.empty() && "Empty outlining region in worklist"); + do { + BlockSequence SubRegion = Region.takeSingleEntrySubRegion(DT); + if (!isProfitableToOutline(SubRegion, TTI)) { + LLVM_DEBUG({ + dbgs() << "Skipping outlining; not profitable to outline\n"; + SubRegion[0]->dump(); + }); + continue; + } + + LLVM_DEBUG({ + dbgs() << "Hot/cold splitting attempting to outline these blocks:\n"; + for (BasicBlock *BB : SubRegion) + BB->dump(); + }); + + Function *Outlined = + extractColdRegion(SubRegion, DT, BFI, TTI, ORE, OutlinedFunctionID); + if (Outlined) { + ++OutlinedFunctionID; + OutlinedFunctions.insert(Outlined); + Changed = true; + } + } while (!Region.empty()); + } + + return Changed; +} + +bool HotColdSplitting::run(Module &M) { + bool Changed = false; + OutlinedFunctions.clear(); + for (auto &F : M) { + if (!shouldOutlineFrom(F)) { + LLVM_DEBUG(llvm::dbgs() << "Skipping " << F.getName() << "\n"); + continue; + } + LLVM_DEBUG(llvm::dbgs() << "Outlining in " << F.getName() << "\n"); + DominatorTree DT(F); + PostDomTree PDT(F); + PDT.recalculate(F); + BlockFrequencyInfo *BFI = GetBFI(F); + TargetTransformInfo &TTI = GetTTI(F); + OptimizationRemarkEmitter &ORE = (*GetORE)(F); + Changed |= outlineColdRegions(F, *PSI, BFI, TTI, DT, PDT, ORE); + } + return Changed; +} + +bool HotColdSplittingLegacyPass::runOnModule(Module &M) { + if (skipModule(M)) + return false; + ProfileSummaryInfo *PSI = + &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + auto GTTI = [this](Function &F) -> TargetTransformInfo & { + return this->getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + }; + auto GBFI = [this](Function &F) { + return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); + }; + std::unique_ptr<OptimizationRemarkEmitter> ORE; + std::function<OptimizationRemarkEmitter &(Function &)> GetORE = + [&ORE](Function &F) -> OptimizationRemarkEmitter & { + ORE.reset(new OptimizationRemarkEmitter(&F)); + return *ORE.get(); + }; + + return HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M); +} + +PreservedAnalyses +HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) { + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + std::function<AssumptionCache &(Function &)> GetAssumptionCache = + [&FAM](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + + auto GBFI = [&FAM](Function &F) { + return &FAM.getResult<BlockFrequencyAnalysis>(F); + }; + + std::function<TargetTransformInfo &(Function &)> GTTI = + [&FAM](Function &F) -> TargetTransformInfo & { + return FAM.getResult<TargetIRAnalysis>(F); + }; + + std::unique_ptr<OptimizationRemarkEmitter> ORE; + std::function<OptimizationRemarkEmitter &(Function &)> GetORE = + [&ORE](Function &F) -> OptimizationRemarkEmitter & { + ORE.reset(new OptimizationRemarkEmitter(&F)); + return *ORE.get(); + }; + + ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); + + if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +char HotColdSplittingLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(HotColdSplittingLegacyPass, "hotcoldsplit", + "Hot Cold Splitting", false, false) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_END(HotColdSplittingLegacyPass, "hotcoldsplit", + "Hot Cold Splitting", false, false) + +ModulePass *llvm::createHotColdSplittingPass() { + return new HotColdSplittingLegacyPass(); +} diff --git a/lib/Transforms/IPO/IPO.cpp b/lib/Transforms/IPO/IPO.cpp index dce9ee076bc5..973382e2b097 100644 --- a/lib/Transforms/IPO/IPO.cpp +++ b/lib/Transforms/IPO/IPO.cpp @@ -34,6 +34,7 @@ void llvm::initializeIPO(PassRegistry &Registry) { initializeGlobalDCELegacyPassPass(Registry); initializeGlobalOptLegacyPassPass(Registry); initializeGlobalSplitPass(Registry); + initializeHotColdSplittingLegacyPassPass(Registry); initializeIPCPPass(Registry); initializeAlwaysInlinerLegacyPassPass(Registry); initializeSimpleInlinerPass(Registry); diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index 3da0c2e83eb8..66a6f80f31e4 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -64,6 +64,7 @@ #include <algorithm> #include <cassert> #include <functional> +#include <sstream> #include <tuple> #include <utility> #include <vector> @@ -112,6 +113,14 @@ static cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats( "printing of statistics for each inlined function")), cl::Hidden, cl::desc("Enable inliner stats for imported functions")); +/// Flag to add inline messages as callsite attributes 'inline-remark'. +static cl::opt<bool> + InlineRemarkAttribute("inline-remark-attribute", cl::init(false), + cl::Hidden, + cl::desc("Enable adding inline-remark attribute to" + " callsites processed by inliner but decided" + " to be not inlined")); + LegacyInlinerBase::LegacyInlinerBase(char &ID) : CallGraphSCCPass(ID) {} LegacyInlinerBase::LegacyInlinerBase(char &ID, bool InsertLifetime) @@ -263,7 +272,7 @@ static void mergeInlinedArrayAllocas( /// available from other functions inlined into the caller. If we are able to /// inline this call site we attempt to reuse already available allocas or add /// any new allocas to the set if not possible. -static bool InlineCallIfPossible( +static InlineResult InlineCallIfPossible( CallSite CS, InlineFunctionInfo &IFI, InlinedArrayAllocasTy &InlinedArrayAllocas, int InlineHistory, bool InsertLifetime, function_ref<AAResults &(Function &)> &AARGetter, @@ -275,8 +284,9 @@ static bool InlineCallIfPossible( // Try to inline the function. Get the list of static allocas that were // inlined. - if (!InlineFunction(CS, IFI, &AAR, InsertLifetime)) - return false; + InlineResult IR = InlineFunction(CS, IFI, &AAR, InsertLifetime); + if (!IR) + return IR; if (InlinerFunctionImportStats != InlinerFunctionImportStatsOpts::No) ImportedFunctionsStats.recordInline(*Caller, *Callee); @@ -286,7 +296,7 @@ static bool InlineCallIfPossible( if (!DisableInlinedAllocaMerging) mergeInlinedArrayAllocas(Caller, IFI, InlinedArrayAllocas, InlineHistory); - return true; + return IR; // success } /// Return true if inlining of CS can block the caller from being @@ -301,6 +311,11 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, // For now we only handle local or inline functions. if (!Caller->hasLocalLinkage() && !Caller->hasLinkOnceODRLinkage()) return false; + // If the cost of inlining CS is non-positive, it is not going to prevent the + // caller from being inlined into its callers and hence we don't need to + // defer. + if (IC.getCost() <= 0) + return false; // Try to detect the case where the current inlining candidate caller (call // it B) is a static or linkonce-ODR function and is an inlining candidate // elsewhere, and the current candidate callee (call it C) is large enough @@ -320,25 +335,31 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, TotalSecondaryCost = 0; // The candidate cost to be imposed upon the current function. int CandidateCost = IC.getCost() - 1; - // This bool tracks what happens if we do NOT inline C into B. - bool callerWillBeRemoved = Caller->hasLocalLinkage(); + // If the caller has local linkage and can be inlined to all its callers, we + // can apply a huge negative bonus to TotalSecondaryCost. + bool ApplyLastCallBonus = Caller->hasLocalLinkage() && !Caller->hasOneUse(); // This bool tracks what happens if we DO inline C into B. bool inliningPreventsSomeOuterInline = false; for (User *U : Caller->users()) { + // If the caller will not be removed (either because it does not have a + // local linkage or because the LastCallToStaticBonus has been already + // applied), then we can exit the loop early. + if (!ApplyLastCallBonus && TotalSecondaryCost >= IC.getCost()) + return false; CallSite CS2(U); // If this isn't a call to Caller (it could be some other sort // of reference) skip it. Such references will prevent the caller // from being removed. if (!CS2 || CS2.getCalledFunction() != Caller) { - callerWillBeRemoved = false; + ApplyLastCallBonus = false; continue; } InlineCost IC2 = GetInlineCost(CS2); ++NumCallerCallersAnalyzed; if (!IC2) { - callerWillBeRemoved = false; + ApplyLastCallBonus = false; continue; } if (IC2.isAlways()) @@ -356,7 +377,7 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, // one is set very low by getInlineCost, in anticipation that Caller will // be removed entirely. We did not account for this above unless there // is only one caller of Caller. - if (callerWillBeRemoved && !Caller->hasOneUse()) + if (ApplyLastCallBonus) TotalSecondaryCost -= InlineConstants::LastCallToStaticBonus; if (inliningPreventsSomeOuterInline && TotalSecondaryCost < IC.getCost()) @@ -365,6 +386,33 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, return false; } +static std::basic_ostream<char> &operator<<(std::basic_ostream<char> &R, + const ore::NV &Arg) { + return R << Arg.Val; +} + +template <class RemarkT> +RemarkT &operator<<(RemarkT &&R, const InlineCost &IC) { + using namespace ore; + if (IC.isAlways()) { + R << "(cost=always)"; + } else if (IC.isNever()) { + R << "(cost=never)"; + } else { + R << "(cost=" << ore::NV("Cost", IC.getCost()) + << ", threshold=" << ore::NV("Threshold", IC.getThreshold()) << ")"; + } + if (const char *Reason = IC.getReason()) + R << ": " << ore::NV("Reason", Reason); + return R; +} + +static std::string inlineCostStr(const InlineCost &IC) { + std::stringstream Remark; + Remark << IC; + return Remark.str(); +} + /// Return the cost only if the inliner should attempt to inline at the given /// CallSite. If we return the cost, we will emit an optimisation remark later /// using that cost, so we won't do so from this function. @@ -379,35 +427,32 @@ shouldInline(CallSite CS, function_ref<InlineCost(CallSite CS)> GetInlineCost, Function *Caller = CS.getCaller(); if (IC.isAlways()) { - LLVM_DEBUG(dbgs() << " Inlining: cost=always" + LLVM_DEBUG(dbgs() << " Inlining " << inlineCostStr(IC) << ", Call: " << *CS.getInstruction() << "\n"); return IC; } if (IC.isNever()) { - LLVM_DEBUG(dbgs() << " NOT Inlining: cost=never" + LLVM_DEBUG(dbgs() << " NOT Inlining " << inlineCostStr(IC) << ", Call: " << *CS.getInstruction() << "\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) << NV("Callee", Callee) << " not inlined into " - << NV("Caller", Caller) - << " because it should never be inlined (cost=never)"; + << NV("Caller", Caller) << " because it should never be inlined " + << IC; }); - return None; + return IC; } if (!IC) { - LLVM_DEBUG(dbgs() << " NOT Inlining: cost=" << IC.getCost() - << ", thres=" << IC.getThreshold() + LLVM_DEBUG(dbgs() << " NOT Inlining " << inlineCostStr(IC) << ", Call: " << *CS.getInstruction() << "\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "TooCostly", Call) << NV("Callee", Callee) << " not inlined into " - << NV("Caller", Caller) << " because too costly to inline (cost=" - << NV("Cost", IC.getCost()) - << ", threshold=" << NV("Threshold", IC.getThreshold()) << ")"; + << NV("Caller", Caller) << " because too costly to inline " << IC; }); - return None; + return IC; } int TotalSecondaryCost = 0; @@ -428,8 +473,7 @@ shouldInline(CallSite CS, function_ref<InlineCost(CallSite CS)> GetInlineCost, return None; } - LLVM_DEBUG(dbgs() << " Inlining: cost=" << IC.getCost() - << ", thres=" << IC.getThreshold() + LLVM_DEBUG(dbgs() << " Inlining " << inlineCostStr(IC) << ", Call: " << *CS.getInstruction() << '\n'); return IC; } @@ -461,6 +505,26 @@ bool LegacyInlinerBase::runOnSCC(CallGraphSCC &SCC) { return inlineCalls(SCC); } +static void emit_inlined_into(OptimizationRemarkEmitter &ORE, DebugLoc &DLoc, + const BasicBlock *Block, const Function &Callee, + const Function &Caller, const InlineCost &IC) { + ORE.emit([&]() { + bool AlwaysInline = IC.isAlways(); + StringRef RemarkName = AlwaysInline ? "AlwaysInline" : "Inlined"; + return OptimizationRemark(DEBUG_TYPE, RemarkName, DLoc, Block) + << ore::NV("Callee", &Callee) << " inlined into " + << ore::NV("Caller", &Caller) << " with " << IC; + }); +} + +static void setInlineRemark(CallSite &CS, StringRef message) { + if (!InlineRemarkAttribute) + return; + + Attribute attr = Attribute::get(CS->getContext(), "inline-remark", message); + CS.addAttribute(AttributeList::FunctionIndex, attr); +} + static bool inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, std::function<AssumptionCache &(Function &)> GetAssumptionCache, @@ -510,6 +574,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, if (Callee->isDeclaration()) { using namespace ore; + setInlineRemark(CS, "unavailable definition"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) << NV("Callee", Callee) << " will not be inlined into " @@ -573,8 +638,10 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, // infinitely inline. InlineHistoryID = CallSites[CSi].second; if (InlineHistoryID != -1 && - InlineHistoryIncludes(Callee, InlineHistoryID, InlineHistory)) + InlineHistoryIncludes(Callee, InlineHistoryID, InlineHistory)) { + setInlineRemark(CS, "recursive"); continue; + } } // FIXME for new PM: because of the old PM we currently generate ORE and @@ -585,8 +652,17 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, Optional<InlineCost> OIC = shouldInline(CS, GetInlineCost, ORE); // If the policy determines that we should inline this function, // delete the call instead. - if (!OIC) + if (!OIC.hasValue()) { + setInlineRemark(CS, "deferred"); + continue; + } + + if (!OIC.getValue()) { + // shouldInline() call returned a negative inline cost that explains + // why this callsite should not be inlined. + setInlineRemark(CS, inlineCostStr(*OIC)); continue; + } // If this call site is dead and it is to a readonly function, we should // just delete the call instead of trying to inline it, regardless of @@ -595,6 +671,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, if (IsTriviallyDead) { LLVM_DEBUG(dbgs() << " -> Deleting dead call: " << *Instr << "\n"); // Update the call graph by deleting the edge from Callee to Caller. + setInlineRemark(CS, "trivially dead"); CG[Caller]->removeCallEdgeFor(CS); Instr->eraseFromParent(); ++NumCallsDeleted; @@ -606,34 +683,22 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, // Attempt to inline the function. using namespace ore; - if (!InlineCallIfPossible(CS, InlineInfo, InlinedArrayAllocas, - InlineHistoryID, InsertLifetime, AARGetter, - ImportedFunctionsStats)) { + InlineResult IR = InlineCallIfPossible( + CS, InlineInfo, InlinedArrayAllocas, InlineHistoryID, + InsertLifetime, AARGetter, ImportedFunctionsStats); + if (!IR) { + setInlineRemark(CS, std::string(IR) + "; " + inlineCostStr(*OIC)); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) << NV("Callee", Callee) << " will not be inlined into " - << NV("Caller", Caller); + << NV("Caller", Caller) << ": " << NV("Reason", IR.message); }); continue; } ++NumInlined; - ORE.emit([&]() { - bool AlwaysInline = OIC->isAlways(); - StringRef RemarkName = AlwaysInline ? "AlwaysInline" : "Inlined"; - OptimizationRemark R(DEBUG_TYPE, RemarkName, DLoc, Block); - R << NV("Callee", Callee) << " inlined into "; - R << NV("Caller", Caller); - if (AlwaysInline) - R << " with cost=always"; - else { - R << " with cost=" << NV("Cost", OIC->getCost()); - R << " (threshold=" << NV("Threshold", OIC->getThreshold()); - R << ")"; - } - return R; - }); + emit_inlined_into(ORE, DLoc, Block, *Callee, *Caller, *OIC); // If inlining this function gave us any new call sites, throw them // onto our worklist to process. They are useful inline candidates. @@ -692,7 +757,7 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, bool LegacyInlinerBase::inlineCalls(CallGraphSCC &SCC) { CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph(); ACT = &getAnalysis<AssumptionCacheTracker>(); - PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); @@ -865,6 +930,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Calls.push_back({CS, -1}); else if (!isa<IntrinsicInst>(I)) { using namespace ore; + setInlineRemark(CS, "unavailable definition"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) << NV("Callee", Callee) << " will not be inlined into " @@ -908,8 +974,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, LazyCallGraph::Node &N = *CG.lookup(F); if (CG.lookupSCC(N) != C) continue; - if (F.hasFnAttribute(Attribute::OptimizeNone)) + if (F.hasFnAttribute(Attribute::OptimizeNone)) { + setInlineRemark(Calls[i].first, "optnone attribute"); continue; + } LLVM_DEBUG(dbgs() << "Inlining calls in: " << F.getName() << "\n"); @@ -953,8 +1021,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Function &Callee = *CS.getCalledFunction(); if (InlineHistoryID != -1 && - InlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) + InlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) { + setInlineRemark(CS, "recursive"); continue; + } // Check if this inlining may repeat breaking an SCC apart that has // already been split once before. In that case, inlining here may @@ -966,13 +1036,23 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, LLVM_DEBUG(dbgs() << "Skipping inlining internal SCC edge from a node " "previously split out of this SCC by inlining: " << F.getName() << " -> " << Callee.getName() << "\n"); + setInlineRemark(CS, "recursive SCC split"); continue; } Optional<InlineCost> OIC = shouldInline(CS, GetInlineCost, ORE); // Check whether we want to inline this callsite. - if (!OIC) + if (!OIC.hasValue()) { + setInlineRemark(CS, "deferred"); + continue; + } + + if (!OIC.getValue()) { + // shouldInline() call returned a negative inline cost that explains + // why this callsite should not be inlined. + setInlineRemark(CS, inlineCostStr(*OIC)); continue; + } // Setup the data structure used to plumb customization into the // `InlineFunction` routine. @@ -987,32 +1067,22 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, using namespace ore; - if (!InlineFunction(CS, IFI)) { + InlineResult IR = InlineFunction(CS, IFI); + if (!IR) { + setInlineRemark(CS, std::string(IR) + "; " + inlineCostStr(*OIC)); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) << NV("Callee", &Callee) << " will not be inlined into " - << NV("Caller", &F); + << NV("Caller", &F) << ": " << NV("Reason", IR.message); }); continue; } DidInline = true; InlinedCallees.insert(&Callee); - ORE.emit([&]() { - bool AlwaysInline = OIC->isAlways(); - StringRef RemarkName = AlwaysInline ? "AlwaysInline" : "Inlined"; - OptimizationRemark R(DEBUG_TYPE, RemarkName, DLoc, Block); - R << NV("Callee", &Callee) << " inlined into "; - R << NV("Caller", &F); - if (AlwaysInline) - R << " with cost=always"; - else { - R << " with cost=" << NV("Cost", OIC->getCost()); - R << " (threshold=" << NV("Threshold", OIC->getThreshold()); - R << ")"; - } - return R; - }); + ++NumInlined; + + emit_inlined_into(ORE, DLoc, Block, Callee, F, *OIC); // Add any new callsites to defined functions to the worklist. if (!IFI.InlinedCallSites.empty()) { @@ -1099,10 +1169,19 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // SCC splits and merges. To avoid this, we capture the originating caller // node and the SCC containing the call edge. This is a slight over // approximation of the possible inlining decisions that must be avoided, - // but is relatively efficient to store. + // but is relatively efficient to store. We use C != OldC to know when + // a new SCC is generated and the original SCC may be generated via merge + // in later iterations. + // + // It is also possible that even if no new SCC is generated + // (i.e., C == OldC), the original SCC could be split and then merged + // into the same one as itself. and the original SCC will be added into + // UR.CWorklist again, we want to catch such cases too. + // // FIXME: This seems like a very heavyweight way of retaining the inline // history, we should look for a more efficient way of tracking it. - if (C != OldC && llvm::any_of(InlinedCallees, [&](Function *Callee) { + if ((C != OldC || UR.CWorklist.count(OldC)) && + llvm::any_of(InlinedCallees, [&](Function *Callee) { return CG.lookupSCC(*CG.lookup(*Callee)) == OldC; })) { LLVM_DEBUG(dbgs() << "Inlined an internal call edge and split an SCC, " @@ -1138,6 +1217,7 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, // And delete the actual function from the module. M.getFunctionList().erase(DeadF); + ++NumDeleted; } if (!Changed) diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index 8c86f7cb806a..733235d45a09 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -104,8 +104,8 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { bool ShouldExtractLoop = false; // Extract the loop if the entry block doesn't branch to the loop header. - TerminatorInst *EntryTI = - L->getHeader()->getParent()->getEntryBlock().getTerminator(); + Instruction *EntryTI = + L->getHeader()->getParent()->getEntryBlock().getTerminator(); if (!isa<BranchInst>(EntryTI) || !cast<BranchInst>(EntryTI)->isUnconditional() || EntryTI->getSuccessor(0) != L->getHeader()) { diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp index 4f7571884707..87c65db09517 100644 --- a/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/lib/Transforms/IPO/LowerTypeTests.cpp @@ -989,6 +989,7 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { if (F->isDSOLocal()) { Function *RealF = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, + F->getAddressSpace(), Name + ".cfi", &M); RealF->setVisibility(GlobalVariable::HiddenVisibility); replaceDirectCalls(F, RealF); @@ -1000,13 +1001,13 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { if (F->isDeclarationForLinker() && !isDefinition) { // Declaration of an external function. FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, - Name + ".cfi_jt", &M); + F->getAddressSpace(), Name + ".cfi_jt", &M); FDecl->setVisibility(GlobalValue::HiddenVisibility); } else if (isDefinition) { F->setName(Name + ".cfi"); F->setLinkage(GlobalValue::ExternalLinkage); FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage, - Name, &M); + F->getAddressSpace(), Name, &M); FDecl->setVisibility(Visibility); Visibility = GlobalValue::HiddenVisibility; @@ -1016,7 +1017,8 @@ void LowerTypeTestsModule::importFunction(Function *F, bool isDefinition) { for (auto &U : F->uses()) { if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) { Function *AliasDecl = Function::Create( - F->getFunctionType(), GlobalValue::ExternalLinkage, "", &M); + F->getFunctionType(), GlobalValue::ExternalLinkage, + F->getAddressSpace(), "", &M); AliasDecl->takeName(A); A->replaceAllUsesWith(AliasDecl); ToErase.push_back(A); @@ -1191,7 +1193,9 @@ void LowerTypeTestsModule::moveInitializerToModuleConstructor( WeakInitializerFn = Function::Create( FunctionType::get(Type::getVoidTy(M.getContext()), /* IsVarArg */ false), - GlobalValue::InternalLinkage, "__cfi_global_var_init", &M); + GlobalValue::InternalLinkage, + M.getDataLayout().getProgramAddressSpace(), + "__cfi_global_var_init", &M); BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", WeakInitializerFn); ReturnInst::Create(M.getContext(), BB); @@ -1234,7 +1238,8 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( // placeholder first. Function *PlaceholderFn = Function::Create(cast<FunctionType>(F->getValueType()), - GlobalValue::ExternalWeakLinkage, "", &M); + GlobalValue::ExternalWeakLinkage, + F->getAddressSpace(), "", &M); replaceCfiUses(F, PlaceholderFn, IsDefinition); Constant *Target = ConstantExpr::getSelect( @@ -1424,7 +1429,9 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( Function *JumpTableFn = Function::Create(FunctionType::get(Type::getVoidTy(M.getContext()), /* IsVarArg */ false), - GlobalValue::PrivateLinkage, ".cfi.jumptable", &M); + GlobalValue::PrivateLinkage, + M.getDataLayout().getProgramAddressSpace(), + ".cfi.jumptable", &M); ArrayType *JumpTableType = ArrayType::get(getJumpTableEntryType(), Functions.size()); auto JumpTable = @@ -1695,6 +1702,13 @@ bool LowerTypeTestsModule::lower() { !ExportSummary && !ImportSummary) return false; + // If only some of the modules were split, we cannot correctly handle + // code that contains type tests. + if (TypeTestFunc && !TypeTestFunc->use_empty() && + ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || + (ImportSummary && ImportSummary->partiallySplitLTOUnits()))) + report_fatal_error("inconsistent LTO Unit splitting with llvm.type.test"); + if (ImportSummary) { if (TypeTestFunc) { for (auto UI = TypeTestFunc->use_begin(), UE = TypeTestFunc->use_end(); @@ -1813,7 +1827,8 @@ bool LowerTypeTestsModule::lower() { if (!F) F = Function::Create( FunctionType::get(Type::getVoidTy(M.getContext()), false), - GlobalVariable::ExternalLinkage, FunctionName, &M); + GlobalVariable::ExternalLinkage, + M.getDataLayout().getProgramAddressSpace(), FunctionName, &M); // If the function is available_externally, remove its definition so // that it is handled the same way as a declaration. Later we will try @@ -1997,7 +2012,7 @@ bool LowerTypeTestsModule::lower() { } Sets.emplace_back(I, MaxUniqueId); } - llvm::sort(Sets.begin(), Sets.end(), + llvm::sort(Sets, [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1, const std::pair<GlobalClassesTy::iterator, unsigned> &S2) { return S1.second < S2.second; @@ -2022,12 +2037,12 @@ bool LowerTypeTestsModule::lower() { // Order type identifiers by unique ID for determinism. This ordering is // stable as there is a one-to-one mapping between metadata and unique IDs. - llvm::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) { + llvm::sort(TypeIds, [&](Metadata *M1, Metadata *M2) { return TypeIdInfo[M1].UniqueId < TypeIdInfo[M2].UniqueId; }); // Same for the branch funnels. - llvm::sort(ICallBranchFunnels.begin(), ICallBranchFunnels.end(), + llvm::sort(ICallBranchFunnels, [&](ICallBranchFunnel *F1, ICallBranchFunnel *F2) { return F1->UniqueId < F2->UniqueId; }); diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 3bebb96c6d35..11efe95b10d4 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -136,6 +136,7 @@ using namespace llvm; STATISTIC(NumFunctionsMerged, "Number of functions merged"); STATISTIC(NumThunksWritten, "Number of thunks generated"); +STATISTIC(NumAliasesWritten, "Number of aliases generated"); STATISTIC(NumDoubleWeak, "Number of new functions created"); static cl::opt<unsigned> NumFunctionsForSanityCheck( @@ -165,6 +166,11 @@ static cl::opt<bool> cl::desc("Preserve debug info in thunk when mergefunc " "transformations are made.")); +static cl::opt<bool> + MergeFunctionsAliases("mergefunc-use-aliases", cl::Hidden, + cl::init(false), + cl::desc("Allow mergefunc to create aliases")); + namespace { class FunctionNode { @@ -272,6 +278,13 @@ private: /// delete G. void writeThunk(Function *F, Function *G); + // Replace G with an alias to F (deleting function G) + void writeAlias(Function *F, Function *G); + + // Replace G with an alias to F if possible, or a thunk to F if + // profitable. Returns false if neither is the case. + bool writeThunkOrAlias(Function *F, Function *G); + /// Replace function F with function G in the function tree. void replaceFunctionInTree(const FunctionNode &FN, Function *G); @@ -284,7 +297,7 @@ private: // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid // dangling iterators into FnTree. The invariant that preserves this is that // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree. - ValueMap<Function*, FnTreeType::iterator> FNodesInTree; + DenseMap<AssertingVH<Function>, FnTreeType::iterator> FNodesInTree; }; } // end anonymous namespace @@ -425,6 +438,7 @@ bool MergeFunctions::runOnModule(Module &M) { } while (!Deferred.empty()); FnTree.clear(); + FNodesInTree.clear(); GlobalNumbers.clear(); return Changed; @@ -460,7 +474,7 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { NewPAL.getRetAttributes(), NewArgAttrs)); - remove(CS.getInstruction()->getParent()->getParent()); + remove(CS.getInstruction()->getFunction()); U->set(BitcastNew); } } @@ -608,7 +622,7 @@ void MergeFunctions::filterInstsUnrelatedToPDI( LLVM_DEBUG(BI->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); } - } else if (dyn_cast<TerminatorInst>(BI) == GEntryBlock->getTerminator()) { + } else if (BI->isTerminator() && &*BI == GEntryBlock->getTerminator()) { LLVM_DEBUG(dbgs() << " Will Include Terminator: "); LLVM_DEBUG(BI->print(dbgs())); LLVM_DEBUG(dbgs() << "\n"); @@ -679,8 +693,8 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { GEntryBlock->getTerminator()->eraseFromParent(); BB = GEntryBlock; } else { - NewG = Function::Create(G->getFunctionType(), G->getLinkage(), "", - G->getParent()); + NewG = Function::Create(G->getFunctionType(), G->getLinkage(), + G->getAddressSpace(), "", G->getParent()); BB = BasicBlock::Create(F->getContext(), "", NewG); } @@ -734,27 +748,76 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { ++NumThunksWritten; } +// Whether this function may be replaced by an alias +static bool canCreateAliasFor(Function *F) { + if (!MergeFunctionsAliases || !F->hasGlobalUnnamedAddr()) + return false; + + // We should only see linkages supported by aliases here + assert(F->hasLocalLinkage() || F->hasExternalLinkage() + || F->hasWeakLinkage() || F->hasLinkOnceLinkage()); + return true; +} + +// Replace G with an alias to F (deleting function G) +void MergeFunctions::writeAlias(Function *F, Function *G) { + Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); + PointerType *PtrType = G->getType(); + auto *GA = GlobalAlias::create( + PtrType->getElementType(), PtrType->getAddressSpace(), + G->getLinkage(), "", BitcastF, G->getParent()); + + F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); + GA->takeName(G); + GA->setVisibility(G->getVisibility()); + GA->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + + removeUsers(G); + G->replaceAllUsesWith(GA); + G->eraseFromParent(); + + LLVM_DEBUG(dbgs() << "writeAlias: " << GA->getName() << '\n'); + ++NumAliasesWritten; +} + +// Replace G with an alias to F if possible, or a thunk to F if +// profitable. Returns false if neither is the case. +bool MergeFunctions::writeThunkOrAlias(Function *F, Function *G) { + if (canCreateAliasFor(G)) { + writeAlias(F, G); + return true; + } + if (isThunkProfitable(F)) { + writeThunk(F, G); + return true; + } + return false; +} + // Merge two equivalent functions. Upon completion, Function G is deleted. void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { if (F->isInterposable()) { assert(G->isInterposable()); - if (!isThunkProfitable(F)) { + // Both writeThunkOrAlias() calls below must succeed, either because we can + // create aliases for G and NewF, or because a thunk for F is profitable. + // F here has the same signature as NewF below, so that's what we check. + if (!isThunkProfitable(F) && (!canCreateAliasFor(F) || !canCreateAliasFor(G))) { return; } // Make them both thunks to the same internal function. - Function *H = Function::Create(F->getFunctionType(), F->getLinkage(), "", - F->getParent()); - H->copyAttributesFrom(F); - H->takeName(F); + Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), + F->getAddressSpace(), "", F->getParent()); + NewF->copyAttributesFrom(F); + NewF->takeName(F); removeUsers(F); - F->replaceAllUsesWith(H); + F->replaceAllUsesWith(NewF); - unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment()); + unsigned MaxAlignment = std::max(G->getAlignment(), NewF->getAlignment()); - writeThunk(F, G); - writeThunk(F, H); + writeThunkOrAlias(F, G); + writeThunkOrAlias(F, NewF); F->setAlignment(MaxAlignment); F->setLinkage(GlobalValue::PrivateLinkage); @@ -770,6 +833,7 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { GlobalNumbers.erase(G); // If G's address is not significant, replace it entirely. Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); + removeUsers(G); G->replaceAllUsesWith(BitcastF); } else { // Redirect direct callers of G to F. (See note on MergeFunctionsPDI @@ -781,18 +845,15 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { // If G was internal then we may have replaced all uses of G with F. If so, // stop here and delete G. There's no need for a thunk. (See note on // MergeFunctionsPDI above). - if (G->hasLocalLinkage() && G->use_empty() && !MergeFunctionsPDI) { + if (G->isDiscardableIfUnused() && G->use_empty() && !MergeFunctionsPDI) { G->eraseFromParent(); ++NumFunctionsMerged; return; } - if (!isThunkProfitable(F)) { - return; + if (writeThunkOrAlias(F, G)) { + ++NumFunctionsMerged; } - - writeThunk(F, G); - ++NumFunctionsMerged; } } @@ -816,6 +877,24 @@ void MergeFunctions::replaceFunctionInTree(const FunctionNode &FN, FN.replaceBy(G); } +// Ordering for functions that are equal under FunctionComparator +static bool isFuncOrderCorrect(const Function *F, const Function *G) { + if (F->isInterposable() != G->isInterposable()) { + // Strong before weak, because the weak function may call the strong + // one, but not the other way around. + return !F->isInterposable(); + } + if (F->hasLocalLinkage() != G->hasLocalLinkage()) { + // External before local, because we definitely have to keep the external + // function, but may be able to drop the local one. + return !F->hasLocalLinkage(); + } + // Impose a total order (by name) on the replacement of functions. This is + // important when operating on more than one module independently to prevent + // cycles of thunks calling each other when the modules are linked together. + return F->getName() <= G->getName(); +} + // Insert a ComparableFunction into the FnTree, or merge it away if equal to one // that was already inserted. bool MergeFunctions::insert(Function *NewFunction) { @@ -832,14 +911,7 @@ bool MergeFunctions::insert(Function *NewFunction) { const FunctionNode &OldF = *Result.first; - // Impose a total order (by name) on the replacement of functions. This is - // important when operating on more than one module independently to prevent - // cycles of thunks calling each other when the modules are linked together. - // - // First of all, we process strong functions before weak functions. - if ((OldF.getFunc()->isInterposable() && !NewFunction->isInterposable()) || - (OldF.getFunc()->isInterposable() == NewFunction->isInterposable() && - OldF.getFunc()->getName() > NewFunction->getName())) { + if (!isFuncOrderCorrect(OldF.getFunc(), NewFunction)) { // Swap the two functions. Function *F = OldF.getFunc(); replaceFunctionInTree(*Result.first, NewFunction); @@ -882,7 +954,7 @@ void MergeFunctions::removeUsers(Value *V) { for (User *U : V->users()) { if (Instruction *I = dyn_cast<Instruction>(U)) { - remove(I->getParent()->getParent()); + remove(I->getFunction()); } else if (isa<GlobalValue>(U)) { // do nothing } else if (Constant *C = dyn_cast<Constant>(U)) { diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index 4907e4b30519..da214a1d3b44 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -359,7 +359,7 @@ struct PartialInlinerLegacyPass : public ModulePass { TargetTransformInfoWrapperPass *TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>(); ProfileSummaryInfo *PSI = - getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&ACT](Function &F) -> AssumptionCache & { @@ -403,7 +403,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, auto IsSingleEntry = [](SmallVectorImpl<BasicBlock *> &BlockList) { BasicBlock *Dom = BlockList.front(); - return BlockList.size() > 1 && pred_size(Dom) == 1; + return BlockList.size() > 1 && Dom->hasNPredecessors(1); }; auto IsSingleExit = @@ -468,7 +468,7 @@ PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F, // Only consider regions with predecessor blocks that are considered // not-cold (default: part of the top 99.99% of all block counters) // AND greater than our minimum block execution count (default: 100). - if (PSI->isColdBB(thisBB, BFI) || + if (PSI->isColdBlock(thisBB, BFI) || BBProfileCount(thisBB) < MinBlockCounterExecution) continue; for (auto SI = succ_begin(thisBB); SI != succ_end(thisBB); ++SI) { @@ -556,7 +556,7 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { }; auto IsReturnBlock = [](BasicBlock *BB) { - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); return isa<ReturnInst>(TI); }; @@ -834,42 +834,37 @@ bool PartialInlinerImpl::shouldPartialInline( int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { int InlineCost = 0; const DataLayout &DL = BB->getParent()->getParent()->getDataLayout(); - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - if (isa<DbgInfoIntrinsic>(I)) - continue; - - switch (I->getOpcode()) { + for (Instruction &I : BB->instructionsWithoutDebug()) { + // Skip free instructions. + switch (I.getOpcode()) { case Instruction::BitCast: case Instruction::PtrToInt: case Instruction::IntToPtr: case Instruction::Alloca: + case Instruction::PHI: continue; case Instruction::GetElementPtr: - if (cast<GetElementPtrInst>(I)->hasAllZeroIndices()) + if (cast<GetElementPtrInst>(&I)->hasAllZeroIndices()) continue; break; default: break; } - IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(I); - if (IntrInst) { - if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start || - IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) - continue; - } + if (I.isLifetimeStartOrEnd()) + continue; - if (CallInst *CI = dyn_cast<CallInst>(I)) { + if (CallInst *CI = dyn_cast<CallInst>(&I)) { InlineCost += getCallsiteCost(CallSite(CI), DL); continue; } - if (InvokeInst *II = dyn_cast<InvokeInst>(I)) { + if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) { InlineCost += getCallsiteCost(CallSite(II), DL); continue; } - if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { + if (SwitchInst *SI = dyn_cast<SwitchInst>(&I)) { InlineCost += (SI->getNumCases() + 1) * InlineConstants::InstrCost; continue; } @@ -1251,7 +1246,7 @@ std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { if (PSI->isFunctionEntryCold(F)) return {false, nullptr}; - if (F->user_begin() == F->user_end()) + if (empty(F->users())) return {false, nullptr}; OptimizationRemarkEmitter ORE(F); @@ -1357,7 +1352,7 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { return false; } - assert(Cloner.OrigFunc->user_begin() == Cloner.OrigFunc->user_end() && + assert(empty(Cloner.OrigFunc->users()) && "F's users should all be replaced!"); std::vector<User *> Users(Cloner.ClonedFunc->user_begin(), @@ -1461,9 +1456,7 @@ bool PartialInlinerImpl::run(Module &M) { std::pair<bool, Function * > Result = unswitchFunction(CurrFunc); if (Result.second) Worklist.push_back(Result.second); - if (Result.first) { - Changed = true; - } + Changed |= Result.first; } return Changed; diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index 5ced6481996a..9764944dc332 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -104,6 +104,10 @@ static cl::opt<bool> EnablePrepareForThinLTO("prepare-for-thinlto", cl::init(false), cl::Hidden, cl::desc("Enable preparation for ThinLTO.")); +cl::opt<bool> EnableHotColdSplit("hot-cold-split", cl::init(false), cl::Hidden, + cl::desc("Enable hot-cold splitting pass")); + + static cl::opt<bool> RunPGOInstrGen( "profile-generate", cl::init(false), cl::Hidden, cl::desc("Enable PGO instrumentation.")); @@ -152,6 +156,10 @@ static cl::opt<bool> EnableGVNSink( "enable-gvn-sink", cl::init(false), cl::Hidden, cl::desc("Enable the GVN sinking pass (default = off)")); +static cl::opt<bool> + EnableCHR("enable-chr", cl::init(true), cl::Hidden, + cl::desc("Enable control height reduction optimization (CHR)")); + PassManagerBuilder::PassManagerBuilder() { OptLevel = 2; SizeLevel = 0; @@ -367,13 +375,11 @@ void PassManagerBuilder::addFunctionSimplificationPasses( addExtensionsToPM(EP_LateLoopOptimizations, MPM); MPM.add(createLoopDeletionPass()); // Delete dead loops - if (EnableLoopInterchange) { - // FIXME: These are function passes and break the loop pass pipeline. + if (EnableLoopInterchange) MPM.add(createLoopInterchangePass()); // Interchange loops - MPM.add(createCFGSimplificationPass()); - } - if (!DisableUnrollLoops) - MPM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops + + MPM.add(createSimpleLoopUnrollPass(OptLevel, + DisableUnrollLoops)); // Unroll small loops addExtensionsToPM(EP_LoopOptimizerEnd, MPM); // This ends the loop pass pipelines. @@ -411,6 +417,10 @@ void PassManagerBuilder::addFunctionSimplificationPasses( // Clean up after everything. addInstructionCombiningPass(MPM); addExtensionsToPM(EP_Peephole, MPM); + + if (EnableCHR && OptLevel >= 3 && + (!PGOInstrUse.empty() || !PGOSampleUse.empty())) + MPM.add(createControlHeightReductionLegacyPass()); } void PassManagerBuilder::populateModulePassManager( @@ -452,12 +462,14 @@ void PassManagerBuilder::populateModulePassManager( addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); - // Rename anon globals to be able to export them in the summary. - // This has to be done after we add the extensions to the pass manager - // as there could be passes (e.g. Adddress sanitizer) which introduce - // new unnamed globals. - if (PrepareForLTO || PrepareForThinLTO) + if (PrepareForLTO || PrepareForThinLTO) { + MPM.add(createCanonicalizeAliasesPass()); + // Rename anon globals to be able to export them in the summary. + // This has to be done after we add the extensions to the pass manager + // as there could be passes (e.g. Adddress sanitizer) which introduce + // new unnamed globals. MPM.add(createNameAnonGlobalPass()); + } return; } @@ -575,6 +587,7 @@ void PassManagerBuilder::populateModulePassManager( // Ensure we perform any last passes, but do so before renaming anonymous // globals in case the passes add any. addExtensionsToPM(EP_OptimizerLast, MPM); + MPM.add(createCanonicalizeAliasesPass()); // Rename anon globals to be able to export them in the summary. MPM.add(createNameAnonGlobalPass()); return; @@ -627,7 +640,7 @@ void PassManagerBuilder::populateModulePassManager( // llvm.loop.distribute=true or when -enable-loop-distribute is specified. MPM.add(createLoopDistributePass()); - MPM.add(createLoopVectorizePass(DisableUnrollLoops, LoopVectorize)); + MPM.add(createLoopVectorizePass(DisableUnrollLoops, !LoopVectorize)); // Eliminate loads by forwarding stores from the previous iteration to loads // of the current iteration. @@ -672,16 +685,17 @@ void PassManagerBuilder::populateModulePassManager( addExtensionsToPM(EP_Peephole, MPM); addInstructionCombiningPass(MPM); - if (!DisableUnrollLoops) { - if (EnableUnrollAndJam) { - // Unroll and Jam. We do this before unroll but need to be in a separate - // loop pass manager in order for the outer loop to be processed by - // unroll and jam before the inner loop is unrolled. - MPM.add(createLoopUnrollAndJamPass(OptLevel)); - } + if (EnableUnrollAndJam && !DisableUnrollLoops) { + // Unroll and Jam. We do this before unroll but need to be in a separate + // loop pass manager in order for the outer loop to be processed by + // unroll and jam before the inner loop is unrolled. + MPM.add(createLoopUnrollAndJamPass(OptLevel)); + } - MPM.add(createLoopUnrollPass(OptLevel)); // Unroll small loops + MPM.add(createLoopUnrollPass(OptLevel, + DisableUnrollLoops)); // Unroll small loops + if (!DisableUnrollLoops) { // LoopUnroll may generate some redundency to cleanup. addInstructionCombiningPass(MPM); @@ -690,7 +704,9 @@ void PassManagerBuilder::populateModulePassManager( // outer loop. LICM pass can help to promote the runtime check out if the // checked value is loop invariant. MPM.add(createLICMPass()); - } + } + + MPM.add(createWarnMissedTransformationsPass()); // After vectorization and unrolling, assume intrinsics may tell us more // about pointer alignments. @@ -722,18 +738,29 @@ void PassManagerBuilder::populateModulePassManager( // flattening of blocks. MPM.add(createDivRemPairsPass()); + if (EnableHotColdSplit) + MPM.add(createHotColdSplittingPass()); + // LoopSink (and other loop passes since the last simplifyCFG) might have // resulted in single-entry-single-exit or empty blocks. Clean up the CFG. MPM.add(createCFGSimplificationPass()); addExtensionsToPM(EP_OptimizerLast, MPM); - // Rename anon globals to be able to handle them in the summary - if (PrepareForLTO) + if (PrepareForLTO) { + MPM.add(createCanonicalizeAliasesPass()); + // Rename anon globals to be able to handle them in the summary MPM.add(createNameAnonGlobalPass()); + } } void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { + // Load sample profile before running the LTO optimization pipeline. + if (!PGOSampleUse.empty()) { + PM.add(createPruneEHPass()); + PM.add(createSampleProfileLoaderPass(PGOSampleUse)); + } + // Remove unused virtual tables to improve the quality of code generated by // whole-program devirtualization and bitset lowering. PM.add(createGlobalDCEPass()); @@ -851,12 +878,13 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { if (EnableLoopInterchange) PM.add(createLoopInterchangePass()); - if (!DisableUnrollLoops) - PM.add(createSimpleLoopUnrollPass(OptLevel)); // Unroll small loops - PM.add(createLoopVectorizePass(true, LoopVectorize)); + PM.add(createSimpleLoopUnrollPass(OptLevel, + DisableUnrollLoops)); // Unroll small loops + PM.add(createLoopVectorizePass(true, !LoopVectorize)); // The vectorizer may have significantly shortened a loop body; unroll again. - if (!DisableUnrollLoops) - PM.add(createLoopUnrollPass(OptLevel)); + PM.add(createLoopUnrollPass(OptLevel, DisableUnrollLoops)); + + PM.add(createWarnMissedTransformationsPass()); // Now that we've optimized loops (in particular loop induction variables), // we may have exposed more scalar opportunities. Run parts of the scalar diff --git a/lib/Transforms/IPO/PruneEH.cpp b/lib/Transforms/IPO/PruneEH.cpp index 2be654258aa8..ae586c017471 100644 --- a/lib/Transforms/IPO/PruneEH.cpp +++ b/lib/Transforms/IPO/PruneEH.cpp @@ -107,7 +107,7 @@ static bool runImpl(CallGraphSCC &SCC, CallGraph &CG) { continue; for (const BasicBlock &BB : *F) { - const TerminatorInst *TI = BB.getTerminator(); + const Instruction *TI = BB.getTerminator(); if (CheckUnwind && TI->mayThrow()) { SCCMightUnwind = true; } else if (CheckReturn && isa<ReturnInst>(TI)) { @@ -255,7 +255,7 @@ static void DeleteBasicBlock(BasicBlock *BB, CallGraph &CG) { } if (TokenInst) { - if (!isa<TerminatorInst>(TokenInst)) + if (!TokenInst->isTerminator()) changeToUnreachable(TokenInst->getNextNode(), /*UseLLVMTrap=*/false); } else { // Get the list of successors of this block. diff --git a/lib/Transforms/IPO/SCCP.cpp b/lib/Transforms/IPO/SCCP.cpp index cc53c4b8c46f..d2c34abfc132 100644 --- a/lib/Transforms/IPO/SCCP.cpp +++ b/lib/Transforms/IPO/SCCP.cpp @@ -1,4 +1,6 @@ #include "llvm/Transforms/IPO/SCCP.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar/SCCP.h" @@ -8,9 +10,22 @@ using namespace llvm; PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { const DataLayout &DL = M.getDataLayout(); auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); - if (!runIPSCCP(M, DL, &TLI)) + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + auto getAnalysis = [&FAM](Function &F) -> AnalysisResultsForFn { + DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); + return { + make_unique<PredicateInfo>(F, DT, FAM.getResult<AssumptionAnalysis>(F)), + &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F)}; + }; + + if (!runIPSCCP(M, DL, &TLI, getAnalysis)) return PreservedAnalyses::all(); - return PreservedAnalyses::none(); + + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); + PA.preserve<FunctionAnalysisManagerModuleProxy>(); + return PA; } namespace { @@ -34,10 +49,25 @@ public: const DataLayout &DL = M.getDataLayout(); const TargetLibraryInfo *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return runIPSCCP(M, DL, TLI); + + auto getAnalysis = [this](Function &F) -> AnalysisResultsForFn { + DominatorTree &DT = + this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + return { + make_unique<PredicateInfo>( + F, DT, + this->getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + F)), + nullptr, // We cannot preserve the DT or PDT with the legacy pass + nullptr}; // manager, so set them to nullptr. + }; + + return runIPSCCP(M, DL, TLI, getAnalysis); } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; @@ -49,6 +79,7 @@ char IPSCCPLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", "Interprocedural Sparse Conditional Constant Propagation", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", "Interprocedural Sparse Conditional Constant Propagation", diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index dcd24595f7ea..9f123c2b875e 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -96,6 +96,13 @@ static cl::opt<std::string> SampleProfileFile( "sample-profile-file", cl::init(""), cl::value_desc("filename"), cl::desc("Profile file loaded by -sample-profile"), cl::Hidden); +// The named file contains a set of transformations that may have been applied +// to the symbol names between the program from which the sample data was +// collected and the current program's symbols. +static cl::opt<std::string> SampleProfileRemappingFile( + "sample-profile-remapping-file", cl::init(""), cl::value_desc("filename"), + cl::desc("Profile remapping file loaded by -sample-profile"), cl::Hidden); + static cl::opt<unsigned> SampleProfileMaxPropagateIterations( "sample-profile-max-propagate-iterations", cl::init(100), cl::desc("Maximum number of iterations to go through when propagating " @@ -116,6 +123,12 @@ static cl::opt<bool> NoWarnSampleUnused( cl::desc("Use this option to turn off/on warnings about function with " "samples but without debug information to use those samples. ")); +static cl::opt<bool> ProfileSampleAccurate( + "profile-sample-accurate", cl::Hidden, cl::init(false), + cl::desc("If the sample profile is accurate, we will mark all un-sampled " + "callsite and function as having 0 samples. Otherwise, treat " + "un-sampled callsites and functions conservatively as unknown. ")); + namespace { using BlockWeightMap = DenseMap<const BasicBlock *, uint64_t>; @@ -183,12 +196,12 @@ private: class SampleProfileLoader { public: SampleProfileLoader( - StringRef Name, bool IsThinLTOPreLink, + StringRef Name, StringRef RemapName, bool IsThinLTOPreLink, std::function<AssumptionCache &(Function &)> GetAssumptionCache, std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo) : GetAC(std::move(GetAssumptionCache)), GetTTI(std::move(GetTargetTransformInfo)), Filename(Name), - IsThinLTOPreLink(IsThinLTOPreLink) {} + RemappingFilename(RemapName), IsThinLTOPreLink(IsThinLTOPreLink) {} bool doInitialization(Module &M); bool runOnModule(Module &M, ModuleAnalysisManager *AM, @@ -205,6 +218,7 @@ protected: const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const; std::vector<const FunctionSamples *> findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const; + mutable DenseMap<const DILocation *, const FunctionSamples *> DILocation2SampleMap; const FunctionSamples *findFunctionSamples(const Instruction &I) const; bool inlineCallInstruction(Instruction *I); bool inlineHotFunctions(Function &F, @@ -282,6 +296,9 @@ protected: /// Name of the profile file to load. std::string Filename; + /// Name of the profile remapping file to load. + std::string RemappingFilename; + /// Flag indicating whether the profile input loaded successfully. bool ProfileIsValid = false; @@ -311,13 +328,14 @@ public: SampleProfileLoaderLegacyPass(StringRef Name = SampleProfileFile, bool IsThinLTOPreLink = false) - : ModulePass(ID), SampleLoader(Name, IsThinLTOPreLink, - [&](Function &F) -> AssumptionCache & { - return ACT->getAssumptionCache(F); - }, - [&](Function &F) -> TargetTransformInfo & { - return TTIWP->getTTI(F); - }) { + : ModulePass(ID), + SampleLoader(Name, SampleProfileRemappingFile, IsThinLTOPreLink, + [&](Function &F) -> AssumptionCache & { + return ACT->getAssumptionCache(F); + }, + [&](Function &F) -> TargetTransformInfo & { + return TTIWP->getTTI(F); + }) { initializeSampleProfileLoaderLegacyPassPass( *PassRegistry::getPassRegistry()); } @@ -527,10 +545,10 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { if (!FS) return std::error_code(); - // Ignore all intrinsics and branch instructions. - // Branch instruction usually contains debug info from sources outside of + // Ignore all intrinsics, phinodes and branch instructions. + // Branch and phinodes instruction usually contains debug info from sources outside of // the residing basic block, thus we ignore them during annotation. - if (isa<BranchInst>(Inst) || isa<IntrinsicInst>(Inst)) + if (isa<BranchInst>(Inst) || isa<IntrinsicInst>(Inst) || isa<PHINode>(Inst)) return std::error_code(); // If a direct call/invoke instruction is inlined in profile @@ -643,8 +661,6 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const { if (FS == nullptr) return nullptr; - std::string CalleeGUID; - CalleeName = getRepInFormat(CalleeName, Reader->getFormat(), CalleeGUID); return FS->findFunctionSamplesAt(LineLocation(FunctionSamples::getOffset(DIL), DIL->getBaseDiscriminator()), CalleeName); @@ -683,10 +699,12 @@ SampleProfileLoader::findIndirectCallFunctionSamples( Sum += NameFS.second.getEntrySamples(); R.push_back(&NameFS.second); } - llvm::sort(R.begin(), R.end(), - [](const FunctionSamples *L, const FunctionSamples *R) { - return L->getEntrySamples() > R->getEntrySamples(); - }); + llvm::sort(R, [](const FunctionSamples *L, const FunctionSamples *R) { + if (L->getEntrySamples() != R->getEntrySamples()) + return L->getEntrySamples() > R->getEntrySamples(); + return FunctionSamples::getGUID(L->getName()) < + FunctionSamples::getGUID(R->getName()); + }); } return R; } @@ -702,12 +720,14 @@ SampleProfileLoader::findIndirectCallFunctionSamples( /// \returns the FunctionSamples pointer to the inlined instance. const FunctionSamples * SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { - SmallVector<std::pair<LineLocation, StringRef>, 10> S; const DILocation *DIL = Inst.getDebugLoc(); if (!DIL) return Samples; - return Samples->findFunctionSamples(DIL); + auto it = DILocation2SampleMap.try_emplace(DIL,nullptr); + if (it.second) + it.first->second = Samples->findFunctionSamples(DIL); + return it.first->second; } bool SampleProfileLoader::inlineCallInstruction(Instruction *I) { @@ -760,7 +780,6 @@ bool SampleProfileLoader::inlineHotFunctions( Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs) { DenseSet<Instruction *> PromotedInsns; bool Changed = false; - bool isCompact = (Reader->getFormat() == SPF_Compact_Binary); while (true) { bool LocalChanged = false; SmallVector<Instruction *, 10> CIS; @@ -792,19 +811,16 @@ bool SampleProfileLoader::inlineHotFunctions( for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) { if (IsThinLTOPreLink) { FS->findInlinedFunctions(InlinedGUIDs, F.getParent(), - PSI->getOrCompHotCountThreshold(), - isCompact); + PSI->getOrCompHotCountThreshold()); continue; } - auto CalleeFunctionName = FS->getName(); + auto CalleeFunctionName = FS->getFuncNameInModule(F.getParent()); // If it is a recursive call, we do not inline it as it could bloat // the code exponentially. There is way to better handle this, e.g. // clone the caller first, and inline the cloned caller if it is // recursive. As llvm does not inline recursive calls, we will // simply ignore it instead of handling it explicitly. - std::string FGUID; - auto Fname = getRepInFormat(F.getName(), Reader->getFormat(), FGUID); - if (CalleeFunctionName == Fname) + if (CalleeFunctionName == F.getName()) continue; const char *Reason = "Callee function not available"; @@ -834,8 +850,7 @@ bool SampleProfileLoader::inlineHotFunctions( LocalChanged = true; } else if (IsThinLTOPreLink) { findCalleeFunctionSamples(*I)->findInlinedFunctions( - InlinedGUIDs, F.getParent(), PSI->getOrCompHotCountThreshold(), - isCompact); + InlinedGUIDs, F.getParent(), PSI->getOrCompHotCountThreshold()); } } if (LocalChanged) { @@ -1177,14 +1192,13 @@ static SmallVector<InstrProfValueData, 2> SortCallTargets( const SampleRecord::CallTargetMap &M) { SmallVector<InstrProfValueData, 2> R; for (auto I = M.begin(); I != M.end(); ++I) - R.push_back({Function::getGUID(I->getKey()), I->getValue()}); - llvm::sort(R.begin(), R.end(), - [](const InstrProfValueData &L, const InstrProfValueData &R) { - if (L.Count == R.Count) - return L.Value > R.Value; - else - return L.Count > R.Count; - }); + R.push_back({FunctionSamples::getGUID(I->getKey()), I->getValue()}); + llvm::sort(R, [](const InstrProfValueData &L, const InstrProfValueData &R) { + if (L.Count == R.Count) + return L.Value > R.Value; + else + return L.Count > R.Count; + }); return R; } @@ -1292,7 +1306,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { } } } - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 1) continue; if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI)) @@ -1519,12 +1533,28 @@ bool SampleProfileLoader::doInitialization(Module &M) { return false; } Reader = std::move(ReaderOrErr.get()); + Reader->collectFuncsToUse(M); ProfileIsValid = (Reader->read() == sampleprof_error::success); + + if (!RemappingFilename.empty()) { + // Apply profile remappings to the loaded profile data if requested. + // For now, we only support remapping symbols encoded using the Itanium + // C++ ABI's name mangling scheme. + ReaderOrErr = SampleProfileReaderItaniumRemapper::create( + RemappingFilename, Ctx, std::move(Reader)); + if (std::error_code EC = ReaderOrErr.getError()) { + std::string Msg = "Could not open profile remapping file: " + EC.message(); + Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg)); + return false; + } + Reader = std::move(ReaderOrErr.get()); + ProfileIsValid = (Reader->read() == sampleprof_error::success); + } return true; } ModulePass *llvm::createSampleProfileLoaderPass() { - return new SampleProfileLoaderLegacyPass(SampleProfileFile); + return new SampleProfileLoaderLegacyPass(); } ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) { @@ -1533,6 +1563,7 @@ ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) { bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM, ProfileSummaryInfo *_PSI) { + FunctionSamples::GUIDToFuncNameMapper Mapper(M); if (!ProfileIsValid) return false; @@ -1577,15 +1608,25 @@ bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { ACT = &getAnalysis<AssumptionCacheTracker>(); TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>(); ProfileSummaryInfo *PSI = - getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); return SampleLoader.runOnModule(M, nullptr, PSI); } bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) { - // Initialize the entry count to -1, which will be treated conservatively - // by getEntryCount as the same as unknown (None). If we have samples this - // will be overwritten in emitAnnotations. - F.setEntryCount(ProfileCount(-1, Function::PCT_Real)); + + DILocation2SampleMap.clear(); + // By default the entry count is initialized to -1, which will be treated + // conservatively by getEntryCount as the same as unknown (None). This is + // to avoid newly added code to be treated as cold. If we have samples + // this will be overwritten in emitAnnotations. + // If ProfileSampleAccurate is true or F has profile-sample-accurate + // attribute, initialize the entry count to 0 so callsites or functions + // unsampled will be treated as cold. + uint64_t initialEntryCount = + (ProfileSampleAccurate || F.hasFnAttribute("profile-sample-accurate")) + ? 0 + : -1; + F.setEntryCount(ProfileCount(initialEntryCount, Function::PCT_Real)); std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; if (AM) { auto &FAM = @@ -1616,6 +1657,8 @@ PreservedAnalyses SampleProfileLoaderPass::run(Module &M, SampleProfileLoader SampleLoader( ProfileFileName.empty() ? SampleProfileFile : ProfileFileName, + ProfileRemappingFileName.empty() ? SampleProfileRemappingFile + : ProfileRemappingFileName, IsThinLTOPreLink, GetAssumptionCache, GetTTI); SampleLoader.doInitialization(M); diff --git a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp index 3c5ad37bced1..ba4efb3ff60d 100644 --- a/lib/Transforms/IPO/SyntheticCountsPropagation.cpp +++ b/lib/Transforms/IPO/SyntheticCountsPropagation.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CallGraph.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/SyntheticCountsUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Function.h" @@ -46,7 +47,7 @@ using ProfileCount = Function::ProfileCount; #define DEBUG_TYPE "synthetic-counts-propagation" /// Initial synthetic count assigned to functions. -static cl::opt<int> +cl::opt<int> InitialSyntheticCount("initial-synthetic-count", cl::Hidden, cl::init(10), cl::ZeroOrMore, cl::desc("Initial value of synthetic entry count.")); @@ -98,13 +99,15 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, ModuleAnalysisManager &MAM) { FunctionAnalysisManager &FAM = MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); - DenseMap<Function *, uint64_t> Counts; + DenseMap<Function *, Scaled64> Counts; // Set initial entry counts. - initializeCounts(M, [&](Function *F, uint64_t Count) { Counts[F] = Count; }); + initializeCounts( + M, [&](Function *F, uint64_t Count) { Counts[F] = Scaled64(Count, 0); }); - // Compute the relative block frequency for a call edge. Use scaled numbers - // and not integers since the relative block frequency could be less than 1. - auto GetCallSiteRelFreq = [&](const CallGraphNode::CallRecord &Edge) { + // Edge includes information about the source. Hence ignore the first + // parameter. + auto GetCallSiteProfCount = [&](const CallGraphNode *, + const CallGraphNode::CallRecord &Edge) { Optional<Scaled64> Res = None; if (!Edge.first) return Res; @@ -112,29 +115,33 @@ PreservedAnalyses SyntheticCountsPropagation::run(Module &M, CallSite CS(cast<Instruction>(Edge.first)); Function *Caller = CS.getCaller(); auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*Caller); + + // Now compute the callsite count from relative frequency and + // entry count: BasicBlock *CSBB = CS.getInstruction()->getParent(); Scaled64 EntryFreq(BFI.getEntryFreq(), 0); - Scaled64 BBFreq(BFI.getBlockFreq(CSBB).getFrequency(), 0); - BBFreq /= EntryFreq; - return Optional<Scaled64>(BBFreq); + Scaled64 BBCount(BFI.getBlockFreq(CSBB).getFrequency(), 0); + BBCount /= EntryFreq; + BBCount *= Counts[Caller]; + return Optional<Scaled64>(BBCount); }; CallGraph CG(M); // Propgate the entry counts on the callgraph. SyntheticCountsUtils<const CallGraph *>::propagate( - &CG, GetCallSiteRelFreq, - [&](const CallGraphNode *N) { return Counts[N->getFunction()]; }, - [&](const CallGraphNode *N, uint64_t New) { + &CG, GetCallSiteProfCount, [&](const CallGraphNode *N, Scaled64 New) { auto F = N->getFunction(); if (!F || F->isDeclaration()) return; + Counts[F] += New; }); // Set the counts as metadata. - for (auto Entry : Counts) - Entry.first->setEntryCount( - ProfileCount(Entry.second, Function::PCT_Synthetic)); + for (auto Entry : Counts) { + Entry.first->setEntryCount(ProfileCount( + Entry.second.template toInt<uint64_t>(), Function::PCT_Synthetic)); + } return PreservedAnalyses::all(); } diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 8fe7ae1282cc..510ecb516dc2 100644 --- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -154,7 +154,8 @@ void simplifyExternals(Module &M) { continue; Function *NewF = - Function::Create(EmptyFT, GlobalValue::ExternalLinkage, "", &M); + Function::Create(EmptyFT, GlobalValue::ExternalLinkage, + F.getAddressSpace(), "", &M); NewF->setVisibility(F.getVisibility()); NewF->takeName(&F); F.replaceAllUsesWith(ConstantExpr::getBitCast(NewF, F.getType())); @@ -237,7 +238,7 @@ void splitAndWriteThinLTOBitcode( // sound because the virtual constant propagation optimizations effectively // inline all implementations of the virtual function into each call site, // rather than using function attributes to perform local optimization. - std::set<const Function *> EligibleVirtualFns; + DenseSet<const Function *> EligibleVirtualFns; // If any member of a comdat lives in MergedM, put all members of that // comdat in MergedM to keep the comdat together. DenseSet<const Comdat *> MergedMComdats; @@ -417,8 +418,18 @@ void splitAndWriteThinLTOBitcode( } } -// Returns whether this module needs to be split because it uses type metadata. +// Returns whether this module needs to be split because splitting is +// enabled and it uses type metadata. bool requiresSplit(Module &M) { + // First check if the LTO Unit splitting has been enabled. + bool EnableSplitLTOUnit = false; + if (auto *MD = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("EnableSplitLTOUnit"))) + EnableSplitLTOUnit = MD->getZExtValue(); + if (!EnableSplitLTOUnit) + return false; + + // Module only needs to be split if it contains type metadata. for (auto &GO : M.global_objects()) { if (GO.hasMetadata(LLVMContext::MD_type)) return true; @@ -430,7 +441,7 @@ bool requiresSplit(Module &M) { void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, function_ref<AAResults &(Function &)> AARGetter, Module &M, const ModuleSummaryIndex *Index) { - // See if this module has any type metadata. If so, we need to split it. + // Split module if splitting is enabled and it contains any type metadata. if (requiresSplit(M)) return splitAndWriteThinLTOBitcode(OS, ThinLinkOS, AARGetter, M); diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp index d65da2504db4..48bd0cda759d 100644 --- a/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -58,6 +58,7 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" @@ -406,6 +407,7 @@ void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS, struct DevirtModule { Module &M; function_ref<AAResults &(Function &)> AARGetter; + function_ref<DominatorTree &(Function &)> LookupDomTree; ModuleSummaryIndex *ExportSummary; const ModuleSummaryIndex *ImportSummary; @@ -433,10 +435,12 @@ struct DevirtModule { DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, + function_ref<DominatorTree &(Function &)> LookupDomTree, ModuleSummaryIndex *ExportSummary, const ModuleSummaryIndex *ImportSummary) - : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary), - ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())), + : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), + ExportSummary(ExportSummary), ImportSummary(ImportSummary), + Int8Ty(Type::getInt8Ty(M.getContext())), Int8PtrTy(Type::getInt8PtrTy(M.getContext())), Int32Ty(Type::getInt32Ty(M.getContext())), Int64Ty(Type::getInt64Ty(M.getContext())), @@ -533,9 +537,10 @@ struct DevirtModule { // Lower the module using the action and summary passed as command line // arguments. For testing purposes only. - static bool runForTesting( - Module &M, function_ref<AAResults &(Function &)> AARGetter, - function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter); + static bool + runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, + function_ref<DominatorTree &(Function &)> LookupDomTree); }; struct WholeProgramDevirt : public ModulePass { @@ -572,17 +577,23 @@ struct WholeProgramDevirt : public ModulePass { return *ORE; }; + auto LookupDomTree = [this](Function &F) -> DominatorTree & { + return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + }; + if (UseCommandLine) - return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter); + return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter, + LookupDomTree); - return DevirtModule(M, LegacyAARGetter(*this), OREGetter, ExportSummary, - ImportSummary) + return DevirtModule(M, LegacyAARGetter(*this), OREGetter, LookupDomTree, + ExportSummary, ImportSummary) .run(); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); } }; @@ -592,6 +603,7 @@ INITIALIZE_PASS_BEGIN(WholeProgramDevirt, "wholeprogramdevirt", "Whole program devirtualization", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt", "Whole program devirtualization", false, false) char WholeProgramDevirt::ID = 0; @@ -611,7 +623,11 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); }; - if (!DevirtModule(M, AARGetter, OREGetter, ExportSummary, ImportSummary) + auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { + return FAM.getResult<DominatorTreeAnalysis>(F); + }; + if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary, + ImportSummary) .run()) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -619,7 +635,8 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, bool DevirtModule::runForTesting( Module &M, function_ref<AAResults &(Function &)> AARGetter, - function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, + function_ref<DominatorTree &(Function &)> LookupDomTree) { ModuleSummaryIndex Summary(/*HaveGVs=*/false); // Handle the command-line summary arguments. This code is for testing @@ -637,7 +654,7 @@ bool DevirtModule::runForTesting( bool Changed = DevirtModule( - M, AARGetter, OREGetter, + M, AARGetter, OREGetter, LookupDomTree, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) .run(); @@ -665,7 +682,7 @@ void DevirtModule::buildTypeIdentifierMap( for (GlobalVariable &GV : M.globals()) { Types.clear(); GV.getMetadata(LLVMContext::MD_type, Types); - if (Types.empty()) + if (GV.isDeclaration() || Types.empty()) continue; VTableBits *&BitsPtr = GVToBits[&GV]; @@ -755,7 +772,8 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, auto Apply = [&](CallSiteInfo &CSInfo) { for (auto &&VCallSite : CSInfo.CallSites) { if (RemarksEnabled) - VCallSite.emitRemark("single-impl", TheFn->getName(), OREGetter); + VCallSite.emitRemark("single-impl", + TheFn->stripPointerCasts()->getName(), OREGetter); VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( TheFn, VCallSite.CS.getCalledValue()->getType())); // This use is no longer unsafe. @@ -846,10 +864,13 @@ void DevirtModule::tryICallBranchFunnel( Function *JT; if (isa<MDString>(Slot.TypeID)) { JT = Function::Create(FT, Function::ExternalLinkage, + M.getDataLayout().getProgramAddressSpace(), getGlobalName(Slot, {}, "branch_funnel"), &M); JT->setVisibility(GlobalValue::HiddenVisibility); } else { - JT = Function::Create(FT, Function::InternalLinkage, "branch_funnel", &M); + JT = Function::Create(FT, Function::InternalLinkage, + M.getDataLayout().getProgramAddressSpace(), + "branch_funnel", &M); } JT->addAttribute(1, Attribute::Nest); @@ -891,7 +912,8 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, continue; if (RemarksEnabled) - VCallSite.emitRemark("branch-funnel", JT->getName(), OREGetter); + VCallSite.emitRemark("branch-funnel", + JT->stripPointerCasts()->getName(), OREGetter); // Pass the address of the vtable in the nest register, which is r10 on // x86_64. @@ -1323,15 +1345,14 @@ void DevirtModule::rebuildGlobal(VTableBits &B) { bool DevirtModule::areRemarksEnabled() { const auto &FL = M.getFunctionList(); - if (FL.empty()) - return false; - const Function &Fn = FL.front(); - - const auto &BBL = Fn.getBasicBlockList(); - if (BBL.empty()) - return false; - auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); - return DI.isEnabled(); + for (const Function &Fn : FL) { + const auto &BBL = Fn.getBasicBlockList(); + if (BBL.empty()) + continue; + auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBL.front()); + return DI.isEnabled(); + } + return false; } void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, @@ -1341,7 +1362,7 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, // points to a member of the type identifier %md. Group calls by (type ID, // offset) pair (effectively the identity of the virtual function) and store // to CallSlots. - DenseSet<Value *> SeenPtrs; + DenseSet<CallSite> SeenCallSites; for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end(); I != E;) { auto CI = dyn_cast<CallInst>(I->getUser()); @@ -1352,19 +1373,22 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, // Search for virtual calls based on %p and add them to DevirtCalls. SmallVector<DevirtCallSite, 1> DevirtCalls; SmallVector<CallInst *, 1> Assumes; - findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); + auto &DT = LookupDomTree(*CI->getFunction()); + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); - // If we found any, add them to CallSlots. Only do this if we haven't seen - // the vtable pointer before, as it may have been CSE'd with pointers from - // other call sites, and we don't want to process call sites multiple times. + // If we found any, add them to CallSlots. if (!Assumes.empty()) { Metadata *TypeId = cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); - if (SeenPtrs.insert(Ptr).second) { - for (DevirtCallSite Call : DevirtCalls) { + for (DevirtCallSite Call : DevirtCalls) { + // Only add this CallSite if we haven't seen it before. The vtable + // pointer may have been CSE'd with pointers from other call sites, + // and we don't want to process call sites multiple times. We can't + // just skip the vtable Ptr if it has been seen before, however, since + // it may be shared by type tests that dominate different calls. + if (SeenCallSites.insert(Call.CS).second) CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr); - } } } @@ -1398,8 +1422,9 @@ void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { SmallVector<Instruction *, 1> LoadedPtrs; SmallVector<Instruction *, 1> Preds; bool HasNonCallUses = false; + auto &DT = LookupDomTree(*CI->getFunction()); findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, - HasNonCallUses, CI); + HasNonCallUses, CI, DT); // Start by generating "pessimistic" code that explicitly loads the function // pointer from the vtable and performs the type check. If possible, we will @@ -1538,6 +1563,17 @@ bool DevirtModule::run() { M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); + // If only some of the modules were split, we cannot correctly handle + // code that contains type tests or type checked loads. + if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || + (ImportSummary && ImportSummary->partiallySplitLTOUnits())) { + if ((TypeTestFunc && !TypeTestFunc->use_empty()) || + (TypeCheckedLoadFunc && !TypeCheckedLoadFunc->use_empty())) + report_fatal_error("inconsistent LTO Unit splitting with llvm.type.test " + "or llvm.type.checked.load"); + return false; + } + // Normally if there are no users of the devirtualization intrinsics in the // module, this pass has nothing to do. But if we are exporting, we also need // to handle any users that appear only in the function summaries. diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 83054588a9aa..6e196bfdbd25 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -186,8 +186,6 @@ namespace { Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); - Value *performFactorization(Instruction *I); - /// Convert given addend to a Value Value *createAddendVal(const FAddend &A, bool& NeedNeg); @@ -197,7 +195,6 @@ namespace { Value *createFSub(Value *Opnd0, Value *Opnd1); Value *createFAdd(Value *Opnd0, Value *Opnd1); Value *createFMul(Value *Opnd0, Value *Opnd1); - Value *createFDiv(Value *Opnd0, Value *Opnd1); Value *createFNeg(Value *V); Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota); void createInstPostProc(Instruction *NewInst, bool NoNumber = false); @@ -427,89 +424,6 @@ unsigned FAddend::drillAddendDownOneStep return BreakNum; } -// Try to perform following optimization on the input instruction I. Return the -// simplified expression if was successful; otherwise, return 0. -// -// Instruction "I" is Simplified into -// ------------------------------------------------------- -// (x * y) +/- (x * z) x * (y +/- z) -// (y / x) +/- (z / x) (y +/- z) / x -Value *FAddCombine::performFactorization(Instruction *I) { - assert((I->getOpcode() == Instruction::FAdd || - I->getOpcode() == Instruction::FSub) && "Expect add/sub"); - - Instruction *I0 = dyn_cast<Instruction>(I->getOperand(0)); - Instruction *I1 = dyn_cast<Instruction>(I->getOperand(1)); - - if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode()) - return nullptr; - - bool isMpy = false; - if (I0->getOpcode() == Instruction::FMul) - isMpy = true; - else if (I0->getOpcode() != Instruction::FDiv) - return nullptr; - - Value *Opnd0_0 = I0->getOperand(0); - Value *Opnd0_1 = I0->getOperand(1); - Value *Opnd1_0 = I1->getOperand(0); - Value *Opnd1_1 = I1->getOperand(1); - - // Input Instr I Factor AddSub0 AddSub1 - // ---------------------------------------------- - // (x*y) +/- (x*z) x y z - // (y/x) +/- (z/x) x y z - Value *Factor = nullptr; - Value *AddSub0 = nullptr, *AddSub1 = nullptr; - - if (isMpy) { - if (Opnd0_0 == Opnd1_0 || Opnd0_0 == Opnd1_1) - Factor = Opnd0_0; - else if (Opnd0_1 == Opnd1_0 || Opnd0_1 == Opnd1_1) - Factor = Opnd0_1; - - if (Factor) { - AddSub0 = (Factor == Opnd0_0) ? Opnd0_1 : Opnd0_0; - AddSub1 = (Factor == Opnd1_0) ? Opnd1_1 : Opnd1_0; - } - } else if (Opnd0_1 == Opnd1_1) { - Factor = Opnd0_1; - AddSub0 = Opnd0_0; - AddSub1 = Opnd1_0; - } - - if (!Factor) - return nullptr; - - FastMathFlags Flags; - Flags.setFast(); - if (I0) Flags &= I->getFastMathFlags(); - if (I1) Flags &= I->getFastMathFlags(); - - // Create expression "NewAddSub = AddSub0 +/- AddsSub1" - Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ? - createFAdd(AddSub0, AddSub1) : - createFSub(AddSub0, AddSub1); - if (ConstantFP *CFP = dyn_cast<ConstantFP>(NewAddSub)) { - const APFloat &F = CFP->getValueAPF(); - if (!F.isNormal()) - return nullptr; - } else if (Instruction *II = dyn_cast<Instruction>(NewAddSub)) - II->setFastMathFlags(Flags); - - if (isMpy) { - Value *RI = createFMul(Factor, NewAddSub); - if (Instruction *II = dyn_cast<Instruction>(RI)) - II->setFastMathFlags(Flags); - return RI; - } - - Value *RI = createFDiv(NewAddSub, Factor); - if (Instruction *II = dyn_cast<Instruction>(RI)) - II->setFastMathFlags(Flags); - return RI; -} - Value *FAddCombine::simplify(Instruction *I) { assert(I->hasAllowReassoc() && I->hasNoSignedZeros() && "Expected 'reassoc'+'nsz' instruction"); @@ -594,8 +508,7 @@ Value *FAddCombine::simplify(Instruction *I) { return R; } - // step 6: Try factorization as the last resort, - return performFactorization(I); + return nullptr; } Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { @@ -772,13 +685,6 @@ Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) { return V; } -Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) { - Value *V = Builder.CreateFDiv(Opnd0, Opnd1); - if (Instruction *I = dyn_cast<Instruction>(V)) - createInstPostProc(I); - return V; -} - void FAddCombine::createInstPostProc(Instruction *NewInstr, bool NoNumber) { NewInstr->setDebugLoc(Instr->getDebugLoc()); @@ -1135,7 +1041,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (SimplifyAssociativeOrCommutative(I)) return &I; - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // (A*B)+(A*C) -> A*(B+C) etc @@ -1285,77 +1191,8 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - // Check for (add (sext x), y), see if we can merge this into an - // integer add followed by a sext. - if (SExtInst *LHSConv = dyn_cast<SExtInst>(LHS)) { - // (add (sext x), cst) --> (sext (add x, cst')) - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { - if (LHSConv->hasOneUse()) { - Constant *CI = - ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (ConstantExpr::getSExt(CI, Ty) == RHSC && - willNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { - // Insert the new, smaller add. - Value *NewAdd = - Builder.CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); - return new SExtInst(NewAdd, Ty); - } - } - } - - // (add (sext x), (sext y)) --> (sext (add int x, y)) - if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) { - // Only do this if x/y have the same type, if at least one of them has a - // single use (so we don't increase the number of sexts), and if the - // integer add will not overflow. - if (LHSConv->getOperand(0)->getType() == - RHSConv->getOperand(0)->getType() && - (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && - willNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0), I)) { - // Insert the new integer add. - Value *NewAdd = Builder.CreateNSWAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0), "addconv"); - return new SExtInst(NewAdd, Ty); - } - } - } - - // Check for (add (zext x), y), see if we can merge this into an - // integer add followed by a zext. - if (auto *LHSConv = dyn_cast<ZExtInst>(LHS)) { - // (add (zext x), cst) --> (zext (add x, cst')) - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { - if (LHSConv->hasOneUse()) { - Constant *CI = - ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (ConstantExpr::getZExt(CI, Ty) == RHSC && - willNotOverflowUnsignedAdd(LHSConv->getOperand(0), CI, I)) { - // Insert the new, smaller add. - Value *NewAdd = - Builder.CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv"); - return new ZExtInst(NewAdd, Ty); - } - } - } - - // (add (zext x), (zext y)) --> (zext (add int x, y)) - if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) { - // Only do this if x/y have the same type, if at least one of them has a - // single use (so we don't increase the number of zexts), and if the - // integer add will not overflow. - if (LHSConv->getOperand(0)->getType() == - RHSConv->getOperand(0)->getType() && - (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && - willNotOverflowUnsignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0), I)) { - // Insert the new integer add. - Value *NewAdd = Builder.CreateNUWAdd( - LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); - return new ZExtInst(NewAdd, Ty); - } - } - } + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; // (add (xor A, B) (and A, B)) --> (or A, B) // (add (and A, B) (xor A, B)) --> (or A, B) @@ -1391,6 +1228,45 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return Changed ? &I : nullptr; } +/// Factor a common operand out of fadd/fsub of fmul/fdiv. +static Instruction *factorizeFAddFSub(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert((I.getOpcode() == Instruction::FAdd || + I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub"); + assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && + "FP factorization requires FMF"); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X, *Y, *Z; + bool IsFMul; + if ((match(Op0, m_OneUse(m_FMul(m_Value(X), m_Value(Z)))) && + match(Op1, m_OneUse(m_c_FMul(m_Value(Y), m_Specific(Z))))) || + (match(Op0, m_OneUse(m_FMul(m_Value(Z), m_Value(X)))) && + match(Op1, m_OneUse(m_c_FMul(m_Value(Y), m_Specific(Z)))))) + IsFMul = true; + else if (match(Op0, m_OneUse(m_FDiv(m_Value(X), m_Value(Z)))) && + match(Op1, m_OneUse(m_FDiv(m_Value(Y), m_Specific(Z))))) + IsFMul = false; + else + return nullptr; + + // (X * Z) + (Y * Z) --> (X + Y) * Z + // (X * Z) - (Y * Z) --> (X - Y) * Z + // (X / Z) + (Y / Z) --> (X + Y) / Z + // (X / Z) - (Y / Z) --> (X - Y) / Z + bool IsFAdd = I.getOpcode() == Instruction::FAdd; + Value *XY = IsFAdd ? Builder.CreateFAddFMF(X, Y, &I) + : Builder.CreateFSubFMF(X, Y, &I); + + // Bail out if we just created a denormal constant. + // TODO: This is copied from a previous implementation. Is it necessary? + const APFloat *C; + if (match(XY, m_APFloat(C)) && !C->isNormal()) + return nullptr; + + return IsFMul ? BinaryOperator::CreateFMulFMF(XY, Z, &I) + : BinaryOperator::CreateFDivFMF(XY, Z, &I); +} + Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifyFAddInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -1400,7 +1276,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (SimplifyAssociativeOrCommutative(I)) return &I; - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I)) @@ -1478,6 +1354,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { + if (Instruction *F = factorizeFAddFSub(I, Builder)) + return F; if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } @@ -1577,7 +1455,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // (A*B)-(A*C) -> A*(B-C) etc @@ -1771,19 +1649,51 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // X - A*-B -> X + A*B // X - -A*B -> X + A*B Value *A, *B; - Constant *CI; if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B))))) return BinaryOperator::CreateAdd(Op0, Builder.CreateMul(A, B)); - // X - A*CI -> X + A*-CI + // X - A*C -> X + A*-C // No need to handle commuted multiply because multiply handling will // ensure constant will be move to the right hand side. - if (match(Op1, m_Mul(m_Value(A), m_Constant(CI)))) { - Value *NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(CI)); + if (match(Op1, m_Mul(m_Value(A), m_Constant(C))) && !isa<ConstantExpr>(C)) { + Value *NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(C)); return BinaryOperator::CreateAdd(Op0, NewMul); } } + { + // ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A + // ~A - Min/Max(O, ~A) -> Max/Min(A, ~O) - A + // Min/Max(~A, O) - ~A -> A - Max/Min(A, ~O) + // Min/Max(O, ~A) - ~A -> A - Max/Min(A, ~O) + // So long as O here is freely invertible, this will be neutral or a win. + Value *LHS, *RHS, *A; + Value *NotA = Op0, *MinMax = Op1; + SelectPatternFlavor SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor; + if (!SelectPatternResult::isMinOrMax(SPF)) { + NotA = Op1; + MinMax = Op0; + SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor; + } + if (SelectPatternResult::isMinOrMax(SPF) && + match(NotA, m_Not(m_Value(A))) && (NotA == LHS || NotA == RHS)) { + if (NotA == LHS) + std::swap(LHS, RHS); + // LHS is now O above and expected to have at least 2 uses (the min/max) + // NotA is epected to have 2 uses from the min/max and 1 from the sub. + if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + !NotA->hasNUsesOrMore(4)) { + // Note: We don't generate the inverse max/min, just create the not of + // it and let other folds do the rest. + Value *Not = Builder.CreateNot(MinMax); + if (NotA == Op0) + return BinaryOperator::CreateSub(Not, A); + else + return BinaryOperator::CreateSub(A, Not); + } + } + } + // Optimize pointer differences into the same array into a size. Consider: // &A[10] - &A[0]: we should compile this to "10". Value *LHSOp, *RHSOp; @@ -1819,6 +1729,9 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return SelectInst::Create(Cmp, Neg, A); } + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; + bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { Changed = true; @@ -1838,7 +1751,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // Subtraction from -0.0 is the canonical form of fneg. @@ -1847,13 +1760,27 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) return BinaryOperator::CreateFNegFMF(Op1, &I); + Value *X, *Y; + Constant *C; + + // Fold negation into constant operand. This is limited with one-use because + // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. + // -(X * C) --> X * (-C) + if (match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) + return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + // -(X / C) --> X / (-C) + if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) + return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + // -(C / X) --> (-C) / X + if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) + return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. // This can also help codegen because fadd is commutative. // Note that if this fsub was really an fneg, the fadd with -0.0 will get // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. - Value *X, *Y; if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); @@ -1869,7 +1796,6 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - Constant *C; if (match(Op1, m_Constant(C)) && !isa<ConstantExpr>(Op1)) return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); @@ -1879,21 +1805,46 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { // Similar to above, but look through a cast of the negated value: // X - (fptrunc(-Y)) --> X + fptrunc(Y) - if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) { - Value *TruncY = Builder.CreateFPTrunc(Y, I.getType()); - return BinaryOperator::CreateFAddFMF(Op0, TruncY, &I); - } + Type *Ty = I.getType(); + if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + // X - (fpext(-Y)) --> X + fpext(Y) - if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) { - Value *ExtY = Builder.CreateFPExt(Y, I.getType()); - return BinaryOperator::CreateFAddFMF(Op0, ExtY, &I); - } + if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); - // Handle specials cases for FSub with selects feeding the operation + // Handle special cases for FSub with selects feeding the operation if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { + // (Y - X) - Y --> -X + if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return BinaryOperator::CreateFNegFMF(X, &I); + + // Y - (X + Y) --> -X + // Y - (Y + X) --> -X + if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return BinaryOperator::CreateFNegFMF(X, &I); + + // (X * C) - X --> X * (C - 1.0) + if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); + return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + } + // X - (X * C) --> X * (1.0 - C) + if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); + return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + } + + if (Instruction *F = factorizeFAddFSub(I, Builder)) + return F; + + // TODO: This performs reassociative folds for FP ops. Some fraction of the + // functionality has been subsumed by simple pattern matching here and in + // InstSimplify. We should let a dedicated reassociation pass handle more + // complex pattern matching and remove this from InstCombine. if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 3d758e2fe7c9..404c2ad7e6e7 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -53,11 +53,11 @@ static unsigned getFCmpCode(FCmpInst::Predicate CC) { /// operands into either a constant true or false, or a brand new ICmp /// instruction. The sign is passed in to determine which kind of predicate to /// use in the new icmp instruction. -static Value *getNewICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, +static Value *getNewICmpValue(unsigned Code, bool Sign, Value *LHS, Value *RHS, InstCombiner::BuilderTy &Builder) { ICmpInst::Predicate NewPred; - if (Value *NewConstant = getICmpValue(Sign, Code, LHS, RHS, NewPred)) - return NewConstant; + if (Constant *TorF = getPredForICmpCode(Code, Sign, LHS->getType(), NewPred)) + return TorF; return Builder.CreateICmp(NewPred, LHS, RHS); } @@ -898,6 +898,130 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } +/// General pattern: +/// X & Y +/// +/// Where Y is checking that all the high bits (covered by a mask 4294967168) +/// are uniform, i.e. %arg & 4294967168 can be either 4294967168 or 0 +/// Pattern can be one of: +/// %t = add i32 %arg, 128 +/// %r = icmp ult i32 %t, 256 +/// Or +/// %t0 = shl i32 %arg, 24 +/// %t1 = ashr i32 %t0, 24 +/// %r = icmp eq i32 %t1, %arg +/// Or +/// %t0 = trunc i32 %arg to i8 +/// %t1 = sext i8 %t0 to i32 +/// %r = icmp eq i32 %t1, %arg +/// This pattern is a signed truncation check. +/// +/// And X is checking that some bit in that same mask is zero. +/// I.e. can be one of: +/// %r = icmp sgt i32 %arg, -1 +/// Or +/// %t = and i32 %arg, 2147483648 +/// %r = icmp eq i32 %t, 0 +/// +/// Since we are checking that all the bits in that mask are the same, +/// and a particular bit is zero, what we are really checking is that all the +/// masked bits are zero. +/// So this should be transformed to: +/// %r = icmp ult i32 %arg, 128 +static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, + Instruction &CxtI, + InstCombiner::BuilderTy &Builder) { + assert(CxtI.getOpcode() == Instruction::And); + + // Match icmp ult (add %arg, C01), C1 (C1 == C01 << 1; powers of two) + auto tryToMatchSignedTruncationCheck = [](ICmpInst *ICmp, Value *&X, + APInt &SignBitMask) -> bool { + CmpInst::Predicate Pred; + const APInt *I01, *I1; // powers of two; I1 == I01 << 1 + if (!(match(ICmp, + m_ICmp(Pred, m_Add(m_Value(X), m_Power2(I01)), m_Power2(I1))) && + Pred == ICmpInst::ICMP_ULT && I1->ugt(*I01) && I01->shl(1) == *I1)) + return false; + // Which bit is the new sign bit as per the 'signed truncation' pattern? + SignBitMask = *I01; + return true; + }; + + // One icmp needs to be 'signed truncation check'. + // We need to match this first, else we will mismatch commutative cases. + Value *X1; + APInt HighestBit; + ICmpInst *OtherICmp; + if (tryToMatchSignedTruncationCheck(ICmp1, X1, HighestBit)) + OtherICmp = ICmp0; + else if (tryToMatchSignedTruncationCheck(ICmp0, X1, HighestBit)) + OtherICmp = ICmp1; + else + return nullptr; + + assert(HighestBit.isPowerOf2() && "expected to be power of two (non-zero)"); + + // Try to match/decompose into: icmp eq (X & Mask), 0 + auto tryToDecompose = [](ICmpInst *ICmp, Value *&X, + APInt &UnsetBitsMask) -> bool { + CmpInst::Predicate Pred = ICmp->getPredicate(); + // Can it be decomposed into icmp eq (X & Mask), 0 ? + if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), + Pred, X, UnsetBitsMask, + /*LookThruTrunc=*/false) && + Pred == ICmpInst::ICMP_EQ) + return true; + // Is it icmp eq (X & Mask), 0 already? + const APInt *Mask; + if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) && + Pred == ICmpInst::ICMP_EQ) { + UnsetBitsMask = *Mask; + return true; + } + return false; + }; + + // And the other icmp needs to be decomposable into a bit test. + Value *X0; + APInt UnsetBitsMask; + if (!tryToDecompose(OtherICmp, X0, UnsetBitsMask)) + return nullptr; + + assert(!UnsetBitsMask.isNullValue() && "empty mask makes no sense."); + + // Are they working on the same value? + Value *X; + if (X1 == X0) { + // Ok as is. + X = X1; + } else if (match(X0, m_Trunc(m_Specific(X1)))) { + UnsetBitsMask = UnsetBitsMask.zext(X1->getType()->getScalarSizeInBits()); + X = X1; + } else + return nullptr; + + // So which bits should be uniform as per the 'signed truncation check'? + // (all the bits starting with (i.e. including) HighestBit) + APInt SignBitsMask = ~(HighestBit - 1U); + + // UnsetBitsMask must have some common bits with SignBitsMask, + if (!UnsetBitsMask.intersects(SignBitsMask)) + return nullptr; + + // Does UnsetBitsMask contain any bits outside of SignBitsMask? + if (!UnsetBitsMask.isSubsetOf(SignBitsMask)) { + APInt OtherHighestBit = (~UnsetBitsMask) + 1U; + if (!OtherHighestBit.isPowerOf2()) + return nullptr; + HighestBit = APIntOps::umin(HighestBit, OtherHighestBit); + } + // Else, if it does not, then all is ok as-is. + + // %r = icmp ult %X, SignBit + return Builder.CreateICmpULT(X, ConstantInt::get(X->getType(), HighestBit), + CxtI.getName() + ".simplified"); +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { @@ -909,7 +1033,7 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(PredL, PredR)) { + if (predicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -917,8 +1041,8 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, LHS->getOperand(1) == RHS->getOperand(1)) { Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); unsigned Code = getICmpCode(LHS) & getICmpCode(RHS); - bool isSigned = LHS->isSigned() || RHS->isSigned(); - return getNewICmpValue(isSigned, Code, Op0, Op1, Builder); + bool IsSigned = LHS->isSigned() || RHS->isSigned(); + return getNewICmpValue(Code, IsSigned, Op0, Op1, Builder); } } @@ -937,6 +1061,9 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) return V; + if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); @@ -1004,7 +1131,7 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, return nullptr; // We can't fold (ugt x, C) & (sgt x, C2). - if (!PredicatesFoldable(PredL, PredR)) + if (!predicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. @@ -1408,7 +1535,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (SimplifyAssociativeOrCommutative(I)) return &I; - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // See if we can simplify any instructions used by the instruction whose sole @@ -1635,10 +1762,9 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return nullptr; } -/// Given an OR instruction, check to see if this is a bswap idiom. If so, -/// insert the new intrinsic and return it. -Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); +Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) { + assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'"); + Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1); // Look through zero extends. if (Instruction *Ext = dyn_cast<ZExtInst>(Op0)) @@ -1674,7 +1800,7 @@ Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { return nullptr; SmallVector<Instruction*, 4> Insts; - if (!recognizeBSwapOrBitReverseIdiom(&I, true, false, Insts)) + if (!recognizeBSwapOrBitReverseIdiom(&Or, true, false, Insts)) return nullptr; Instruction *LastInst = Insts.pop_back_val(); LastInst->removeFromParent(); @@ -1684,6 +1810,57 @@ Instruction *InstCombiner::MatchBSwap(BinaryOperator &I) { return LastInst; } +/// Transform UB-safe variants of bitwise rotate to the funnel shift intrinsic. +static Instruction *matchRotate(Instruction &Or) { + // TODO: Can we reduce the code duplication between this and the related + // rotate matching code under visitSelect and visitTrunc? + unsigned Width = Or.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) + return nullptr; + + // First, find an or'd pair of opposite shifts with the same shifted operand: + // or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1) + Value *Or0 = Or.getOperand(0), *Or1 = Or.getOperand(1); + Value *ShVal, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // Match the shift amount operands for a rotate pattern. This always matches + // a subtraction on the R operand. + auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { + // The shift amount may be masked with negation: + // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) + Value *X; + unsigned Mask = Width - 1; + if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && + match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) + return X; + + return nullptr; + }; + + Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width); + bool SubIsOnLHS = false; + if (!ShAmt) { + ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width); + SubIsOnLHS = true; + } + if (!ShAmt) + return nullptr; + + bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) || + (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl); + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); + return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt }); +} + /// If all elements of two constant vectors are 0/-1 and inverses, return true. static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { unsigned NumElts = C1->getType()->getVectorNumElements(); @@ -1704,14 +1881,33 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { /// We have an expression of the form (A & C) | (B & D). If A is a scalar or /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of /// B, it can be used as the condition operand of a select instruction. -static Value *getSelectCondition(Value *A, Value *B, - InstCombiner::BuilderTy &Builder) { - // If these are scalars or vectors of i1, A can be used directly. +Value *InstCombiner::getSelectCondition(Value *A, Value *B) { + // Step 1: We may have peeked through bitcasts in the caller. + // Exit immediately if we don't have (vector) integer types. Type *Ty = A->getType(); - if (match(A, m_Not(m_Specific(B))) && Ty->isIntOrIntVectorTy(1)) - return A; + if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy()) + return nullptr; + + // Step 2: We need 0 or all-1's bitmasks. + if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits()) + return nullptr; + + // Step 3: If B is the 'not' value of A, we have our answer. + if (match(A, m_Not(m_Specific(B)))) { + // If these are scalars or vectors of i1, A can be used directly. + if (Ty->isIntOrIntVectorTy(1)) + return A; + return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty)); + } - // If A and B are sign-extended, look through the sexts to find the booleans. + // If both operands are constants, see if the constants are inverse bitmasks. + Constant *AConst, *BConst; + if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) + if (AConst == ConstantExpr::getNot(BConst)) + return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty)); + + // Look for more complex patterns. The 'not' op may be hidden behind various + // casts. Look through sexts and bitcasts to find the booleans. Value *Cond; Value *NotB; if (match(A, m_SExt(m_Value(Cond))) && @@ -1727,36 +1923,29 @@ static Value *getSelectCondition(Value *A, Value *B, if (!Ty->isVectorTy()) return nullptr; - // If both operands are constants, see if the constants are inverse bitmasks. - Constant *AC, *BC; - if (match(A, m_Constant(AC)) && match(B, m_Constant(BC)) && - areInverseVectorBitmasks(AC, BC)) { - return Builder.CreateZExtOrTrunc(AC, CmpInst::makeCmpResultType(Ty)); - } - // If both operands are xor'd with constants using the same sexted boolean // operand, see if the constants are inverse bitmasks. - if (match(A, (m_Xor(m_SExt(m_Value(Cond)), m_Constant(AC)))) && - match(B, (m_Xor(m_SExt(m_Specific(Cond)), m_Constant(BC)))) && + // TODO: Use ConstantExpr::getNot()? + if (match(A, (m_Xor(m_SExt(m_Value(Cond)), m_Constant(AConst)))) && + match(B, (m_Xor(m_SExt(m_Specific(Cond)), m_Constant(BConst)))) && Cond->getType()->isIntOrIntVectorTy(1) && - areInverseVectorBitmasks(AC, BC)) { - AC = ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); - return Builder.CreateXor(Cond, AC); + areInverseVectorBitmasks(AConst, BConst)) { + AConst = ConstantExpr::getTrunc(AConst, CmpInst::makeCmpResultType(Ty)); + return Builder.CreateXor(Cond, AConst); } return nullptr; } /// We have an expression of the form (A & C) | (B & D). Try to simplify this /// to "A' ? C : D", where A' is a boolean or vector of booleans. -static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, - InstCombiner::BuilderTy &Builder) { +Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B, + Value *D) { // The potential condition of the select may be bitcasted. In that case, look // through its bitcast and the corresponding bitcast of the 'not' condition. Type *OrigType = A->getType(); A = peekThroughBitcast(A, true); B = peekThroughBitcast(B, true); - - if (Value *Cond = getSelectCondition(A, B, Builder)) { + if (Value *Cond = getSelectCondition(A, B)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) // The bitcasts will either all exist or all not exist. The builder will // not create unnecessary casts if the types already match. @@ -1838,7 +2027,7 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, } // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(PredL, PredR)) { + if (predicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -1846,8 +2035,8 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, LHS->getOperand(1) == RHS->getOperand(1)) { Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); unsigned Code = getICmpCode(LHS) | getICmpCode(RHS); - bool isSigned = LHS->isSigned() || RHS->isSigned(); - return getNewICmpValue(isSigned, Code, Op0, Op1, Builder); + bool IsSigned = LHS->isSigned() || RHS->isSigned(); + return getNewICmpValue(Code, IsSigned, Op0, Op1, Builder); } } @@ -1928,7 +2117,7 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return nullptr; // We can't fold (ugt x, C) | (sgt x, C2). - if (!PredicatesFoldable(PredL, PredR)) + if (!predicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. @@ -2007,7 +2196,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (SimplifyAssociativeOrCommutative(I)) return &I; - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // See if we can simplify any instructions used by the instruction whose sole @@ -2029,37 +2218,25 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) return FoldedLogic; - // Given an OR instruction, check to see if this is a bswap. - if (Instruction *BSwap = MatchBSwap(I)) + if (Instruction *BSwap = matchBSwap(I)) return BSwap; - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - { - Value *A; - const APInt *C; - // (X^C)|Y -> (X|Y)^C iff Y&C == 0 - if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && - MaskedValueIsZero(Op1, *C, 0, &I)) { - Value *NOr = Builder.CreateOr(A, Op1); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, - ConstantInt::get(NOr->getType(), *C)); - } + if (Instruction *Rotate = matchRotate(I)) + return Rotate; - // Y|(X^C) -> (X|Y)^C iff Y&C == 0 - if (match(Op1, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && - MaskedValueIsZero(Op0, *C, 0, &I)) { - Value *NOr = Builder.CreateOr(A, Op0); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, - ConstantInt::get(NOr->getType(), *C)); - } + Value *X, *Y; + const APInt *CV; + if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && + !CV->isAllOnesValue() && MaskedValueIsZero(Y, *CV, 0, &I)) { + // (X ^ C) | Y -> (X | Y) ^ C iff Y & C == 0 + // The check for a 'not' op is for efficiency (if Y is known zero --> ~X). + Value *Or = Builder.CreateOr(X, Y); + return BinaryOperator::CreateXor(Or, ConstantInt::get(I.getType(), *CV)); } - Value *A, *B; - // (A & C)|(B & D) - Value *C = nullptr, *D = nullptr; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B, *C, *D; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { ConstantInt *C1 = dyn_cast<ConstantInt>(C); @@ -2122,21 +2299,21 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // 'or' that it is replacing. if (Op0->hasOneUse() || Op1->hasOneUse()) { // (Cond & C) | (~Cond & D) -> Cond ? C : D, and commuted variants. - if (Value *V = matchSelectFromAndOr(A, C, B, D, Builder)) + if (Value *V = matchSelectFromAndOr(A, C, B, D)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(A, C, D, B, Builder)) + if (Value *V = matchSelectFromAndOr(A, C, D, B)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(C, A, B, D, Builder)) + if (Value *V = matchSelectFromAndOr(C, A, B, D)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(C, A, D, B, Builder)) + if (Value *V = matchSelectFromAndOr(C, A, D, B)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(B, D, A, C, Builder)) + if (Value *V = matchSelectFromAndOr(B, D, A, C)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(B, D, C, A, Builder)) + if (Value *V = matchSelectFromAndOr(B, D, C, A)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(D, B, A, C, Builder)) + if (Value *V = matchSelectFromAndOr(D, B, A, C)) return replaceInstUsesWith(I, V); - if (Value *V = matchSelectFromAndOr(D, B, C, A, Builder)) + if (Value *V = matchSelectFromAndOr(D, B, C, A)) return replaceInstUsesWith(I, V); } } @@ -2251,12 +2428,12 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // be simplified by a later pass either, so we try swapping the inner/outer // ORs in the hopes that we'll be able to simplify it this way. // (X|C) | V --> (X|V) | C - ConstantInt *C1; + ConstantInt *CI; if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) && - match(Op0, m_Or(m_Value(A), m_ConstantInt(C1)))) { + match(Op0, m_Or(m_Value(A), m_ConstantInt(CI)))) { Value *Inner = Builder.CreateOr(A, Op1); Inner->takeName(Op0); - return BinaryOperator::CreateOr(Inner, C1); + return BinaryOperator::CreateOr(Inner, CI); } // Change (or (bool?A:B),(bool?C:D)) --> (bool?(or A,C):(or B,D)) @@ -2339,7 +2516,7 @@ static Instruction *foldXorToXor(BinaryOperator &I, } Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { - if (PredicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { + if (predicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -2348,8 +2525,8 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B) Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS); - bool isSigned = LHS->isSigned() || RHS->isSigned(); - return getNewICmpValue(isSigned, Code, Op0, Op1, Builder); + bool IsSigned = LHS->isSigned() || RHS->isSigned(); + return getNewICmpValue(Code, IsSigned, Op0, Op1, Builder); } } @@ -2360,7 +2537,8 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); if ((LHS->hasOneUse() || RHS->hasOneUse()) && - LHS0->getType() == RHS0->getType()) { + LHS0->getType() == RHS0->getType() && + LHS0->getType()->isIntOrIntVectorTy()) { // (X > -1) ^ (Y > -1) --> (X ^ Y) < 0 // (X < 0) ^ (Y < 0) --> (X ^ Y) < 0 if ((PredL == CmpInst::ICMP_SGT && match(LHS1, m_AllOnes()) && @@ -2452,6 +2630,32 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, return nullptr; } +// Transform +// ~(x ^ y) +// into: +// (~x) ^ y +// or into +// x ^ (~y) +static Instruction *sinkNotIntoXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y; + // FIXME: one-use check is not needed in general, but currently we are unable + // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182) + if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) + return nullptr; + + // We only want to do the transform if it is free to do. + if (IsFreeToInvert(X, X->hasOneUse())) { + // Ok, good. + } else if (IsFreeToInvert(Y, Y->hasOneUse())) { + std::swap(X, Y); + } else + return nullptr; + + Value *NotX = Builder.CreateNot(X, X->getName() + ".not"); + return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan"); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -2463,7 +2667,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (SimplifyAssociativeOrCommutative(I)) return &I; - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *NewXor = foldXorToXor(I, Builder)) @@ -2481,9 +2685,15 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); - // A^B --> A|B iff A and B have no bits set in common. Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (haveNoCommonBitsSet(Op0, Op1, DL, &AC, &I, &DT)) + + // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) + // This it a special case in haveNoCommonBitsSet, but the computeKnownBits + // calls in there are unnecessary as SimplifyDemandedInstructionBits should + // have already taken care of those cases. + Value *M; + if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), + m_c_And(m_Deferred(M), m_Value())))) return BinaryOperator::CreateOr(Op0, Op1); // Apply DeMorgan's Law for 'nand' / 'nor' logic with an inverted operand. @@ -2528,8 +2738,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } // ~(X - Y) --> ~X + Y - if (match(NotVal, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) - return BinaryOperator::CreateAdd(Builder.CreateNot(X), Y); + if (match(NotVal, m_Sub(m_Value(X), m_Value(Y)))) + if (isa<Constant>(X) || NotVal->hasOneUse()) + return BinaryOperator::CreateAdd(Builder.CreateNot(X), Y); // ~(~X >>s Y) --> (X >>s Y) if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) @@ -2539,19 +2750,36 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // the 'not' by inverting the constant and using the opposite shift type. // Canonicalization rules ensure that only a negative constant uses 'ashr', // but we must check that in case that transform has not fired yet. + + // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) Constant *C; if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) && - match(C, m_Negative())) { - // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) - Constant *NotC = ConstantExpr::getNot(C); - return BinaryOperator::CreateLShr(NotC, Y); - } + match(C, m_Negative())) + return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y); + // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) && - match(C, m_NonNegative())) { - // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) - Constant *NotC = ConstantExpr::getNot(C); - return BinaryOperator::CreateAShr(NotC, Y); + match(C, m_NonNegative())) + return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y); + + // ~(X + C) --> -(C + 1) - X + if (match(Op0, m_Add(m_Value(X), m_Constant(C)))) + return BinaryOperator::CreateSub(ConstantExpr::getNeg(AddOne(C)), X); + } + + // Use DeMorgan and reassociation to eliminate a 'not' op. + Constant *C1; + if (match(Op1, m_Constant(C1))) { + Constant *C2; + if (match(Op0, m_OneUse(m_Or(m_Not(m_Value(X)), m_Constant(C2))))) { + // (~X | C2) ^ C1 --> ((X & ~C2) ^ -1) ^ C1 --> (X & ~C2) ^ ~C1 + Value *And = Builder.CreateAnd(X, ConstantExpr::getNot(C2)); + return BinaryOperator::CreateXor(And, ConstantExpr::getNot(C1)); + } + if (match(Op0, m_OneUse(m_And(m_Not(m_Value(X)), m_Constant(C2))))) { + // (~X & C2) ^ C1 --> ((X | ~C2) ^ -1) ^ C1 --> (X | ~C2) ^ ~C1 + Value *Or = Builder.CreateOr(X, ConstantExpr::getNot(C2)); + return BinaryOperator::CreateXor(Or, ConstantExpr::getNot(C1)); } } @@ -2567,28 +2795,15 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (match(Op1, m_APInt(RHSC))) { Value *X; const APInt *C; - if (match(Op0, m_Sub(m_APInt(C), m_Value(X)))) { - // ~(c-X) == X-c-1 == X+(-c-1) - if (RHSC->isAllOnesValue()) { - Constant *NewC = ConstantInt::get(I.getType(), -(*C) - 1); - return BinaryOperator::CreateAdd(X, NewC); - } - if (RHSC->isSignMask()) { - // (C - X) ^ signmask -> (C + signmask - X) - Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); - return BinaryOperator::CreateSub(NewC, X); - } - } else if (match(Op0, m_Add(m_Value(X), m_APInt(C)))) { - // ~(X-c) --> (-c-1)-X - if (RHSC->isAllOnesValue()) { - Constant *NewC = ConstantInt::get(I.getType(), -(*C) - 1); - return BinaryOperator::CreateSub(NewC, X); - } - if (RHSC->isSignMask()) { - // (X + C) ^ signmask -> (X + C + signmask) - Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); - return BinaryOperator::CreateAdd(X, NewC); - } + if (RHSC->isSignMask() && match(Op0, m_Sub(m_APInt(C), m_Value(X)))) { + // (C - X) ^ signmask -> (C + signmask - X) + Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); + return BinaryOperator::CreateSub(NewC, X); + } + if (RHSC->isSignMask() && match(Op0, m_Add(m_Value(X), m_APInt(C)))) { + // (X + C) ^ signmask -> (X + C + signmask) + Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); + return BinaryOperator::CreateAdd(X, NewC); } // (X|C1)^C2 -> X^(C1^C2) iff X&~C1 == 0 @@ -2635,82 +2850,52 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) return FoldedLogic; - { - Value *A, *B; - if (match(Op1, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { - if (A == Op0) { // A^(A|B) == A^(B|A) - cast<BinaryOperator>(Op1)->swapOperands(); - std::swap(A, B); - } - if (B == Op0) { // A^(B|A) == (B|A)^A - I.swapOperands(); // Simplified below. - std::swap(Op0, Op1); - } - } else if (match(Op1, m_OneUse(m_And(m_Value(A), m_Value(B))))) { - if (A == Op0) { // A^(A&B) -> A^(B&A) - cast<BinaryOperator>(Op1)->swapOperands(); - std::swap(A, B); - } - if (B == Op0) { // A^(B&A) -> (B&A)^A - I.swapOperands(); // Simplified below. - std::swap(Op0, Op1); - } - } - } - - { - Value *A, *B; - if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { - if (A == Op1) // (B|A)^B == (A|B)^B - std::swap(A, B); - if (B == Op1) // (A|B)^B == A & ~B - return BinaryOperator::CreateAnd(A, Builder.CreateNot(Op1)); - } else if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B))))) { - if (A == Op1) // (A&B)^A -> (B&A)^A - std::swap(A, B); - const APInt *C; - if (B == Op1 && // (B&A)^A == ~B & A - !match(Op1, m_APInt(C))) { // Canonical form is (B&C)^C - return BinaryOperator::CreateAnd(Builder.CreateNot(A), Op1); - } - } - } - - { - Value *A, *B, *C, *D; - // (A ^ C)^(A | B) -> ((~A) & B) ^ C - if (match(Op0, m_Xor(m_Value(D), m_Value(C))) && - match(Op1, m_Or(m_Value(A), m_Value(B)))) { - if (D == A) - return BinaryOperator::CreateXor( - Builder.CreateAnd(Builder.CreateNot(A), B), C); - if (D == B) - return BinaryOperator::CreateXor( - Builder.CreateAnd(Builder.CreateNot(B), A), C); - } - // (A | B)^(A ^ C) -> ((~A) & B) ^ C - if (match(Op0, m_Or(m_Value(A), m_Value(B))) && - match(Op1, m_Xor(m_Value(D), m_Value(C)))) { - if (D == A) - return BinaryOperator::CreateXor( - Builder.CreateAnd(Builder.CreateNot(A), B), C); - if (D == B) - return BinaryOperator::CreateXor( - Builder.CreateAnd(Builder.CreateNot(B), A), C); - } - // (A & B) ^ (A ^ B) -> (A | B) - if (match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_c_Xor(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - // (A ^ B) ^ (A & B) -> (A | B) - if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && - match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - } + // Y ^ (X | Y) --> X & ~Y + // Y ^ (Y | X) --> X & ~Y + if (match(Op1, m_OneUse(m_c_Or(m_Value(X), m_Specific(Op0))))) + return BinaryOperator::CreateAnd(X, Builder.CreateNot(Op0)); + // (X | Y) ^ Y --> X & ~Y + // (Y | X) ^ Y --> X & ~Y + if (match(Op0, m_OneUse(m_c_Or(m_Value(X), m_Specific(Op1))))) + return BinaryOperator::CreateAnd(X, Builder.CreateNot(Op1)); + + // Y ^ (X & Y) --> ~X & Y + // Y ^ (Y & X) --> ~X & Y + if (match(Op1, m_OneUse(m_c_And(m_Value(X), m_Specific(Op0))))) + return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(X)); + // (X & Y) ^ Y --> ~X & Y + // (Y & X) ^ Y --> ~X & Y + // Canonical form is (X & C) ^ C; don't touch that. + // TODO: A 'not' op is better for analysis and codegen, but demanded bits must + // be fixed to prefer that (otherwise we get infinite looping). + if (!match(Op1, m_Constant()) && + match(Op0, m_OneUse(m_c_And(m_Value(X), m_Specific(Op1))))) + return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(X)); + + Value *A, *B, *C; + // (A ^ B) ^ (A | C) --> (~A & C) ^ B -- There are 4 commuted variants. + if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))), + m_OneUse(m_c_Or(m_Deferred(A), m_Value(C)))))) + return BinaryOperator::CreateXor( + Builder.CreateAnd(Builder.CreateNot(A), C), B); + + // (A ^ B) ^ (B | C) --> (~B & C) ^ A -- There are 4 commuted variants. + if (match(&I, m_c_Xor(m_OneUse(m_Xor(m_Value(A), m_Value(B))), + m_OneUse(m_c_Or(m_Deferred(B), m_Value(C)))))) + return BinaryOperator::CreateXor( + Builder.CreateAnd(Builder.CreateNot(B), C), A); + + // (A & B) ^ (A ^ B) -> (A | B) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Xor(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + // (A ^ B) ^ (A & B) -> (A | B) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); // (A & ~B) ^ ~A -> ~(A & B) // (~B & A) ^ ~A -> ~(A & B) - Value *A, *B; if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && match(Op1, m_Not(m_Specific(A)))) return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); @@ -2759,23 +2944,41 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // %res = select i1 %cmp2, i32 %x, i32 %noty // // Same is applicable for smin/umax/umin. - { + if (match(Op1, m_AllOnes()) && Op0->hasOneUse()) { Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(Op0, LHS, RHS).Flavor; - if (Op0->hasOneUse() && SelectPatternResult::isMinOrMax(SPF) && - match(Op1, m_AllOnes())) { - - Value *X; - if (match(RHS, m_Not(m_Value(X)))) - std::swap(RHS, LHS); - - if (match(LHS, m_Not(m_Value(X)))) { + if (SelectPatternResult::isMinOrMax(SPF)) { + // It's possible we get here before the not has been simplified, so make + // sure the input to the not isn't freely invertible. + if (match(LHS, m_Not(m_Value(X))) && !IsFreeToInvert(X, X->hasOneUse())) { Value *NotY = Builder.CreateNot(RHS); return SelectInst::Create( Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); } + + // It's possible we get here before the not has been simplified, so make + // sure the input to the not isn't freely invertible. + if (match(RHS, m_Not(m_Value(Y))) && !IsFreeToInvert(Y, Y->hasOneUse())) { + Value *NotX = Builder.CreateNot(LHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); + } + + // If both sides are freely invertible, then we can get rid of the xor + // completely. + if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + IsFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { + Value *NotLHS = Builder.CreateNot(LHS); + Value *NotRHS = Builder.CreateNot(RHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), NotLHS, NotRHS), + NotLHS, NotRHS); + } } } + if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) + return NewXor; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index cbfbd8a53993..aeb25d530d71 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -136,6 +136,14 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { if (Size > 8 || (Size&(Size-1))) return nullptr; // If not 1/2/4/8 bytes, exit. + // If it is an atomic and alignment is less than the size then we will + // introduce the unaligned memory access which will be later transformed + // into libcall in CodeGen. This is not evident performance gain so disable + // it now. + if (isa<AtomicMemTransferInst>(MI)) + if (CopyDstAlign < Size || CopySrcAlign < Size) + return nullptr; + // Use an integer load+store unless we can find something better. unsigned SrcAddrSp = cast<PointerType>(MI->getArgOperand(1)->getType())->getAddressSpace(); @@ -174,6 +182,9 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { MI->getMetadata(LLVMContext::MD_mem_parallel_loop_access); if (LoopMemParallelMD) L->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + MDNode *AccessGroupMD = MI->getMetadata(LLVMContext::MD_access_group); + if (AccessGroupMD) + L->setMetadata(LLVMContext::MD_access_group, AccessGroupMD); StoreInst *S = Builder.CreateStore(L, Dest); // Alignment from the mem intrinsic will be better, so use it. @@ -182,6 +193,8 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { S->setMetadata(LLVMContext::MD_tbaa, CopyMD); if (LoopMemParallelMD) S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); + if (AccessGroupMD) + S->setMetadata(LLVMContext::MD_access_group, AccessGroupMD); if (auto *MT = dyn_cast<MemTransferInst>(MI)) { // non-atomics can be volatile @@ -215,6 +228,18 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { Alignment = MI->getDestAlignment(); assert(Len && "0-sized memory setting should be removed already."); + // Alignment 0 is identity for alignment 1 for memset, but not store. + if (Alignment == 0) + Alignment = 1; + + // If it is an atomic and alignment is less than the size then we will + // introduce the unaligned memory access which will be later transformed + // into libcall in CodeGen. This is not evident performance gain so disable + // it now. + if (isa<AtomicMemSetInst>(MI)) + if (Alignment < Len) + return nullptr; + // memset(s,c,n) -> store s, c (for n=1,2,4,8) if (Len <= 8 && isPowerOf2_32((uint32_t)Len)) { Type *ITy = IntegerType::get(MI->getContext(), Len*8); // n=1 -> i8. @@ -224,9 +249,6 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { Type *NewDstPtrTy = PointerType::get(ITy, DstAddrSp); Dest = Builder.CreateBitCast(Dest, NewDstPtrTy); - // Alignment 0 is identity for alignment 1 for memset, but not store. - if (Alignment == 0) Alignment = 1; - // Extract the fill value and store. uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest, @@ -648,7 +670,7 @@ static Value *simplifyX86round(IntrinsicInst &II, } Intrinsic::ID ID = (RoundControl == 2) ? Intrinsic::ceil : Intrinsic::floor; - Value *Res = Builder.CreateIntrinsic(ID, {Src}, &II); + Value *Res = Builder.CreateUnaryIntrinsic(ID, Src, &II); if (!IsScalar) { if (auto *C = dyn_cast<Constant>(Mask)) if (C->isAllOnesValue()) @@ -675,7 +697,8 @@ static Value *simplifyX86round(IntrinsicInst &II, return Builder.CreateInsertElement(Dst, Res, (uint64_t)0); } -static Value *simplifyX86movmsk(const IntrinsicInst &II) { +static Value *simplifyX86movmsk(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { Value *Arg = II.getArgOperand(0); Type *ResTy = II.getType(); Type *ArgTy = Arg->getType(); @@ -688,29 +711,46 @@ static Value *simplifyX86movmsk(const IntrinsicInst &II) { if (!ArgTy->isVectorTy()) return nullptr; - auto *C = dyn_cast<Constant>(Arg); - if (!C) - return nullptr; + if (auto *C = dyn_cast<Constant>(Arg)) { + // Extract signbits of the vector input and pack into integer result. + APInt Result(ResTy->getPrimitiveSizeInBits(), 0); + for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { + auto *COp = C->getAggregateElement(I); + if (!COp) + return nullptr; + if (isa<UndefValue>(COp)) + continue; - // Extract signbits of the vector input and pack into integer result. - APInt Result(ResTy->getPrimitiveSizeInBits(), 0); - for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { - auto *COp = C->getAggregateElement(I); - if (!COp) - return nullptr; - if (isa<UndefValue>(COp)) - continue; + auto *CInt = dyn_cast<ConstantInt>(COp); + auto *CFp = dyn_cast<ConstantFP>(COp); + if (!CInt && !CFp) + return nullptr; - auto *CInt = dyn_cast<ConstantInt>(COp); - auto *CFp = dyn_cast<ConstantFP>(COp); - if (!CInt && !CFp) - return nullptr; + if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) + Result.setBit(I); + } + return Constant::getIntegerValue(ResTy, Result); + } - if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) - Result.setBit(I); + // Look for a sign-extended boolean source vector as the argument to this + // movmsk. If the argument is bitcast, look through that, but make sure the + // source of that bitcast is still a vector with the same number of elements. + // TODO: We can also convert a bitcast with wider elements, but that requires + // duplicating the bool source sign bits to match the number of elements + // expected by the movmsk call. + Arg = peekThroughBitcast(Arg); + Value *X; + if (Arg->getType()->isVectorTy() && + Arg->getType()->getVectorNumElements() == ArgTy->getVectorNumElements() && + match(Arg, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + // call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM + unsigned NumElts = X->getType()->getVectorNumElements(); + Type *ScalarTy = Type::getIntNTy(Arg->getContext(), NumElts); + Value *BC = Builder.CreateBitCast(X, ScalarTy); + return Builder.CreateZExtOrTrunc(BC, ResTy); } - return Constant::getIntegerValue(ResTy, Result); + return nullptr; } static Value *simplifyX86insertps(const IntrinsicInst &II, @@ -1133,82 +1173,6 @@ static Value *simplifyX86vpcom(const IntrinsicInst &II, return nullptr; } -static Value *simplifyMinnumMaxnum(const IntrinsicInst &II) { - Value *Arg0 = II.getArgOperand(0); - Value *Arg1 = II.getArgOperand(1); - - // fmin(x, x) -> x - if (Arg0 == Arg1) - return Arg0; - - const auto *C1 = dyn_cast<ConstantFP>(Arg1); - - // fmin(x, nan) -> x - if (C1 && C1->isNaN()) - return Arg0; - - // This is the value because if undef were NaN, we would return the other - // value and cannot return a NaN unless both operands are. - // - // fmin(undef, x) -> x - if (isa<UndefValue>(Arg0)) - return Arg1; - - // fmin(x, undef) -> x - if (isa<UndefValue>(Arg1)) - return Arg0; - - Value *X = nullptr; - Value *Y = nullptr; - if (II.getIntrinsicID() == Intrinsic::minnum) { - // fmin(x, fmin(x, y)) -> fmin(x, y) - // fmin(y, fmin(x, y)) -> fmin(x, y) - if (match(Arg1, m_FMin(m_Value(X), m_Value(Y)))) { - if (Arg0 == X || Arg0 == Y) - return Arg1; - } - - // fmin(fmin(x, y), x) -> fmin(x, y) - // fmin(fmin(x, y), y) -> fmin(x, y) - if (match(Arg0, m_FMin(m_Value(X), m_Value(Y)))) { - if (Arg1 == X || Arg1 == Y) - return Arg0; - } - - // TODO: fmin(nnan x, inf) -> x - // TODO: fmin(nnan ninf x, flt_max) -> x - if (C1 && C1->isInfinity()) { - // fmin(x, -inf) -> -inf - if (C1->isNegative()) - return Arg1; - } - } else { - assert(II.getIntrinsicID() == Intrinsic::maxnum); - // fmax(x, fmax(x, y)) -> fmax(x, y) - // fmax(y, fmax(x, y)) -> fmax(x, y) - if (match(Arg1, m_FMax(m_Value(X), m_Value(Y)))) { - if (Arg0 == X || Arg0 == Y) - return Arg1; - } - - // fmax(fmax(x, y), x) -> fmax(x, y) - // fmax(fmax(x, y), y) -> fmax(x, y) - if (match(Arg0, m_FMax(m_Value(X), m_Value(Y)))) { - if (Arg1 == X || Arg1 == Y) - return Arg0; - } - - // TODO: fmax(nnan x, -inf) -> x - // TODO: fmax(nnan ninf x, -flt_max) -> x - if (C1 && C1->isInfinity()) { - // fmax(x, inf) -> inf - if (!C1->isNegative()) - return Arg1; - } - } - return nullptr; -} - static bool maskIsAllOneOrUndef(Value *Mask) { auto *ConstMask = dyn_cast<Constant>(Mask); if (!ConstMask) @@ -1852,6 +1816,17 @@ Instruction *InstCombiner::visitVACopyInst(VACopyInst &I) { return nullptr; } +static Instruction *canonicalizeConstantArg0ToArg1(CallInst &Call) { + assert(Call.getNumArgOperands() > 1 && "Need at least 2 args to swap"); + Value *Arg0 = Call.getArgOperand(0), *Arg1 = Call.getArgOperand(1); + if (isa<Constant>(Arg0) && !isa<Constant>(Arg1)) { + Call.setArgOperand(0, Arg1); + Call.setArgOperand(1, Arg0); + return &Call; + } + return nullptr; +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallSite to do the heavy /// lifting. @@ -2005,18 +1980,49 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return I; break; + case Intrinsic::fshl: + case Intrinsic::fshr: { + const APInt *SA; + if (match(II->getArgOperand(2), m_APInt(SA))) { + Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); + unsigned BitWidth = SA->getBitWidth(); + uint64_t ShiftAmt = SA->urem(BitWidth); + assert(ShiftAmt != 0 && "SimplifyCall should have handled zero shift"); + // Normalize to funnel shift left. + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + // fshl(X, 0, C) -> shl X, C + // fshl(X, undef, C) -> shl X, C + if (match(Op1, m_Zero()) || match(Op1, m_Undef())) + return BinaryOperator::CreateShl( + Op0, ConstantInt::get(II->getType(), ShiftAmt)); + + // fshl(0, X, C) -> lshr X, (BW-C) + // fshl(undef, X, C) -> lshr X, (BW-C) + if (match(Op0, m_Zero()) || match(Op0, m_Undef())) + return BinaryOperator::CreateLShr( + Op1, ConstantInt::get(II->getType(), BitWidth - ShiftAmt)); + } + + // The shift amount (operand 2) of a funnel shift is modulo the bitwidth, + // so only the low bits of the shift amount are demanded if the bitwidth is + // a power-of-2. + unsigned BitWidth = II->getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(BitWidth)) + break; + APInt Op2Demanded = APInt::getLowBitsSet(BitWidth, Log2_32_Ceil(BitWidth)); + KnownBits Op2Known(BitWidth); + if (SimplifyDemandedBits(II, 2, Op2Demanded, Op2Known)) + return &CI; + break; + } case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: - if (isa<Constant>(II->getArgOperand(0)) && - !isa<Constant>(II->getArgOperand(1))) { - // Canonicalize constants into the RHS. - Value *LHS = II->getArgOperand(0); - II->setArgOperand(0, II->getArgOperand(1)); - II->setArgOperand(1, LHS); - return II; - } + if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) + return I; LLVM_FALLTHROUGH; case Intrinsic::usub_with_overflow: @@ -2034,34 +2040,164 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::minnum: - case Intrinsic::maxnum: { + case Intrinsic::uadd_sat: + case Intrinsic::sadd_sat: + if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) + return I; + LLVM_FALLTHROUGH; + case Intrinsic::usub_sat: + case Intrinsic::ssub_sat: { Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - // Canonicalize constants to the RHS. - if (isa<ConstantFP>(Arg0) && !isa<ConstantFP>(Arg1)) { - II->setArgOperand(0, Arg1); - II->setArgOperand(1, Arg0); - return II; + Intrinsic::ID IID = II->getIntrinsicID(); + + // Make use of known overflow information. + OverflowResult OR; + switch (IID) { + default: + llvm_unreachable("Unexpected intrinsic!"); + case Intrinsic::uadd_sat: + OR = computeOverflowForUnsignedAdd(Arg0, Arg1, II); + if (OR == OverflowResult::NeverOverflows) + return BinaryOperator::CreateNUWAdd(Arg0, Arg1); + if (OR == OverflowResult::AlwaysOverflows) + return replaceInstUsesWith(*II, + ConstantInt::getAllOnesValue(II->getType())); + break; + case Intrinsic::usub_sat: + OR = computeOverflowForUnsignedSub(Arg0, Arg1, II); + if (OR == OverflowResult::NeverOverflows) + return BinaryOperator::CreateNUWSub(Arg0, Arg1); + if (OR == OverflowResult::AlwaysOverflows) + return replaceInstUsesWith(*II, + ConstantInt::getNullValue(II->getType())); + break; + case Intrinsic::sadd_sat: + if (willNotOverflowSignedAdd(Arg0, Arg1, *II)) + return BinaryOperator::CreateNSWAdd(Arg0, Arg1); + break; + case Intrinsic::ssub_sat: + if (willNotOverflowSignedSub(Arg0, Arg1, *II)) + return BinaryOperator::CreateNSWSub(Arg0, Arg1); + break; } - // FIXME: Simplifications should be in instsimplify. - if (Value *V = simplifyMinnumMaxnum(*II)) - return replaceInstUsesWith(*II, V); + // ssub.sat(X, C) -> sadd.sat(X, -C) if C != MIN + Constant *C; + if (IID == Intrinsic::ssub_sat && match(Arg1, m_Constant(C)) && + C->isNotMinSignedValue()) { + Value *NegVal = ConstantExpr::getNeg(C); + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic( + Intrinsic::sadd_sat, Arg0, NegVal)); + } + + // sat(sat(X + Val2) + Val) -> sat(X + (Val+Val2)) + // sat(sat(X - Val2) - Val) -> sat(X - (Val+Val2)) + // if Val and Val2 have the same sign + if (auto *Other = dyn_cast<IntrinsicInst>(Arg0)) { + Value *X; + const APInt *Val, *Val2; + APInt NewVal; + bool IsUnsigned = + IID == Intrinsic::uadd_sat || IID == Intrinsic::usub_sat; + if (Other->getIntrinsicID() == II->getIntrinsicID() && + match(Arg1, m_APInt(Val)) && + match(Other->getArgOperand(0), m_Value(X)) && + match(Other->getArgOperand(1), m_APInt(Val2))) { + if (IsUnsigned) + NewVal = Val->uadd_sat(*Val2); + else if (Val->isNonNegative() == Val2->isNonNegative()) { + bool Overflow; + NewVal = Val->sadd_ov(*Val2, Overflow); + if (Overflow) { + // Both adds together may add more than SignedMaxValue + // without saturating the final result. + break; + } + } else { + // Cannot fold saturated addition with different signs. + break; + } + return replaceInstUsesWith( + *II, Builder.CreateBinaryIntrinsic( + IID, X, ConstantInt::get(II->getType(), NewVal))); + } + } + break; + } + + case Intrinsic::minnum: + case Intrinsic::maxnum: + case Intrinsic::minimum: + case Intrinsic::maximum: { + if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) + return I; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + Intrinsic::ID IID = II->getIntrinsicID(); Value *X, *Y; if (match(Arg0, m_FNeg(m_Value(X))) && match(Arg1, m_FNeg(m_Value(Y))) && (Arg0->hasOneUse() || Arg1->hasOneUse())) { // If both operands are negated, invert the call and negate the result: - // minnum(-X, -Y) --> -(maxnum(X, Y)) - // maxnum(-X, -Y) --> -(minnum(X, Y)) - Intrinsic::ID NewIID = II->getIntrinsicID() == Intrinsic::maxnum ? - Intrinsic::minnum : Intrinsic::maxnum; - Value *NewCall = Builder.CreateIntrinsic(NewIID, { X, Y }, II); + // min(-X, -Y) --> -(max(X, Y)) + // max(-X, -Y) --> -(min(X, Y)) + Intrinsic::ID NewIID; + switch (IID) { + case Intrinsic::maxnum: + NewIID = Intrinsic::minnum; + break; + case Intrinsic::minnum: + NewIID = Intrinsic::maxnum; + break; + case Intrinsic::maximum: + NewIID = Intrinsic::minimum; + break; + case Intrinsic::minimum: + NewIID = Intrinsic::maximum; + break; + default: + llvm_unreachable("unexpected intrinsic ID"); + } + Value *NewCall = Builder.CreateBinaryIntrinsic(NewIID, X, Y, II); Instruction *FNeg = BinaryOperator::CreateFNeg(NewCall); FNeg->copyIRFlags(II); return FNeg; } + + // m(m(X, C2), C1) -> m(X, C) + const APFloat *C1, *C2; + if (auto *M = dyn_cast<IntrinsicInst>(Arg0)) { + if (M->getIntrinsicID() == IID && match(Arg1, m_APFloat(C1)) && + ((match(M->getArgOperand(0), m_Value(X)) && + match(M->getArgOperand(1), m_APFloat(C2))) || + (match(M->getArgOperand(1), m_Value(X)) && + match(M->getArgOperand(0), m_APFloat(C2))))) { + APFloat Res(0.0); + switch (IID) { + case Intrinsic::maxnum: + Res = maxnum(*C1, *C2); + break; + case Intrinsic::minnum: + Res = minnum(*C1, *C2); + break; + case Intrinsic::maximum: + Res = maximum(*C1, *C2); + break; + case Intrinsic::minimum: + Res = minimum(*C1, *C2); + break; + default: + llvm_unreachable("unexpected intrinsic ID"); + } + Instruction *NewCall = Builder.CreateBinaryIntrinsic( + IID, X, ConstantFP::get(Arg0->getType(), Res)); + NewCall->copyIRFlags(II); + return replaceInstUsesWith(*II, NewCall); + } + } + break; } case Intrinsic::fmuladd: { @@ -2079,17 +2215,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { LLVM_FALLTHROUGH; } case Intrinsic::fma: { - Value *Src0 = II->getArgOperand(0); - Value *Src1 = II->getArgOperand(1); - - // Canonicalize constant multiply operand to Src1. - if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { - II->setArgOperand(0, Src1); - II->setArgOperand(1, Src0); - std::swap(Src0, Src1); - } + if (Instruction *I = canonicalizeConstantArg0ToArg1(CI)) + return I; // fma fneg(x), fneg(y), z -> fma x, y, z + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); Value *X, *Y; if (match(Src0, m_FNeg(m_Value(X))) && match(Src1, m_FNeg(m_Value(Y)))) { II->setArgOperand(0, X); @@ -2135,24 +2266,33 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *ExtSrc; if (match(II->getArgOperand(0), m_OneUse(m_FPExt(m_Value(ExtSrc))))) { // Narrow the call: intrinsic (fpext x) -> fpext (intrinsic x) - Value *NarrowII = Builder.CreateIntrinsic(II->getIntrinsicID(), - { ExtSrc }, II); + Value *NarrowII = + Builder.CreateUnaryIntrinsic(II->getIntrinsicID(), ExtSrc, II); return new FPExtInst(NarrowII, II->getType()); } break; } case Intrinsic::cos: case Intrinsic::amdgcn_cos: { - Value *SrcSrc; + Value *X; Value *Src = II->getArgOperand(0); - if (match(Src, m_FNeg(m_Value(SrcSrc))) || - match(Src, m_FAbs(m_Value(SrcSrc)))) { + if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X)))) { // cos(-x) -> cos(x) // cos(fabs(x)) -> cos(x) - II->setArgOperand(0, SrcSrc); + II->setArgOperand(0, X); return II; } - + break; + } + case Intrinsic::sin: { + Value *X; + if (match(II->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) { + // sin(-x) --> -sin(x) + Value *NewSin = Builder.CreateUnaryIntrinsic(Intrinsic::sin, X, II); + Instruction *FNeg = BinaryOperator::CreateFNeg(NewSin); + FNeg->copyFastMathFlags(II); + return FNeg; + } break; } case Intrinsic::ppc_altivec_lvx: @@ -2382,7 +2522,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx_movmsk_pd_256: case Intrinsic::x86_avx_movmsk_ps_256: case Intrinsic::x86_avx2_pmovmskb: - if (Value *V = simplifyX86movmsk(*II)) + if (Value *V = simplifyX86movmsk(*II, Builder)) return replaceInstUsesWith(*II, V); break; @@ -2922,16 +3062,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx_blendv_ps_256: case Intrinsic::x86_avx_blendv_pd_256: case Intrinsic::x86_avx2_pblendvb: { - // Convert blendv* to vector selects if the mask is constant. - // This optimization is convoluted because the intrinsic is defined as - // getting a vector of floats or doubles for the ps and pd versions. - // FIXME: That should be changed. - + // fold (blend A, A, Mask) -> A Value *Op0 = II->getArgOperand(0); Value *Op1 = II->getArgOperand(1); Value *Mask = II->getArgOperand(2); - - // fold (blend A, A, Mask) -> A if (Op0 == Op1) return replaceInstUsesWith(CI, Op0); @@ -2944,6 +3078,33 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Constant *NewSelector = getNegativeIsTrueBoolVec(ConstantMask); return SelectInst::Create(NewSelector, Op1, Op0, "blendv"); } + + // Convert to a vector select if we can bypass casts and find a boolean + // vector condition value. + Value *BoolVec; + Mask = peekThroughBitcast(Mask); + if (match(Mask, m_SExt(m_Value(BoolVec))) && + BoolVec->getType()->isVectorTy() && + BoolVec->getType()->getScalarSizeInBits() == 1) { + assert(Mask->getType()->getPrimitiveSizeInBits() == + II->getType()->getPrimitiveSizeInBits() && + "Not expecting mask and operands with different sizes"); + + unsigned NumMaskElts = Mask->getType()->getVectorNumElements(); + unsigned NumOperandElts = II->getType()->getVectorNumElements(); + if (NumMaskElts == NumOperandElts) + return SelectInst::Create(BoolVec, Op1, Op0); + + // If the mask has less elements than the operands, each mask bit maps to + // multiple elements of the operands. Bitcast back and forth. + if (NumMaskElts < NumOperandElts) { + Value *CastOp0 = Builder.CreateBitCast(Op0, Mask->getType()); + Value *CastOp1 = Builder.CreateBitCast(Op1, Mask->getType()); + Value *Sel = Builder.CreateSelect(BoolVec, CastOp1, CastOp0); + return new BitCastInst(Sel, II->getType()); + } + } + break; } @@ -3275,6 +3436,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, FCmp); } + if (Mask == (N_ZERO | P_ZERO)) { + // Equivalent of == 0. + Value *FCmp = Builder.CreateFCmpOEQ( + Src0, ConstantFP::get(Src0->getType(), 0.0)); + + FCmp->takeName(II); + return replaceInstUsesWith(*II, FCmp); + } + + // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other + if (((Mask & S_NAN) || (Mask & Q_NAN)) && isKnownNeverNaN(Src0, &TLI)) { + II->setArgOperand(1, ConstantInt::get(Src1->getType(), + Mask & ~(S_NAN | Q_NAN))); + return II; + } + const ConstantFP *CVal = dyn_cast<ConstantFP>(Src0); if (!CVal) { if (isa<UndefValue>(Src0)) @@ -3384,22 +3561,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { bool Signed = II->getIntrinsicID() == Intrinsic::amdgcn_sbfe; - // TODO: Also emit sub if only width is constant. - if (!CWidth && COffset && Offset == 0) { - Constant *KSize = ConstantInt::get(COffset->getType(), IntSize); - Value *ShiftVal = Builder.CreateSub(KSize, II->getArgOperand(2)); - ShiftVal = Builder.CreateZExt(ShiftVal, II->getType()); - - Value *Shl = Builder.CreateShl(Src, ShiftVal); - Value *RightShift = Signed ? Builder.CreateAShr(Shl, ShiftVal) - : Builder.CreateLShr(Shl, ShiftVal); - RightShift->takeName(II); - return replaceInstUsesWith(*II, RightShift); - } - if (!CWidth || !COffset) break; + // The case of Width == 0 is handled above, which makes this tranformation + // safe. If Width == 0, then the ashr and lshr instructions become poison + // value since the shift amount would be equal to the bit size. + assert(Width != 0); + // TODO: This allows folding to undef when the hardware has specific // behavior? if (Offset + Width < IntSize) { @@ -3603,6 +3772,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ? Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp; + Type *Ty = SrcLHS->getType(); + if (auto *CmpType = dyn_cast<IntegerType>(Ty)) { + // Promote to next legal integer type. + unsigned Width = CmpType->getBitWidth(); + unsigned NewWidth = Width; + + // Don't do anything for i1 comparisons. + if (Width == 1) + break; + + if (Width <= 16) + NewWidth = 16; + else if (Width <= 32) + NewWidth = 32; + else if (Width <= 64) + NewWidth = 64; + else if (Width > 64) + break; // Can't handle this. + + if (Width != NewWidth) { + IntegerType *CmpTy = Builder.getIntNTy(NewWidth); + if (CmpInst::isSigned(SrcPred)) { + SrcLHS = Builder.CreateSExt(SrcLHS, CmpTy); + SrcRHS = Builder.CreateSExt(SrcRHS, CmpTy); + } else { + SrcLHS = Builder.CreateZExt(SrcLHS, CmpTy); + SrcRHS = Builder.CreateZExt(SrcRHS, CmpTy); + } + } + } else if (!Ty->isFloatTy() && !Ty->isDoubleTy() && !Ty->isHalfTy()) + break; + Value *NewF = Intrinsic::getDeclaration(II->getModule(), NewIID, SrcLHS->getType()); Value *Args[] = { SrcLHS, SrcRHS, @@ -3661,7 +3862,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Scan down this block to see if there is another stack restore in the // same block without an intervening call/alloca. BasicBlock::iterator BI(II); - TerminatorInst *TI = II->getParent()->getTerminator(); + Instruction *TI = II->getParent()->getTerminator(); bool CannotRemove = false; for (++BI; &*BI != TI; ++BI) { if (isa<AllocaInst>(BI)) { @@ -3788,8 +3989,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); // isKnownNonNull -> nonnull attribute - if (isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) + if (!II->hasRetAttr(Attribute::NonNull) && + isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) { II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + return II; + } } // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) @@ -3889,7 +4093,11 @@ Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { auto InstCombineRAUW = [this](Instruction *From, Value *With) { replaceInstUsesWith(*From, With); }; - LibCallSimplifier Simplifier(DL, &TLI, ORE, InstCombineRAUW); + auto InstCombineErase = [this](Instruction *I) { + eraseInstFromFunction(*I); + }; + LibCallSimplifier Simplifier(DL, &TLI, ORE, InstCombineRAUW, + InstCombineErase); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index fd59c3a7c0c3..1201ac196ec0 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -492,12 +492,19 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { } /// Rotate left/right may occur in a wider type than necessary because of type -/// promotion rules. Try to narrow all of the component instructions. +/// promotion rules. Try to narrow the inputs and convert to funnel shift. Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { assert((isa<VectorType>(Trunc.getSrcTy()) || shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && "Don't narrow to an illegal scalar type"); + // Bail out on strange types. It is possible to handle some of these patterns + // even with non-power-of-2 sizes, but it is not a likely scenario. + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + if (!isPowerOf2_32(NarrowWidth)) + return nullptr; + // First, find an or'd pair of opposite shifts with the same shifted operand: // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) Value *Or0, *Or1; @@ -514,22 +521,38 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { if (ShiftOpcode0 == ShiftOpcode1) return nullptr; - // The shift amounts must add up to the narrow bit width. - Value *ShAmt; - bool SubIsOnLHS; - Type *DestTy = Trunc.getType(); - unsigned NarrowWidth = DestTy->getScalarSizeInBits(); - if (match(ShAmt0, - m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) { - ShAmt = ShAmt1; - SubIsOnLHS = true; - } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), - m_Specific(ShAmt0))))) { - ShAmt = ShAmt0; - SubIsOnLHS = false; - } else { + // Match the shift amount operands for a rotate pattern. This always matches + // a subtraction on the R operand. + auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { + // The shift amounts may add up to the narrow bit width: + // (shl ShVal, L) | (lshr ShVal, Width - L) + if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) + return L; + + // The shift amount may be masked with negation: + // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) + Value *X; + unsigned Mask = Width - 1; + if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && + match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) + return X; + + // Same as above, but the shift amount may be extended after masking: + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) + return X; + return nullptr; + }; + + Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); + bool SubIsOnLHS = false; + if (!ShAmt) { + ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth); + SubIsOnLHS = true; } + if (!ShAmt) + return nullptr; // The shifted value must have high zeros in the wide type. Typically, this // will be a zext, but it could also be the result of an 'and' or 'shift'. @@ -540,23 +563,15 @@ Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { // We have an unnecessarily wide rotate! // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) - // Narrow it down to eliminate the zext/trunc: - // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1') + // Narrow the inputs and convert to funnel shift intrinsic: + // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); - Value *NegShAmt = Builder.CreateNeg(NarrowShAmt); - - // Mask both shift amounts to ensure there's no UB from oversized shifts. - Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1); - Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC); - Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC); - - // Truncate the original value and use narrow ops. Value *X = Builder.CreateTrunc(ShVal, DestTy); - Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt; - Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt; - Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0); - Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1); - return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1); + bool IsFshl = (!SubIsOnLHS && ShiftOpcode0 == BinaryOperator::Shl) || + (SubIsOnLHS && ShiftOpcode1 == BinaryOperator::Shl); + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy); + return IntrinsicInst::Create(F, { X, X, NarrowShAmt }); } /// Try to narrow the width of math or bitwise logic instructions by pulling a @@ -706,12 +721,35 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (SimplifyDemandedInstructionBits(CI)) return &CI; - // Canonicalize trunc x to i1 -> (icmp ne (and x, 1), 0), likewise for vector. if (DestTy->getScalarSizeInBits() == 1) { - Constant *One = ConstantInt::get(SrcTy, 1); - Src = Builder.CreateAnd(Src, One); Value *Zero = Constant::getNullValue(Src->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, Src, Zero); + if (DestTy->isIntegerTy()) { + // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only). + // TODO: We canonicalize to more instructions here because we are probably + // lacking equivalent analysis for trunc relative to icmp. There may also + // be codegen concerns. If those trunc limitations were removed, we could + // remove this transform. + Value *And = Builder.CreateAnd(Src, ConstantInt::get(SrcTy, 1)); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + + // For vectors, we do not canonicalize all truncs to icmp, so optimize + // patterns that would be covered within visitICmpInst. + Value *X; + const APInt *C; + if (match(Src, m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 + APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C); + Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_APInt(C)), + m_Deferred(X))))) { + // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 + APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C) | 1; + Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } } // FIXME: Maybe combine the next two transforms to handle the no cast case @@ -1061,12 +1099,9 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { Value *Src = CI.getOperand(0); Type *SrcTy = Src->getType(), *DestTy = CI.getType(); - // Attempt to extend the entire input expression tree to the destination - // type. Only do this if the dest type is a simple type, don't convert the - // expression tree to something weird like i93 unless the source is also - // strange. + // Try to extend the entire expression tree to the wide destination type. unsigned BitsToClear; - if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && + if (shouldChangeType(SrcTy, DestTy) && canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { assert(BitsToClear <= SrcTy->getScalarSizeInBits() && "Can't clear more bits than in SrcTy"); @@ -1343,12 +1378,8 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { return replaceInstUsesWith(CI, ZExt); } - // Attempt to extend the entire input expression tree to the destination - // type. Only do this if the dest type is a simple type, don't convert the - // expression tree to something weird like i93 unless the source is also - // strange. - if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && - canEvaluateSExtd(Src, DestTy)) { + // Try to extend the entire expression tree to the wide destination type. + if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) { // Okay, we can transform this! Insert the new expression now. LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1589,8 +1620,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { } // (fptrunc (fneg x)) -> (fneg (fptrunc x)) - if (BinaryOperator::isFNeg(OpI)) { - Value *InnerTrunc = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); + Value *X; + if (match(OpI, m_FNeg(m_Value(X)))) { + Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); return BinaryOperator::CreateFNegFMF(InnerTrunc, OpI); } } diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 6de92a4842ab..b5bbb09935e2 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -522,11 +522,9 @@ static Value *evaluateGEPOffsetExpression(User *GEP, InstCombiner &IC, } // Otherwise, there is an index. The computation we will do will be modulo - // the pointer size, so get it. - uint64_t PtrSizeMask = ~0ULL >> (64-IntPtrWidth); - - Offset &= PtrSizeMask; - VariableScale &= PtrSizeMask; + // the pointer size. + Offset = SignExtend64(Offset, IntPtrWidth); + VariableScale = SignExtend64(VariableScale, IntPtrWidth); // To do this transformation, any constant index must be a multiple of the // variable scale factor. For example, we can evaluate "12 + 4*i" as "3 + i", @@ -909,7 +907,8 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } // If all indices are the same, just compare the base pointers. - if (IndicesTheSame) + Type *BaseType = GEPLHS->getOperand(0)->getType(); + if (IndicesTheSame && CmpInst::makeCmpResultType(BaseType) == I.getType()) return new ICmpInst(Cond, GEPLHS->getOperand(0), GEPRHS->getOperand(0)); // If we're comparing GEPs with two base pointers that only differ in type @@ -976,7 +975,7 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, if (NumDifferences == 0) // SAME GEP? return replaceInstUsesWith(I, // No comparison is needed here. - Builder.getInt1(ICmpInst::isTrueWhenEqual(Cond))); + ConstantInt::get(I.getType(), ICmpInst::isTrueWhenEqual(Cond))); else if (NumDifferences == 1 && GEPsInBounds) { Value *LHSV = GEPLHS->getOperand(DiffOperand); @@ -1079,19 +1078,20 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate()))); } -/// Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI, +/// Fold "icmp pred (X+C), X". +Instruction *InstCombiner::foldICmpAddOpConst(Value *X, const APInt &C, ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" // operators. + assert(!!C && "C should not be zero!"); // (X+1) <u X --> X >u (MAXUINT-1) --> X == 255 // (X+2) <u X --> X >u (MAXUINT-2) --> X > 253 // (X+MAXUINT) <u X --> X >u (MAXUINT-MAXUINT) --> X != 0 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { - Value *R = - ConstantExpr::getSub(ConstantInt::getAllOnesValue(CI->getType()), CI); + Constant *R = ConstantInt::get(X->getType(), + APInt::getMaxValue(C.getBitWidth()) - C); return new ICmpInst(ICmpInst::ICMP_UGT, X, R); } @@ -1099,11 +1099,10 @@ Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI, // (X+2) >u X --> X <u (0-2) --> X <u 254 // (X+MAXUINT) >u X --> X <u (0-MAXUINT) --> X <u 1 --> X == 0 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) - return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantExpr::getNeg(CI)); + return new ICmpInst(ICmpInst::ICMP_ULT, X, + ConstantInt::get(X->getType(), -C)); - unsigned BitWidth = CI->getType()->getPrimitiveSizeInBits(); - ConstantInt *SMax = ConstantInt::get(X->getContext(), - APInt::getSignedMaxValue(BitWidth)); + APInt SMax = APInt::getSignedMaxValue(C.getBitWidth()); // (X+ 1) <s X --> X >s (MAXSINT-1) --> X == 127 // (X+ 2) <s X --> X >s (MAXSINT-2) --> X >s 125 @@ -1112,7 +1111,8 @@ Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI, // (X+ -2) <s X --> X >s (MAXSINT- -2) --> X >s 126 // (X+ -1) <s X --> X >s (MAXSINT- -1) --> X != 127 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantExpr::getSub(SMax, CI)); + return new ICmpInst(ICmpInst::ICMP_SGT, X, + ConstantInt::get(X->getType(), SMax - C)); // (X+ 1) >s X --> X <s (MAXSINT-(1-1)) --> X != 127 // (X+ 2) >s X --> X <s (MAXSINT-(2-1)) --> X <s 126 @@ -1122,8 +1122,8 @@ Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI, // (X+ -1) >s X --> X <s (MAXSINT-(-1-1)) --> X == -128 assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE); - Constant *C = Builder.getInt(CI->getValue() - 1); - return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C)); + return new ICmpInst(ICmpInst::ICMP_SLT, X, + ConstantInt::get(X->getType(), SMax - (C - 1))); } /// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" -> @@ -1333,17 +1333,12 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { return nullptr; } -// Fold icmp Pred X, C. +/// Fold icmp Pred X, C. +/// TODO: This code structure does not make sense. The saturating add fold +/// should be moved to some other helper and extended as noted below (it is also +/// possible that code has been made unnecessary - do we canonicalize IR to +/// overflow/saturating intrinsics or not?). Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { - CmpInst::Predicate Pred = Cmp.getPredicate(); - Value *X = Cmp.getOperand(0); - - const APInt *C; - if (!match(Cmp.getOperand(1), m_APInt(C))) - return nullptr; - - Value *A = nullptr, *B = nullptr; - // Match the following pattern, which is a common idiom when writing // overflow-safe integer arithmetic functions. The source performs an addition // in wider type and explicitly checks for overflow using comparisons against @@ -1355,37 +1350,62 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { // // sum = a + b // if (sum+128 >u 255) ... -> llvm.sadd.with.overflow.i8 - { - ConstantInt *CI2; // I = icmp ugt (add (add A, B), CI2), CI - if (Pred == ICmpInst::ICMP_UGT && - match(X, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) - if (Instruction *Res = processUGT_ADDCST_ADD( - Cmp, A, B, CI2, cast<ConstantInt>(Cmp.getOperand(1)), *this)) - return Res; - } + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0), *Op1 = Cmp.getOperand(1); + Value *A, *B; + ConstantInt *CI, *CI2; // I = icmp ugt (add (add A, B), CI2), CI + if (Pred == ICmpInst::ICMP_UGT && match(Op1, m_ConstantInt(CI)) && + match(Op0, m_Add(m_Add(m_Value(A), m_Value(B)), m_ConstantInt(CI2)))) + if (Instruction *Res = processUGT_ADDCST_ADD(Cmp, A, B, CI2, CI, *this)) + return Res; + + return nullptr; +} - // FIXME: Use m_APInt to allow folds for splat constants. - ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1)); - if (!CI) +/// Canonicalize icmp instructions based on dominating conditions. +Instruction *InstCombiner::foldICmpWithDominatingICmp(ICmpInst &Cmp) { + // This is a cheap/incomplete check for dominance - just match a single + // predecessor with a conditional branch. + BasicBlock *CmpBB = Cmp.getParent(); + BasicBlock *DomBB = CmpBB->getSinglePredecessor(); + if (!DomBB) return nullptr; - // Canonicalize icmp instructions based on dominating conditions. - BasicBlock *Parent = Cmp.getParent(); - BasicBlock *Dom = Parent->getSinglePredecessor(); - auto *BI = Dom ? dyn_cast<BranchInst>(Dom->getTerminator()) : nullptr; - ICmpInst::Predicate Pred2; + Value *DomCond; BasicBlock *TrueBB, *FalseBB; - ConstantInt *CI2; - if (BI && match(BI, m_Br(m_ICmp(Pred2, m_Specific(X), m_ConstantInt(CI2)), - TrueBB, FalseBB)) && - TrueBB != FalseBB) { - ConstantRange CR = - ConstantRange::makeAllowedICmpRegion(Pred, CI->getValue()); + if (!match(DomBB->getTerminator(), m_Br(m_Value(DomCond), TrueBB, FalseBB))) + return nullptr; + + assert((TrueBB == CmpBB || FalseBB == CmpBB) && + "Predecessor block does not point to successor?"); + + // The branch should get simplified. Don't bother simplifying this condition. + if (TrueBB == FalseBB) + return nullptr; + + // Try to simplify this compare to T/F based on the dominating condition. + Optional<bool> Imp = isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); + if (Imp) + return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp)); + + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0), *Y = Cmp.getOperand(1); + ICmpInst::Predicate DomPred; + const APInt *C, *DomC; + if (match(DomCond, m_ICmp(DomPred, m_Specific(X), m_APInt(DomC))) && + match(Y, m_APInt(C))) { + // We have 2 compares of a variable with constants. Calculate the constant + // ranges of those compares to see if we can transform the 2nd compare: + // DomBB: + // DomCond = icmp DomPred X, DomC + // br DomCond, CmpBB, FalseBB + // CmpBB: + // Cmp = icmp Pred X, C + ConstantRange CR = ConstantRange::makeAllowedICmpRegion(Pred, *C); ConstantRange DominatingCR = - (Parent == TrueBB) - ? ConstantRange::makeExactICmpRegion(Pred2, CI2->getValue()) - : ConstantRange::makeExactICmpRegion( - CmpInst::getInversePredicate(Pred2), CI2->getValue()); + (CmpBB == TrueBB) ? ConstantRange::makeExactICmpRegion(DomPred, *DomC) + : ConstantRange::makeExactICmpRegion( + CmpInst::getInversePredicate(DomPred), *DomC); ConstantRange Intersection = DominatingCR.intersectWith(CR); ConstantRange Difference = DominatingCR.difference(CR); if (Intersection.isEmptySet()) @@ -1393,23 +1413,20 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { if (Difference.isEmptySet()) return replaceInstUsesWith(Cmp, Builder.getTrue()); - // If this is a normal comparison, it demands all bits. If it is a sign - // bit comparison, it only demands the sign bit. - bool UnusedBit; - bool IsSignBit = isSignBitCheck(Pred, CI->getValue(), UnusedBit); - // Canonicalizing a sign bit comparison that gets used in a branch, // pessimizes codegen by generating branch on zero instruction instead // of a test and branch. So we avoid canonicalizing in such situations // because test and branch instruction has better branch displacement // than compare and branch instruction. + bool UnusedBit; + bool IsSignBit = isSignBitCheck(Pred, *C, UnusedBit); if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp))) return nullptr; - if (auto *AI = Intersection.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*AI)); - if (auto *AD = Difference.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_NE, X, Builder.getInt(*AD)); + if (const APInt *EqC = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*EqC)); + if (const APInt *NeC = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, X, Builder.getInt(*NeC)); } return nullptr; @@ -1498,16 +1515,25 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, } } - // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) - // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && *XorC == ~C && (C + 1).isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); - - // (icmp ult (xor X, C), -C) -> (icmp uge X, C) - // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && *XorC == -C && C.isPowerOf2()) - return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); - + // Mask constant magic can eliminate an 'xor' with unsigned compares. + if (Pred == ICmpInst::ICMP_UGT) { + // (xor X, ~C) >u C --> X <u ~C (when C+1 is a power of 2) + if (*XorC == ~C && (C + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + // (xor X, C) >u C --> X >u C (when C+1 is a power of 2) + if (*XorC == C && (C + 1).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); + } + if (Pred == ICmpInst::ICMP_ULT) { + // (xor X, -C) <u C --> X >u ~C (when C is a power of 2) + if (*XorC == -C && C.isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), ~C)); + // (xor X, C) <u C --> X >u ~C (when -C is a power of 2) + if (*XorC == C && (-C).isPowerOf2()) + return new ICmpInst(ICmpInst::ICMP_UGT, X, + ConstantInt::get(X->getType(), ~C)); + } return nullptr; } @@ -1598,6 +1624,13 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, const APInt &C1) { + // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 + // TODO: We canonicalize to the longer form for scalars because we have + // better analysis/folds for icmp, and codegen may be better with icmp. + if (Cmp.getPredicate() == CmpInst::ICMP_NE && Cmp.getType()->isVectorTy() && + C1.isNullValue() && match(And->getOperand(1), m_One())) + return new TruncInst(And->getOperand(0), Cmp.getType()); + const APInt *C2; if (!match(And->getOperand(1), m_APInt(C2))) return nullptr; @@ -2336,13 +2369,19 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, Type *Ty = Add->getType(); CmpInst::Predicate Pred = Cmp.getPredicate(); + if (!Add->hasOneUse()) + return nullptr; + // If the add does not wrap, we can always adjust the compare by subtracting - // the constants. Equality comparisons are handled elsewhere. SGE/SLE are - // canonicalized to SGT/SLT. - if (Add->hasNoSignedWrap() && - (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { + // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE + // are canonicalized to SGT/SLT/UGT/ULT. + if ((Add->hasNoSignedWrap() && + (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) || + (Add->hasNoUnsignedWrap() && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT))) { bool Overflow; - APInt NewC = C.ssub_ov(*C2, Overflow); + APInt NewC = + Cmp.isSigned() ? C.ssub_ov(*C2, Overflow) : C.usub_ov(*C2, Overflow); // If there is overflow, the result must be true or false. // TODO: Can we assert there is no overflow because InstSimplify always // handles those cases? @@ -2366,9 +2405,6 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); } - if (!Add->hasOneUse()) - return nullptr; - // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 @@ -2729,6 +2765,7 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, // Handle icmp {eq|ne} <intrinsic>, Constant. Type *Ty = II->getType(); + unsigned BitWidth = C.getBitWidth(); switch (II->getIntrinsicID()) { case Intrinsic::bswap: Worklist.Add(II); @@ -2737,21 +2774,39 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, return &Cmp; case Intrinsic::ctlz: - case Intrinsic::cttz: + case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != - if (C == C.getBitWidth()) { + if (C == BitWidth) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); Cmp.setOperand(1, ConstantInt::getNullValue(Ty)); return &Cmp; } + + // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set + // and Mask1 has bits 0..C+1 set. Similar for ctl, but for high bits. + // Limit to one use to ensure we don't increase instruction count. + unsigned Num = C.getLimitedValue(BitWidth); + if (Num != BitWidth && II->hasOneUse()) { + bool IsTrailing = II->getIntrinsicID() == Intrinsic::cttz; + APInt Mask1 = IsTrailing ? APInt::getLowBitsSet(BitWidth, Num + 1) + : APInt::getHighBitsSet(BitWidth, Num + 1); + APInt Mask2 = IsTrailing + ? APInt::getOneBitSet(BitWidth, Num) + : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); + Cmp.setOperand(0, Builder.CreateAnd(II->getArgOperand(0), Mask1)); + Cmp.setOperand(1, ConstantInt::get(Ty, Mask2)); + Worklist.Add(II); + return &Cmp; + } break; + } case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != bool IsZero = C.isNullValue(); - if (IsZero || C == C.getBitWidth()) { + if (IsZero || C == BitWidth) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); auto *NewOp = @@ -2870,15 +2925,25 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { /// In this case, we are looking for comparisons that look like /// a check for a lossy truncation. /// Folds: -/// x & (-1 >> y) SrcPred x to x DstPred (-1 >> y) +/// icmp SrcPred (x & Mask), x to icmp DstPred x, Mask +/// Where Mask is some pattern that produces all-ones in low bits: +/// (-1 >> y) +/// ((-1 << y) >> y) <- non-canonical, has extra uses +/// ~(-1 << y) +/// ((1 << y) + (-1)) <- non-canonical, has extra uses /// The Mask can be a constant, too. /// For some predicates, the operands are commutative. /// For others, x can only be on a specific side. static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, InstCombiner::BuilderTy &Builder) { ICmpInst::Predicate SrcPred; - Value *X, *M; - auto m_Mask = m_CombineOr(m_LShr(m_AllOnes(), m_Value()), m_LowBitMask()); + Value *X, *M, *Y; + auto m_VariableMask = m_CombineOr( + m_CombineOr(m_Not(m_Shl(m_AllOnes(), m_Value())), + m_Add(m_Shl(m_One(), m_Value()), m_AllOnes())), + m_CombineOr(m_LShr(m_AllOnes(), m_Value()), + m_LShr(m_Shl(m_AllOnes(), m_Value(Y)), m_Deferred(Y)))); + auto m_Mask = m_CombineOr(m_VariableMask, m_LowBitMask()); if (!match(&I, m_c_ICmp(SrcPred, m_c_And(m_CombineAnd(m_Mask, m_Value(M)), m_Value(X)), m_Deferred(X)))) @@ -2924,12 +2989,20 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // x & (-1 >> y) s>= x -> x s<= (-1 >> y) if (X != I.getOperand(1)) // X must be on RHS of comparison! return nullptr; // Ignore the other case. + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; DstPred = ICmpInst::Predicate::ICMP_SLE; break; case ICmpInst::Predicate::ICMP_SLT: // x & (-1 >> y) s< x -> x s> (-1 >> y) if (X != I.getOperand(1)) // X must be on RHS of comparison! return nullptr; // Ignore the other case. + if (!match(M, m_Constant())) // Can not do this fold with non-constant. + return nullptr; + if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. + return nullptr; DstPred = ICmpInst::Predicate::ICMP_SGT; break; case ICmpInst::Predicate::ICMP_SLE: @@ -3034,6 +3107,18 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { return nullptr; const CmpInst::Predicate Pred = I.getPredicate(); + Value *X; + + // Convert add-with-unsigned-overflow comparisons into a 'not' with compare. + // (Op1 + X) <u Op1 --> ~Op1 <u X + // Op0 >u (Op0 + X) --> X >u ~Op0 + if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) && + Pred == ICmpInst::ICMP_ULT) + return new ICmpInst(Pred, Builder.CreateNot(Op1), X); + if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) && + Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(Pred, X, Builder.CreateNot(Op0)); + bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; if (BO0 && isa<OverflowingBinaryOperator>(BO0)) NoOp0WrapProblem = @@ -4598,6 +4683,83 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, } } +// Transform pattern like: +// (1 << Y) u<= X or ~(-1 << Y) u< X or ((1 << Y)+(-1)) u< X +// (1 << Y) u> X or ~(-1 << Y) u>= X or ((1 << Y)+(-1)) u>= X +// Into: +// (X l>> Y) != 0 +// (X l>> Y) == 0 +static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred, NewPred; + Value *X, *Y; + if (match(&Cmp, + m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) { + // We want X to be the icmp's second operand, so swap predicate if it isn't. + if (Cmp.getOperand(0) == X) + Pred = Cmp.getSwappedPredicate(); + + switch (Pred) { + case ICmpInst::ICMP_ULE: + NewPred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_UGT: + NewPred = ICmpInst::ICMP_EQ; + break; + default: + return nullptr; + } + } else if (match(&Cmp, m_c_ICmp(Pred, + m_OneUse(m_CombineOr( + m_Not(m_Shl(m_AllOnes(), m_Value(Y))), + m_Add(m_Shl(m_One(), m_Value(Y)), + m_AllOnes()))), + m_Value(X)))) { + // The variant with 'add' is not canonical, (the variant with 'not' is) + // we only get it because it has extra uses, and can't be canonicalized, + + // We want X to be the icmp's second operand, so swap predicate if it isn't. + if (Cmp.getOperand(0) == X) + Pred = Cmp.getSwappedPredicate(); + + switch (Pred) { + case ICmpInst::ICMP_ULT: + NewPred = ICmpInst::ICMP_NE; + break; + case ICmpInst::ICMP_UGE: + NewPred = ICmpInst::ICMP_EQ; + break; + default: + return nullptr; + } + } else + return nullptr; + + Value *NewX = Builder.CreateLShr(X, Y, X->getName() + ".highbits"); + Constant *Zero = Constant::getNullValue(NewX->getType()); + return CmpInst::Create(Instruction::ICmp, NewPred, NewX, Zero); +} + +static Instruction *foldVectorCmp(CmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + // If both arguments of the cmp are shuffles that use the same mask and + // shuffle within a single vector, move the shuffle after the cmp. + Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1); + Value *V1, *V2; + Constant *M; + if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(M))) && + match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(M))) && + V1->getType() == V2->getType() && + (LHS->hasOneUse() || RHS->hasOneUse())) { + // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M + CmpInst::Predicate P = Cmp.getPredicate(); + Value *NewCmp = isa<ICmpInst>(Cmp) ? Builder.CreateICmp(P, V1, V2) + : Builder.CreateFCmp(P, V1, V2); + return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); + } + return nullptr; +} + Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -4645,6 +4807,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithConstant(I)) return Res; + if (Instruction *Res = foldICmpWithDominatingICmp(I)) + return Res; + if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -4857,16 +5022,24 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { return ExtractValueInst::Create(ACXI, 1); { - Value *X; ConstantInt *Cst; + Value *X; + const APInt *C; // icmp X+Cst, X - if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X) - return foldICmpAddOpConst(X, Cst, I.getPredicate()); + if (match(Op0, m_Add(m_Value(X), m_APInt(C))) && Op1 == X) + return foldICmpAddOpConst(X, *C, I.getPredicate()); // icmp X, X+Cst - if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) - return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate()); + if (match(Op1, m_Add(m_Value(X), m_APInt(C))) && Op0 == X) + return foldICmpAddOpConst(X, *C, I.getSwappedPredicate()); } + if (Instruction *Res = foldICmpWithHighBitMask(I, Builder)) + return Res; + + if (I.getType()->isVectorTy()) + if (Instruction *Res = foldVectorCmp(I, Builder)) + return Res; + return Changed ? &I : nullptr; } @@ -5109,6 +5282,117 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, return new ICmpInst(Pred, LHSI->getOperand(0), RHSInt); } +/// Fold (C / X) < 0.0 --> X < 0.0 if possible. Swap predicate if necessary. +static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, + Constant *RHSC) { + // When C is not 0.0 and infinities are not allowed: + // (C / X) < 0.0 is a sign-bit test of X + // (C / X) < 0.0 --> X < 0.0 (if C is positive) + // (C / X) < 0.0 --> X > 0.0 (if C is negative, swap the predicate) + // + // Proof: + // Multiply (C / X) < 0.0 by X * X / C. + // - X is non zero, if it is the flag 'ninf' is violated. + // - C defines the sign of X * X * C. Thus it also defines whether to swap + // the predicate. C is also non zero by definition. + // + // Thus X * X / C is non zero and the transformation is valid. [qed] + + FCmpInst::Predicate Pred = I.getPredicate(); + + // Check that predicates are valid. + if ((Pred != FCmpInst::FCMP_OGT) && (Pred != FCmpInst::FCMP_OLT) && + (Pred != FCmpInst::FCMP_OGE) && (Pred != FCmpInst::FCMP_OLE)) + return nullptr; + + // Check that RHS operand is zero. + if (!match(RHSC, m_AnyZeroFP())) + return nullptr; + + // Check fastmath flags ('ninf'). + if (!LHSI->hasNoInfs() || !I.hasNoInfs()) + return nullptr; + + // Check the properties of the dividend. It must not be zero to avoid a + // division by zero (see Proof). + const APFloat *C; + if (!match(LHSI->getOperand(0), m_APFloat(C))) + return nullptr; + + if (C->isZero()) + return nullptr; + + // Get swapped predicate if necessary. + if (C->isNegative()) + Pred = I.getSwappedPredicate(); + + return new FCmpInst(Pred, LHSI->getOperand(1), RHSC, "", &I); +} + +/// Optimize fabs(X) compared with zero. +static Instruction *foldFabsWithFcmpZero(FCmpInst &I) { + Value *X; + if (!match(I.getOperand(0), m_Intrinsic<Intrinsic::fabs>(m_Value(X))) || + !match(I.getOperand(1), m_PosZeroFP())) + return nullptr; + + auto replacePredAndOp0 = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { + I->setPredicate(P); + I->setOperand(0, X); + return I; + }; + + switch (I.getPredicate()) { + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OLT: + // fabs(X) >= 0.0 --> true + // fabs(X) < 0.0 --> false + llvm_unreachable("fcmp should have simplified"); + + case FCmpInst::FCMP_OGT: + // fabs(X) > 0.0 --> X != 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_ONE, X); + + case FCmpInst::FCMP_UGT: + // fabs(X) u> 0.0 --> X u!= 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_UNE, X); + + case FCmpInst::FCMP_OLE: + // fabs(X) <= 0.0 --> X == 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_OEQ, X); + + case FCmpInst::FCMP_ULE: + // fabs(X) u<= 0.0 --> X u== 0.0 + return replacePredAndOp0(&I, FCmpInst::FCMP_UEQ, X); + + case FCmpInst::FCMP_OGE: + // fabs(X) >= 0.0 --> !isnan(X) + assert(!I.hasNoNaNs() && "fcmp should have simplified"); + return replacePredAndOp0(&I, FCmpInst::FCMP_ORD, X); + + case FCmpInst::FCMP_ULT: + // fabs(X) u< 0.0 --> isnan(X) + assert(!I.hasNoNaNs() && "fcmp should have simplified"); + return replacePredAndOp0(&I, FCmpInst::FCMP_UNO, X); + + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_ORD: + case FCmpInst::FCMP_UNO: + // Look through the fabs() because it doesn't change anything but the sign. + // fabs(X) == 0.0 --> X == 0.0, + // fabs(X) != 0.0 --> X != 0.0 + // isnan(fabs(X)) --> isnan(X) + // !isnan(fabs(X) --> !isnan(X) + return replacePredAndOp0(&I, I.getPredicate(), X); + + default: + return nullptr; + } +} + Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { bool Changed = false; @@ -5153,11 +5437,11 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0)) { + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) { I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); return &I; } - if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1)) { + if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) { I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); return &I; } @@ -5178,128 +5462,93 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { return nullptr; } - // Handle fcmp with constant RHS - if (Constant *RHSC = dyn_cast<Constant>(Op1)) { - if (Instruction *LHSI = dyn_cast<Instruction>(Op0)) - switch (LHSI->getOpcode()) { - case Instruction::FPExt: { - // fcmp (fpext x), C -> fcmp x, (fptrunc C) if fptrunc is lossless - FPExtInst *LHSExt = cast<FPExtInst>(LHSI); - ConstantFP *RHSF = dyn_cast<ConstantFP>(RHSC); - if (!RHSF) - break; - - const fltSemantics *Sem; - // FIXME: This shouldn't be here. - if (LHSExt->getSrcTy()->isHalfTy()) - Sem = &APFloat::IEEEhalf(); - else if (LHSExt->getSrcTy()->isFloatTy()) - Sem = &APFloat::IEEEsingle(); - else if (LHSExt->getSrcTy()->isDoubleTy()) - Sem = &APFloat::IEEEdouble(); - else if (LHSExt->getSrcTy()->isFP128Ty()) - Sem = &APFloat::IEEEquad(); - else if (LHSExt->getSrcTy()->isX86_FP80Ty()) - Sem = &APFloat::x87DoubleExtended(); - else if (LHSExt->getSrcTy()->isPPC_FP128Ty()) - Sem = &APFloat::PPCDoubleDouble(); - else - break; - - bool Lossy; - APFloat F = RHSF->getValueAPF(); - F.convert(*Sem, APFloat::rmNearestTiesToEven, &Lossy); - - // Avoid lossy conversions and denormals. Zero is a special case - // that's OK to convert. - APFloat Fabs = F; - Fabs.clearSign(); - if (!Lossy && - ((Fabs.compare(APFloat::getSmallestNormalized(*Sem)) != - APFloat::cmpLessThan) || Fabs.isZero())) - - return new FCmpInst(Pred, LHSExt->getOperand(0), - ConstantFP::get(RHSC->getContext(), F)); - break; - } - case Instruction::PHI: - // Only fold fcmp into the PHI if the phi and fcmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) - return NV; - break; - case Instruction::SIToFP: - case Instruction::UIToFP: - if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC)) + // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: + // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 + if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) { + I.setOperand(1, ConstantFP::getNullValue(Op1->getType())); + return &I; + } + + // Handle fcmp with instruction LHS and constant RHS. + Instruction *LHSI; + Constant *RHSC; + if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) { + switch (LHSI->getOpcode()) { + case Instruction::PHI: + // Only fold fcmp into the PHI if the phi and fcmp are in the same + // block. If in the same block, we're encouraging jump threading. If + // not, we are just pessimizing the code by making an i1 phi. + if (LHSI->getParent() == I.getParent()) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; - break; - case Instruction::FSub: { - // fcmp pred (fneg x), C -> fcmp swap(pred) x, -C - Value *Op; - if (match(LHSI, m_FNeg(m_Value(Op)))) - return new FCmpInst(I.getSwappedPredicate(), Op, - ConstantExpr::getFNeg(RHSC)); - break; - } - case Instruction::Load: - if (GetElementPtrInst *GEP = - dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) { - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) - if (GV->isConstant() && GV->hasDefinitiveInitializer() && - !cast<LoadInst>(LHSI)->isVolatile()) - if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) - return Res; - } - break; - case Instruction::Call: { - if (!RHSC->isNullValue()) - break; + break; + case Instruction::SIToFP: + case Instruction::UIToFP: + if (Instruction *NV = foldFCmpIntToFPConst(I, LHSI, RHSC)) + return NV; + break; + case Instruction::FDiv: + if (Instruction *NV = foldFCmpReciprocalAndZero(I, LHSI, RHSC)) + return NV; + break; + case Instruction::Load: + if (auto *GEP = dyn_cast<GetElementPtrInst>(LHSI->getOperand(0))) + if (auto *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0))) + if (GV->isConstant() && GV->hasDefinitiveInitializer() && + !cast<LoadInst>(LHSI)->isVolatile()) + if (Instruction *Res = foldCmpLoadFromIndexedGlobal(GEP, GV, I)) + return Res; + break; + } + } - CallInst *CI = cast<CallInst>(LHSI); - Intrinsic::ID IID = getIntrinsicForCallSite(CI, &TLI); - if (IID != Intrinsic::fabs) - break; + if (Instruction *R = foldFabsWithFcmpZero(I)) + return R; - // Various optimization for fabs compared with zero. - switch (Pred) { - default: - break; - // fabs(x) < 0 --> false - case FCmpInst::FCMP_OLT: - llvm_unreachable("handled by SimplifyFCmpInst"); - // fabs(x) > 0 --> x != 0 - case FCmpInst::FCMP_OGT: - return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0), RHSC); - // fabs(x) <= 0 --> x == 0 - case FCmpInst::FCMP_OLE: - return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0), RHSC); - // fabs(x) >= 0 --> !isnan(x) - case FCmpInst::FCMP_OGE: - return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0), RHSC); - // fabs(x) == 0 --> x == 0 - // fabs(x) != 0 --> x != 0 - case FCmpInst::FCMP_OEQ: - case FCmpInst::FCMP_UEQ: - case FCmpInst::FCMP_ONE: - case FCmpInst::FCMP_UNE: - return new FCmpInst(Pred, CI->getArgOperand(0), RHSC); - } - } + Value *X, *Y; + if (match(Op0, m_FNeg(m_Value(X)))) { + // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y + if (match(Op1, m_FNeg(m_Value(Y)))) + return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); + + // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C + Constant *C; + if (match(Op1, m_Constant(C))) { + Constant *NegC = ConstantExpr::getFNeg(C); + return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); + } + } + + if (match(Op0, m_FPExt(m_Value(X)))) { + // fcmp (fpext X), (fpext Y) -> fcmp X, Y + if (match(Op1, m_FPExt(m_Value(Y))) && X->getType() == Y->getType()) + return new FCmpInst(Pred, X, Y, "", &I); + + // fcmp (fpext X), C -> fcmp X, (fptrunc C) if fptrunc is lossless + const APFloat *C; + if (match(Op1, m_APFloat(C))) { + const fltSemantics &FPSem = + X->getType()->getScalarType()->getFltSemantics(); + bool Lossy; + APFloat TruncC = *C; + TruncC.convert(FPSem, APFloat::rmNearestTiesToEven, &Lossy); + + // Avoid lossy conversions and denormals. + // Zero is a special case that's OK to convert. + APFloat Fabs = TruncC; + Fabs.clearSign(); + if (!Lossy && + ((Fabs.compare(APFloat::getSmallestNormalized(FPSem)) != + APFloat::cmpLessThan) || Fabs.isZero())) { + Constant *NewC = ConstantFP::get(X->getType(), TruncC); + return new FCmpInst(Pred, X, NewC, "", &I); } + } } - // fcmp pred (fneg x), (fneg y) -> fcmp swap(pred) x, y - Value *X, *Y; - if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) - return new FCmpInst(I.getSwappedPredicate(), X, Y); - - // fcmp (fpext x), (fpext y) -> fcmp x, y - if (FPExtInst *LHSExt = dyn_cast<FPExtInst>(Op0)) - if (FPExtInst *RHSExt = dyn_cast<FPExtInst>(Op1)) - if (LHSExt->getSrcTy() == RHSExt->getSrcTy()) - return new FCmpInst(Pred, LHSExt->getOperand(0), RHSExt->getOperand(0)); + if (I.getType()->isVectorTy()) + if (Instruction *Res = foldVectorCmp(I, Builder)) + return Res; return Changed ? &I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 58ef3d41415c..2de41bd5bef5 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -20,7 +20,6 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -33,6 +32,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" @@ -41,11 +41,14 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #define DEBUG_TYPE "instcombine" +using namespace llvm::PatternMatch; + namespace llvm { class APInt; @@ -79,8 +82,8 @@ class User; /// 5 -> Other instructions static inline unsigned getComplexity(Value *V) { if (isa<Instruction>(V)) { - if (isa<CastInst>(V) || BinaryOperator::isNeg(V) || - BinaryOperator::isFNeg(V) || BinaryOperator::isNot(V)) + if (isa<CastInst>(V) || match(V, m_Neg(m_Value())) || + match(V, m_Not(m_Value())) || match(V, m_FNeg(m_Value()))) return 4; return 5; } @@ -138,7 +141,7 @@ static inline Constant *SubOne(Constant *C) { /// uses of V and only keep uses of ~V. static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { // ~(~(X)) -> X. - if (BinaryOperator::isNot(V)) + if (match(V, m_Not(m_Value()))) return true; // Constants can be considered to be not'ed values. @@ -175,6 +178,10 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { if (isa<Constant>(BO->getOperand(0)) || isa<Constant>(BO->getOperand(1))) return WillInvertAllUses; + // Selects with invertible operands are freely invertible + if (match(V, m_Select(m_Value(), m_Not(m_Value()), m_Not(m_Value())))) + return WillInvertAllUses; + return false; } @@ -496,6 +503,12 @@ private: OverflowResult::NeverOverflows; } + bool willNotOverflowAdd(const Value *LHS, const Value *RHS, + const Instruction &CxtI, bool IsSigned) const { + return IsSigned ? willNotOverflowSignedAdd(LHS, RHS, CxtI) + : willNotOverflowUnsignedAdd(LHS, RHS, CxtI); + } + bool willNotOverflowSignedSub(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForSignedSub(LHS, RHS, &CxtI) == @@ -508,6 +521,12 @@ private: OverflowResult::NeverOverflows; } + bool willNotOverflowSub(const Value *LHS, const Value *RHS, + const Instruction &CxtI, bool IsSigned) const { + return IsSigned ? willNotOverflowSignedSub(LHS, RHS, CxtI) + : willNotOverflowUnsignedSub(LHS, RHS, CxtI); + } + bool willNotOverflowSignedMul(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForSignedMul(LHS, RHS, &CxtI) == @@ -520,12 +539,29 @@ private: OverflowResult::NeverOverflows; } + bool willNotOverflowMul(const Value *LHS, const Value *RHS, + const Instruction &CxtI, bool IsSigned) const { + return IsSigned ? willNotOverflowSignedMul(LHS, RHS, CxtI) + : willNotOverflowUnsignedMul(LHS, RHS, CxtI); + } + + bool willNotOverflow(BinaryOperator::BinaryOps Opcode, const Value *LHS, + const Value *RHS, const Instruction &CxtI, + bool IsSigned) const { + switch (Opcode) { + case Instruction::Add: return willNotOverflowAdd(LHS, RHS, CxtI, IsSigned); + case Instruction::Sub: return willNotOverflowSub(LHS, RHS, CxtI, IsSigned); + case Instruction::Mul: return willNotOverflowMul(LHS, RHS, CxtI, IsSigned); + default: llvm_unreachable("Unexpected opcode for overflow query"); + } + } + Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); - Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); Instruction *narrowMaskedBinOp(BinaryOperator &And); + Instruction *narrowMathIfNoOverflow(BinaryOperator &I); Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); @@ -553,6 +589,9 @@ private: Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, bool JoinedByAnd, Instruction &CxtI); + Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); + Value *getSelectCondition(Value *A, Value *B); + public: /// Inserts an instruction \p New before instruction \p Old /// @@ -763,13 +802,14 @@ private: Value *simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, APInt DemandedElts, - int DmaskIdx = -1); + int DmaskIdx = -1, + int TFCIdx = -1); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0); /// Canonicalize the position of binops relative to shufflevector. - Instruction *foldShuffledBinop(BinaryOperator &Inst); + Instruction *foldVectorBinop(BinaryOperator &Inst); /// Given a binary operator, cast instruction, or select which has a PHI node /// as operand #0, see if we can fold the instruction into the PHI (which is @@ -813,11 +853,12 @@ private: ConstantInt *AndCst = nullptr); Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC); - Instruction *foldICmpAddOpConst(Value *X, ConstantInt *CI, + Instruction *foldICmpAddOpConst(Value *X, const APInt &C, ICmpInst::Predicate Pred); Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); + Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); @@ -880,8 +921,11 @@ private: Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); - Instruction *MatchBSwap(BinaryOperator &I); - bool SimplifyStoreAtEndOfBlock(StoreInst &SI); + bool mergeStoreIntoSuccessor(StoreInst &SI); + + /// Given an 'or' instruction, check to see if it is part of a bswap idiom. + /// If so, return the equivalent bswap intrinsic. + Instruction *matchBSwap(BinaryOperator &Or); Instruction *SimplifyAnyMemTransfer(AnyMemTransferInst *MI); Instruction *SimplifyAnyMemSet(AnyMemSetInst *MI); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 62769f077b47..76ab614090fa 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -19,6 +19,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" @@ -115,13 +116,10 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, } // Lifetime intrinsics can be handled by the caller. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) { - assert(II->use_empty() && "Lifetime markers have no result to use!"); - ToDelete.push_back(II); - continue; - } + if (I->isLifetimeStartOrEnd()) { + assert(I->use_empty() && "Lifetime markers have no result to use!"); + ToDelete.push_back(I); + continue; } // If this is isn't our memcpy/memmove, reject it as something we can't @@ -197,30 +195,32 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { // Convert: alloca Ty, C - where C is a constant != 1 into: alloca [C x Ty], 1 if (const ConstantInt *C = dyn_cast<ConstantInt>(AI.getArraySize())) { - Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); - AllocaInst *New = IC.Builder.CreateAlloca(NewTy, nullptr, AI.getName()); - New->setAlignment(AI.getAlignment()); - - // Scan to the end of the allocation instructions, to skip over a block of - // allocas if possible...also skip interleaved debug info - // - BasicBlock::iterator It(New); - while (isa<AllocaInst>(*It) || isa<DbgInfoIntrinsic>(*It)) - ++It; - - // Now that I is pointing to the first non-allocation-inst in the block, - // insert our getelementptr instruction... - // - Type *IdxTy = IC.getDataLayout().getIntPtrType(AI.getType()); - Value *NullIdx = Constant::getNullValue(IdxTy); - Value *Idx[2] = {NullIdx, NullIdx}; - Instruction *GEP = - GetElementPtrInst::CreateInBounds(New, Idx, New->getName() + ".sub"); - IC.InsertNewInstBefore(GEP, *It); - - // Now make everything use the getelementptr instead of the original - // allocation. - return IC.replaceInstUsesWith(AI, GEP); + if (C->getValue().getActiveBits() <= 64) { + Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); + AllocaInst *New = IC.Builder.CreateAlloca(NewTy, nullptr, AI.getName()); + New->setAlignment(AI.getAlignment()); + + // Scan to the end of the allocation instructions, to skip over a block of + // allocas if possible...also skip interleaved debug info + // + BasicBlock::iterator It(New); + while (isa<AllocaInst>(*It) || isa<DbgInfoIntrinsic>(*It)) + ++It; + + // Now that I is pointing to the first non-allocation-inst in the block, + // insert our getelementptr instruction... + // + Type *IdxTy = IC.getDataLayout().getIntPtrType(AI.getType()); + Value *NullIdx = Constant::getNullValue(IdxTy); + Value *Idx[2] = {NullIdx, NullIdx}; + Instruction *GEP = + GetElementPtrInst::CreateInBounds(New, Idx, New->getName() + ".sub"); + IC.InsertNewInstBefore(GEP, *It); + + // Now make everything use the getelementptr instead of the original + // allocation. + return IC.replaceInstUsesWith(AI, GEP); + } } if (isa<UndefValue>(AI.getArraySize())) @@ -490,6 +490,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT case LLVMContext::MD_noalias: case LLVMContext::MD_nontemporal: case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_access_group: // All of these directly apply. NewLoad->setMetadata(ID, N); break; @@ -549,10 +550,10 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value case LLVMContext::MD_noalias: case LLVMContext::MD_nontemporal: case LLVMContext::MD_mem_parallel_loop_access: + case LLVMContext::MD_access_group: // All of these directly apply. NewStore->setMetadata(ID, N); break; - case LLVMContext::MD_invariant_load: case LLVMContext::MD_nonnull: case LLVMContext::MD_range: @@ -1024,7 +1025,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { if (Value *AvailableVal = FindAvailableLoadedValue( &LI, LI.getParent(), BBI, DefMaxInstsToScan, AA, &IsLoadCSE)) { if (IsLoadCSE) - combineMetadataForCSE(cast<LoadInst>(AvailableVal), &LI); + combineMetadataForCSE(cast<LoadInst>(AvailableVal), &LI, false); return replaceInstUsesWith( LI, Builder.CreateBitOrPointerCast(AvailableVal, LI.getType(), @@ -1496,64 +1497,45 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (isa<UndefValue>(Val)) return eraseInstFromFunction(SI); - // If this store is the last instruction in the basic block (possibly - // excepting debug info instructions), and if the block ends with an - // unconditional branch, try to move it to the successor block. + // If this store is the second-to-last instruction in the basic block + // (excluding debug info and bitcasts of pointers) and if the block ends with + // an unconditional branch, try to move the store to the successor block. BBI = SI.getIterator(); do { ++BBI; } while (isa<DbgInfoIntrinsic>(BBI) || (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())); + if (BranchInst *BI = dyn_cast<BranchInst>(BBI)) if (BI->isUnconditional()) - if (SimplifyStoreAtEndOfBlock(SI)) - return nullptr; // xform done! + mergeStoreIntoSuccessor(SI); return nullptr; } -/// SimplifyStoreAtEndOfBlock - Turn things like: +/// Try to transform: /// if () { *P = v1; } else { *P = v2 } -/// into a phi node with a store in the successor. -/// -/// Simplify things like: +/// or: /// *P = v1; if () { *P = v2; } /// into a phi node with a store in the successor. -/// -bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { +bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) { assert(SI.isUnordered() && - "this code has not been auditted for volatile or ordered store case"); + "This code has not been audited for volatile or ordered store case."); + // Check if the successor block has exactly 2 incoming edges. BasicBlock *StoreBB = SI.getParent(); - - // Check to see if the successor block has exactly two incoming edges. If - // so, see if the other predecessor contains a store to the same location. - // if so, insert a PHI node (if needed) and move the stores down. BasicBlock *DestBB = StoreBB->getTerminator()->getSuccessor(0); - - // Determine whether Dest has exactly two predecessors and, if so, compute - // the other predecessor. - pred_iterator PI = pred_begin(DestBB); - BasicBlock *P = *PI; - BasicBlock *OtherBB = nullptr; - - if (P != StoreBB) - OtherBB = P; - - if (++PI == pred_end(DestBB)) + if (!DestBB->hasNPredecessors(2)) return false; - P = *PI; - if (P != StoreBB) { - if (OtherBB) - return false; - OtherBB = P; - } - if (++PI != pred_end(DestBB)) - return false; + // Capture the other block (the block that doesn't contain our store). + pred_iterator PredIter = pred_begin(DestBB); + if (*PredIter == StoreBB) + ++PredIter; + BasicBlock *OtherBB = *PredIter; - // Bail out if all the relevant blocks aren't distinct (this can happen, - // for example, if SI is in an infinite loop) + // Bail out if all of the relevant blocks aren't distinct. This can happen, + // for example, if SI is in an infinite loop. if (StoreBB == DestBB || OtherBB == DestBB) return false; @@ -1564,7 +1546,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { return false; // If the other block ends in an unconditional branch, check for the 'if then - // else' case. there is an instruction before the branch. + // else' case. There is an instruction before the branch. StoreInst *OtherStore = nullptr; if (OtherBr->isUnconditional()) { --BBI; @@ -1589,7 +1571,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { return false; // Okay, we know that OtherBr now goes to Dest and StoreBB, so this is an - // if/then triangle. See if there is a store to the same ptr as SI that + // if/then triangle. See if there is a store to the same ptr as SI that // lives in OtherBB. for (;; --BBI) { // Check to see if we find the matching store. @@ -1600,15 +1582,14 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { break; } // If we find something that may be using or overwriting the stored - // value, or if we run out of instructions, we can't do the xform. + // value, or if we run out of instructions, we can't do the transform. if (BBI->mayReadFromMemory() || BBI->mayThrow() || BBI->mayWriteToMemory() || BBI == OtherBB->begin()) return false; } - // In order to eliminate the store in OtherBr, we have to - // make sure nothing reads or overwrites the stored value in - // StoreBB. + // In order to eliminate the store in OtherBr, we have to make sure nothing + // reads or overwrites the stored value in StoreBB. for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) { // FIXME: This should really be AA driven. if (I->mayReadFromMemory() || I->mayThrow() || I->mayWriteToMemory()) @@ -1618,24 +1599,24 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { // Insert a PHI node now if we need it. Value *MergedVal = OtherStore->getOperand(0); + // The debug locations of the original instructions might differ. Merge them. + DebugLoc MergedLoc = DILocation::getMergedLocation(SI.getDebugLoc(), + OtherStore->getDebugLoc()); if (MergedVal != SI.getOperand(0)) { PHINode *PN = PHINode::Create(MergedVal->getType(), 2, "storemerge"); PN->addIncoming(SI.getOperand(0), SI.getParent()); PN->addIncoming(OtherStore->getOperand(0), OtherBB); MergedVal = InsertNewInstBefore(PN, DestBB->front()); + PN->setDebugLoc(MergedLoc); } - // Advance to a place where it is safe to insert the new store and - // insert it. + // Advance to a place where it is safe to insert the new store and insert it. BBI = DestBB->getFirstInsertionPt(); StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), - SI.isVolatile(), - SI.getAlignment(), - SI.getOrdering(), - SI.getSyncScopeID()); + SI.isVolatile(), SI.getAlignment(), + SI.getOrdering(), SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); - // The debug locations of the original instructions might differ; merge them. - NewSI->applyMergedLocation(SI.getDebugLoc(), OtherStore->getDebugLoc()); + NewSI->setDebugLoc(MergedLoc); // If the two stores had AA tags, merge them. AAMDNodes AATags; diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 63761d427235..7e99f3e4e500 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -133,7 +133,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (SimplifyAssociativeOrCommutative(I)) return &I; - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Value *V = SimplifyUsingDistributiveLaws(I)) @@ -171,14 +171,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (match(&I, m_Mul(m_Value(NewOp), m_Constant(C1)))) { // Replace X*(2^C) with X << C, where C is either a scalar or a vector. if (Constant *NewCst = getLogBase2(NewOp->getType(), C1)) { - unsigned Width = NewCst->getType()->getPrimitiveSizeInBits(); BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap(); if (I.hasNoSignedWrap()) { const APInt *V; - if (match(NewCst, m_APInt(V)) && *V != Width - 1) + if (match(NewCst, m_APInt(V)) && *V != V->getBitWidth() - 1) Shl->setHasNoSignedWrap(); } @@ -245,6 +244,11 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return NewMul; } + // -X * Y --> -(X * Y) + // X * -Y --> -(X * Y) + if (match(&I, m_c_Mul(m_OneUse(m_Neg(m_Value(X))), m_Value(Y)))) + return BinaryOperator::CreateNeg(Builder.CreateMul(X, Y)); + // (X / Y) * Y = X - (X % Y) // (X / Y) * -Y = (X % Y) - X { @@ -323,77 +327,8 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (match(Op1, m_LShr(m_Value(X), m_APInt(C))) && *C == C->getBitWidth() - 1) return BinaryOperator::CreateAnd(Builder.CreateAShr(X, *C), Op0); - // Check for (mul (sext x), y), see if we can merge this into an - // integer mul followed by a sext. - if (SExtInst *Op0Conv = dyn_cast<SExtInst>(Op0)) { - // (mul (sext x), cst) --> (sext (mul x, cst')) - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { - if (Op0Conv->hasOneUse()) { - Constant *CI = - ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); - if (ConstantExpr::getSExt(CI, I.getType()) == Op1C && - willNotOverflowSignedMul(Op0Conv->getOperand(0), CI, I)) { - // Insert the new, smaller mul. - Value *NewMul = - Builder.CreateNSWMul(Op0Conv->getOperand(0), CI, "mulconv"); - return new SExtInst(NewMul, I.getType()); - } - } - } - - // (mul (sext x), (sext y)) --> (sext (mul int x, y)) - if (SExtInst *Op1Conv = dyn_cast<SExtInst>(Op1)) { - // Only do this if x/y have the same type, if at last one of them has a - // single use (so we don't increase the number of sexts), and if the - // integer mul will not overflow. - if (Op0Conv->getOperand(0)->getType() == - Op1Conv->getOperand(0)->getType() && - (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && - willNotOverflowSignedMul(Op0Conv->getOperand(0), - Op1Conv->getOperand(0), I)) { - // Insert the new integer mul. - Value *NewMul = Builder.CreateNSWMul( - Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); - return new SExtInst(NewMul, I.getType()); - } - } - } - - // Check for (mul (zext x), y), see if we can merge this into an - // integer mul followed by a zext. - if (auto *Op0Conv = dyn_cast<ZExtInst>(Op0)) { - // (mul (zext x), cst) --> (zext (mul x, cst')) - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { - if (Op0Conv->hasOneUse()) { - Constant *CI = - ConstantExpr::getTrunc(Op1C, Op0Conv->getOperand(0)->getType()); - if (ConstantExpr::getZExt(CI, I.getType()) == Op1C && - willNotOverflowUnsignedMul(Op0Conv->getOperand(0), CI, I)) { - // Insert the new, smaller mul. - Value *NewMul = - Builder.CreateNUWMul(Op0Conv->getOperand(0), CI, "mulconv"); - return new ZExtInst(NewMul, I.getType()); - } - } - } - - // (mul (zext x), (zext y)) --> (zext (mul int x, y)) - if (auto *Op1Conv = dyn_cast<ZExtInst>(Op1)) { - // Only do this if x/y have the same type, if at last one of them has a - // single use (so we don't increase the number of zexts), and if the - // integer mul will not overflow. - if (Op0Conv->getOperand(0)->getType() == - Op1Conv->getOperand(0)->getType() && - (Op0Conv->hasOneUse() || Op1Conv->hasOneUse()) && - willNotOverflowUnsignedMul(Op0Conv->getOperand(0), - Op1Conv->getOperand(0), I)) { - // Insert the new integer mul. - Value *NewMul = Builder.CreateNUWMul( - Op0Conv->getOperand(0), Op1Conv->getOperand(0), "mulconv"); - return new ZExtInst(NewMul, I.getType()); - } - } - } + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; bool Changed = false; if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) { @@ -418,7 +353,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (SimplifyAssociativeOrCommutative(I)) return &I; - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) @@ -503,7 +438,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { match(Op0, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) && match(Op1, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) { Value *XY = Builder.CreateFMulFMF(X, Y, &I); - Value *Sqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, { XY }, &I); + Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I); return replaceInstUsesWith(I, Sqrt); } @@ -933,7 +868,7 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // Handle the integer div common cases @@ -1027,7 +962,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // Handle the integer div common cases @@ -1175,7 +1110,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *R = foldFDivConstantDivisor(I)) @@ -1227,7 +1162,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { IRBuilder<>::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(I.getFastMathFlags()); AttributeList Attrs = CallSite(Op0).getCalledFunction()->getAttributes(); - Value *Res = emitUnaryFloatFnCall(X, TLI.getName(LibFunc_tan), B, Attrs); + Value *Res = emitUnaryFloatFnCall(X, &TLI, LibFunc_tan, LibFunc_tanf, + LibFunc_tanl, B, Attrs); if (IsCot) Res = B.CreateFDiv(ConstantFP::get(I.getType(), 1.0), Res); return replaceInstUsesWith(I, Res); @@ -1304,7 +1240,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *common = commonIRemTransforms(I)) @@ -1351,7 +1287,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; // Handle the integer rem common cases @@ -1425,7 +1361,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index e54a1dd05a24..7603cf4d7958 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -211,20 +211,20 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { } // If it requires a conversion for every PHI operand, do not do it. - if (std::all_of(AvailablePtrVals.begin(), AvailablePtrVals.end(), - [&](Value *V) { - return (V->getType() != IntToPtr->getType()) || - isa<IntToPtrInst>(V); - })) + if (all_of(AvailablePtrVals, [&](Value *V) { + return (V->getType() != IntToPtr->getType()) || isa<IntToPtrInst>(V); + })) return nullptr; // If any of the operand that requires casting is a terminator // instruction, do not do it. - if (std::any_of(AvailablePtrVals.begin(), AvailablePtrVals.end(), - [&](Value *V) { - return (V->getType() != IntToPtr->getType()) && - isa<TerminatorInst>(V); - })) + if (any_of(AvailablePtrVals, [&](Value *V) { + if (V->getType() == IntToPtr->getType()) + return false; + + auto *Inst = dyn_cast<Instruction>(V); + return Inst && Inst->isTerminator(); + })) return nullptr; PHINode *NewPtrPHI = PHINode::Create( @@ -608,6 +608,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { LLVMContext::MD_align, LLVMContext::MD_dereferenceable, LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_access_group, }; for (unsigned ID : KnownIDs) @@ -616,7 +617,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { // Add all operands to the new PHI and combine TBAA metadata. for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) { LoadInst *LI = cast<LoadInst>(PN.getIncomingValue(i)); - combineMetadata(NewLI, LI, KnownIDs); + combineMetadata(NewLI, LI, KnownIDs, true); Value *NewInVal = LI->getOperand(0); if (NewInVal != InVal) InVal = nullptr; @@ -649,7 +650,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) { // We cannot create a new instruction after the PHI if the terminator is an // EHPad because there is no valid insertion point. - if (TerminatorInst *TI = Phi.getParent()->getTerminator()) + if (Instruction *TI = Phi.getParent()->getTerminator()) if (TI->isEHPad()) return nullptr; @@ -723,7 +724,7 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) { Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { // We cannot create a new instruction after the PHI if the terminator is an // EHPad because there is no valid insertion point. - if (TerminatorInst *TI = PN.getParent()->getTerminator()) + if (Instruction *TI = PN.getParent()->getTerminator()) if (TI->isEHPad()) return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 796b4021d273..faf58a08976d 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -54,34 +54,62 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder, return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } -/// Fold -/// %A = icmp eq/ne i8 %x, 0 -/// %B = op i8 %x, %z -/// %C = select i1 %A, i8 %B, i8 %y -/// To -/// %C = select i1 %A, i8 %z, i8 %y -/// OP: binop with an identity constant -/// TODO: support for non-commutative and FP opcodes -static Instruction *foldSelectBinOpIdentity(SelectInst &Sel) { - - Value *Cond = Sel.getCondition(); - Value *X, *Z; +/// Replace a select operand based on an equality comparison with the identity +/// constant of a binop. +static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, + const TargetLibraryInfo &TLI) { + // The select condition must be an equality compare with a constant operand. + Value *X; Constant *C; CmpInst::Predicate Pred; - if (!match(Cond, m_ICmp(Pred, m_Value(X), m_Constant(C))) || - !ICmpInst::isEquality(Pred)) + if (!match(Sel.getCondition(), m_Cmp(Pred, m_Value(X), m_Constant(C)))) return nullptr; - bool IsEq = Pred == ICmpInst::ICMP_EQ; - auto *BO = - dyn_cast<BinaryOperator>(IsEq ? Sel.getTrueValue() : Sel.getFalseValue()); - // TODO: support for undefs - if (BO && match(BO, m_c_BinOp(m_Specific(X), m_Value(Z))) && - ConstantExpr::getBinOpIdentity(BO->getOpcode(), X->getType()) == C) { - Sel.setOperand(IsEq ? 1 : 2, Z); - return &Sel; + bool IsEq; + if (ICmpInst::isEquality(Pred)) + IsEq = Pred == ICmpInst::ICMP_EQ; + else if (Pred == FCmpInst::FCMP_OEQ) + IsEq = true; + else if (Pred == FCmpInst::FCMP_UNE) + IsEq = false; + else + return nullptr; + + // A select operand must be a binop. + BinaryOperator *BO; + if (!match(Sel.getOperand(IsEq ? 1 : 2), m_BinOp(BO))) + return nullptr; + + // The compare constant must be the identity constant for that binop. + // If this a floating-point compare with 0.0, any zero constant will do. + Type *Ty = BO->getType(); + Constant *IdC = ConstantExpr::getBinOpIdentity(BO->getOpcode(), Ty, true); + if (IdC != C) { + if (!IdC || !CmpInst::isFPPredicate(Pred)) + return nullptr; + if (!match(IdC, m_AnyZeroFP()) || !match(C, m_AnyZeroFP())) + return nullptr; } - return nullptr; + + // Last, match the compare variable operand with a binop operand. + Value *Y; + if (!BO->isCommutative() && !match(BO, m_BinOp(m_Value(Y), m_Specific(X)))) + return nullptr; + if (!match(BO, m_c_BinOp(m_Value(Y), m_Specific(X)))) + return nullptr; + + // +0.0 compares equal to -0.0, and so it does not behave as required for this + // transform. Bail out if we can not exclude that possibility. + if (isa<FPMathOperator>(BO)) + if (!BO->hasNoSignedZeros() && !CannotBeNegativeZero(Y, &TLI)) + return nullptr; + + // BO = binop Y, X + // S = { select (cmp eq X, C), BO, ? } or { select (cmp ne X, C), ?, BO } + // => + // S = { select (cmp eq X, C), Y, ? } or { select (cmp ne X, C), ?, Y } + Sel.setOperand(IsEq ? 1 : 2, Y); + return &Sel; } /// This folds: @@ -343,13 +371,24 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, return nullptr; } + // If the select condition is a vector, the operands of the original select's + // operands also must be vectors. This may not be the case for getelementptr + // for example. + if (SI.getCondition()->getType()->isVectorTy() && + (!OtherOpT->getType()->isVectorTy() || + !OtherOpF->getType()->isVectorTy())) + return nullptr; + // If we reach here, they do have operations in common. Value *NewSI = Builder.CreateSelect(SI.getCondition(), OtherOpT, OtherOpF, SI.getName() + ".v", &SI); Value *Op0 = MatchIsOpZero ? MatchOp : NewSI; Value *Op1 = MatchIsOpZero ? NewSI : MatchOp; if (auto *BO = dyn_cast<BinaryOperator>(TI)) { - return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + BinaryOperator *NewBO = BinaryOperator::Create(BO->getOpcode(), Op0, Op1); + NewBO->copyIRFlags(TI); + NewBO->andIRFlags(FI); + return NewBO; } if (auto *TGEP = dyn_cast<GetElementPtrInst>(TI)) { auto *FGEP = cast<GetElementPtrInst>(FI); @@ -670,17 +709,18 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, match(Count, m_Trunc(m_Value(V)))) Count = V; + // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the + // input to the cttz/ctlz is used as LHS for the compare instruction. + if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) && + !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) + return nullptr; + + IntrinsicInst *II = cast<IntrinsicInst>(Count); + // Check if the value propagated on zero is a constant number equal to the // sizeof in bits of 'Count'. unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits(); - if (!match(ValueOnZero, m_SpecificInt(SizeOfInBits))) - return nullptr; - - // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the - // input to the cttz/ctlz is used as LHS for the compare instruction. - if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) || - match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) { - IntrinsicInst *II = cast<IntrinsicInst>(Count); + if (match(ValueOnZero, m_SpecificInt(SizeOfInBits))) { // Explicitly clear the 'undef_on_zero' flag. IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext())); @@ -688,6 +728,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); } + // If the ValueOnZero is not the bitwidth, we can at least make use of the + // fact that the cttz/ctlz result will not be used if the input is zero, so + // it's okay to relax it to undef for that case. + if (II->hasOneUse() && !match(II->getArgOperand(1), m_One())) + II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); + return nullptr; } @@ -1054,11 +1100,13 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (C == A || C == B) { // MAX(MAX(A, B), B) -> MAX(A, B) // MIN(MIN(a, b), a) -> MIN(a, b) + // TODO: This could be done in instsimplify. if (SPF1 == SPF2 && SelectPatternResult::isMinOrMax(SPF1)) return replaceInstUsesWith(Outer, Inner); // MAX(MIN(a, b), a) -> a // MIN(MAX(a, b), a) -> a + // TODO: This could be done in instsimplify. if ((SPF1 == SPF_SMIN && SPF2 == SPF_SMAX) || (SPF1 == SPF_SMAX && SPF2 == SPF_SMIN) || (SPF1 == SPF_UMIN && SPF2 == SPF_UMAX) || @@ -1071,6 +1119,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, if (match(B, m_APInt(CB)) && match(C, m_APInt(CC))) { // MIN(MIN(A, 23), 97) -> MIN(A, 23) // MAX(MAX(A, 97), 23) -> MAX(A, 97) + // TODO: This could be done in instsimplify. if ((SPF1 == SPF_UMIN && CB->ule(*CC)) || (SPF1 == SPF_SMIN && CB->sle(*CC)) || (SPF1 == SPF_UMAX && CB->uge(*CC)) || @@ -1091,6 +1140,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, // ABS(ABS(X)) -> ABS(X) // NABS(NABS(X)) -> NABS(X) + // TODO: This could be done in instsimplify. if (SPF1 == SPF2 && (SPF1 == SPF_ABS || SPF1 == SPF_NABS)) { return replaceInstUsesWith(Outer, Inner); } @@ -1503,6 +1553,60 @@ static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, return SelectInst::Create(CmpABC, MinMaxOp, ThirdOp); } +/// Try to reduce a rotate pattern that includes a compare and select into a +/// funnel shift intrinsic. Example: +/// rotl32(a, b) --> (b == 0 ? a : ((a >> (32 - b)) | (a << b))) +/// --> call llvm.fshl.i32(a, a, b) +static Instruction *foldSelectRotate(SelectInst &Sel) { + // The false value of the select must be a rotate of the true value. + Value *Or0, *Or1; + if (!match(Sel.getFalseValue(), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *TVal = Sel.getTrueValue(); + Value *SA0, *SA1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(TVal), m_Value(SA1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // We have one of these patterns so far: + // select ?, TVal, (or (lshr TVal, SA0), (shl TVal, SA1)) + // select ?, TVal, (or (shl TVal, SA0), (lshr TVal, SA1)) + // This must be a power-of-2 rotate for a bitmasking transform to be valid. + unsigned Width = Sel.getType()->getScalarSizeInBits(); + if (!isPowerOf2_32(Width)) + return nullptr; + + // Check the shift amounts to see if they are an opposite pair. + Value *ShAmt; + if (match(SA1, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA0))))) + ShAmt = SA0; + else if (match(SA0, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(SA1))))) + ShAmt = SA1; + else + return nullptr; + + // Finally, see if the select is filtering out a shift-by-zero. + Value *Cond = Sel.getCondition(); + ICmpInst::Predicate Pred; + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()))) || + Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // This is a rotate that avoids shift-by-bitwidth UB in a suboptimal way. + // Convert to funnel shift intrinsic. + bool IsFshl = (ShAmt == SA0 && ShiftOpcode0 == BinaryOperator::Shl) || + (ShAmt == SA1 && ShiftOpcode1 == BinaryOperator::Shl); + Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; + Function *F = Intrinsic::getDeclaration(Sel.getModule(), IID, Sel.getType()); + return IntrinsicInst::Create(F, { TVal, TVal, ShAmt }); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1617,31 +1721,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { - // Transform (X == Y) ? X : Y -> Y - if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, FalseVal); - } - // Transform (X une Y) ? X : Y -> X - if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, TrueVal); - } - // Canonicalize to use ordered comparisons by swapping the select // operands. // @@ -1660,31 +1739,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // NOTE: if we wanted to, this is where to detect MIN/MAX } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ - // Transform (X == Y) ? Y : X -> X - if (FCI->getPredicate() == FCmpInst::FCMP_OEQ) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, FalseVal); - } - // Transform (X une Y) ? Y : X -> Y - if (FCI->getPredicate() == FCmpInst::FCMP_UNE) { - // This is not safe in general for floating point: - // consider X== -0, Y== +0. - // It becomes safe if either operand is a nonzero constant. - ConstantFP *CFPt, *CFPf; - if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) && - !CFPt->getValueAPF().isZero()) || - ((CFPf = dyn_cast<ConstantFP>(FalseVal)) && - !CFPf->getValueAPF().isZero())) - return replaceInstUsesWith(SI, TrueVal); - } - // Canonicalize to use ordered comparisons by swapping the select // operands. // @@ -1717,7 +1771,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(X)))) || (X == TrueVal && Pred == FCmpInst::FCMP_OGT && match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(X))))) { - Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI); return replaceInstUsesWith(SI, Fabs); } // With nsz: @@ -1730,7 +1784,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE)) || (X == TrueVal && match(FalseVal, m_FNeg(m_Specific(X))) && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE)))) { - Value *Fabs = Builder.CreateIntrinsic(Intrinsic::fabs, { X }, FCI); + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, X, FCI); return replaceInstUsesWith(SI, Fabs); } } @@ -1759,10 +1813,23 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *FoldI = foldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; - Value *LHS, *RHS, *LHS2, *RHS2; + Value *LHS, *RHS; Instruction::CastOps CastOp; SelectPatternResult SPR = matchSelectPattern(&SI, LHS, RHS, &CastOp); auto SPF = SPR.Flavor; + if (SPF) { + Value *LHS2, *RHS2; + if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) + if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS), SPF2, LHS2, + RHS2, SI, SPF, RHS)) + return R; + if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) + if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS), SPF2, LHS2, + RHS2, SI, SPF, LHS)) + return R; + // TODO. + // ABS(-X) -> ABS(X) + } if (SelectPatternResult::isMinOrMax(SPF)) { // Canonicalize so that @@ -1797,39 +1864,40 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } // MAX(~a, ~b) -> ~MIN(a, b) + // MAX(~a, C) -> ~MIN(a, ~C) // MIN(~a, ~b) -> ~MAX(a, b) - Value *A, *B; - if (match(LHS, m_Not(m_Value(A))) && match(RHS, m_Not(m_Value(B))) && - (LHS->getNumUses() <= 2 || RHS->getNumUses() <= 2)) { - CmpInst::Predicate InvertedPred = getInverseMinMaxPred(SPF); - Value *InvertedCmp = Builder.CreateICmp(InvertedPred, A, B); - Value *NewSel = Builder.CreateSelect(InvertedCmp, A, B); - return BinaryOperator::CreateNot(NewSel); - } + // MIN(~a, C) -> ~MAX(a, ~C) + auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { + Value *A; + if (match(X, m_Not(m_Value(A))) && !X->hasNUsesOrMore(3) && + !IsFreeToInvert(A, A->hasOneUse()) && + // Passing false to only consider m_Not and constants. + IsFreeToInvert(Y, false)) { + Value *B = Builder.CreateNot(Y); + Value *NewMinMax = createMinMax(Builder, getInverseMinMaxFlavor(SPF), + A, B); + // Copy the profile metadata. + if (MDNode *MD = SI.getMetadata(LLVMContext::MD_prof)) { + cast<SelectInst>(NewMinMax)->setMetadata(LLVMContext::MD_prof, MD); + // Swap the metadata if the operands are swapped. + if (X == SI.getFalseValue() && Y == SI.getTrueValue()) + cast<SelectInst>(NewMinMax)->swapProfMetadata(); + } - if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return BinaryOperator::CreateNot(NewMinMax); + } + + return nullptr; + }; + + if (Instruction *I = moveNotAfterMinMax(LHS, RHS)) + return I; + if (Instruction *I = moveNotAfterMinMax(RHS, LHS)) return I; - } - if (SPF) { - // MAX(MAX(a, b), a) -> MAX(a, b) - // MIN(MIN(a, b), a) -> MIN(a, b) - // MAX(MIN(a, b), a) -> a - // MIN(MAX(a, b), a) -> a - // ABS(ABS(a)) -> ABS(a) - // NABS(NABS(a)) -> NABS(a) - if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) - if (Instruction *R = foldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2, - SI, SPF, RHS)) - return R; - if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) - if (Instruction *R = foldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2, - SI, SPF, LHS)) - return R; + if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) + return I; } - - // TODO. - // ABS(-X) -> ABS(X) } // See if we can fold the select into a phi node if the condition is a select. @@ -1934,10 +2002,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - if (BinaryOperator::isNot(CondVal)) { - SI.setOperand(0, BinaryOperator::getNotArgument(CondVal)); + Value *NotCond; + if (match(CondVal, m_Not(m_Value(NotCond)))) { + SI.setOperand(0, NotCond); SI.setOperand(1, FalseVal); SI.setOperand(2, TrueVal); + SI.swapProfMetadata(); return &SI; } @@ -1952,24 +2022,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - // See if we can determine the result of this select based on a dominating - // condition. - BasicBlock *Parent = SI.getParent(); - if (BasicBlock *Dom = Parent->getSinglePredecessor()) { - auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); - if (PBI && PBI->isConditional() && - PBI->getSuccessor(0) != PBI->getSuccessor(1) && - (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { - bool CondIsTrue = PBI->getSuccessor(0) == Parent; - Optional<bool> Implication = isImpliedCondition( - PBI->getCondition(), SI.getCondition(), DL, CondIsTrue); - if (Implication) { - Value *V = *Implication ? TrueVal : FalseVal; - return replaceInstUsesWith(SI, V); - } - } - } - // If we can compute the condition, there's no need for a select. // Like the above fold, we are attempting to reduce compile-time cost by // putting this fold here with limitations rather than in InstSimplify. @@ -1991,8 +2043,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Select = foldSelectCmpXchg(SI)) return Select; - if (Instruction *Select = foldSelectBinOpIdentity(SI)) + if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI)) return Select; + if (Instruction *Rot = foldSelectRotate(SI)) + return Rot; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 1ca75f3989d4..c562d45a9e2b 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -593,7 +593,7 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *V = commonShiftTransforms(I)) @@ -697,7 +697,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *R = commonShiftTransforms(I)) @@ -725,9 +725,9 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { Value *X; const APInt *ShOp1; - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { - unsigned ShlAmt = ShOp1->getZExtValue(); - if (ShlAmt < ShAmt) { + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { + if (ShOp1->ult(ShAmt)) { + unsigned ShlAmt = ShOp1->getZExtValue(); Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) @@ -740,7 +740,8 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); } - if (ShlAmt > ShAmt) { + if (ShOp1->ugt(ShAmt)) { + unsigned ShlAmt = ShOp1->getZExtValue(); Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) @@ -753,7 +754,7 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); } - assert(ShlAmt == ShAmt); + assert(*ShOp1 == ShAmt); // (X << C) >>u C --> X & (-1 >>u C) APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); @@ -825,7 +826,7 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldShuffledBinop(I)) + if (Instruction *X = foldVectorBinop(I)) return X; if (Instruction *R = commonShiftTransforms(I)) diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 425f5ce384be..9bf87d024607 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -314,11 +314,32 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known.One = std::move(IKnownOne); break; } - case Instruction::Select: - // If this is a select as part of a min/max pattern, don't simplify any - // further in case we break the structure. + case Instruction::Select: { Value *LHS, *RHS; - if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN) + SelectPatternFlavor SPF = matchSelectPattern(I, LHS, RHS).Flavor; + if (SPF == SPF_UMAX) { + // UMax(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-zero bit of C. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(RHS, m_APInt(C)) && CTZ >= C->getActiveBits()) + return LHS; + } else if (SPF == SPF_UMIN) { + // UMin(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-one bit of C. + // This comes from using DeMorgans on the above umax example. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(RHS, m_APInt(C)) && + CTZ >= C->getBitWidth() - C->countLeadingOnes()) + return LHS; + } + + // If this is a select as part of any other min/max pattern, don't simplify + // any further in case we break the structure. + if (SPF != SPF_UNKNOWN) return nullptr; if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnown, Depth + 1) || @@ -336,6 +357,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known.One = RHSKnown.One & LHSKnown.One; Known.Zero = RHSKnown.Zero & LHSKnown.Zero; break; + } case Instruction::ZExt: case Instruction::Trunc: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); @@ -668,6 +690,30 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // TODO: Could compute known zero/one bits based on the input. break; } + case Intrinsic::fshr: + case Intrinsic::fshl: { + const APInt *SA; + if (!match(I->getOperand(2), m_APInt(SA))) + break; + + // Normalize to funnel shift left. APInt shifts of BitWidth are well- + // defined, so no need to special-case zero shifts here. + uint64_t ShiftAmt = SA->urem(BitWidth); + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + APInt DemandedMaskLHS(DemandedMask.lshr(ShiftAmt)); + APInt DemandedMaskRHS(DemandedMask.shl(BitWidth - ShiftAmt)); + if (SimplifyDemandedBits(I, 0, DemandedMaskLHS, LHSKnown, Depth + 1) || + SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) + return I; + + Known.Zero = LHSKnown.Zero.shl(ShiftAmt) | + RHSKnown.Zero.lshr(BitWidth - ShiftAmt); + Known.One = LHSKnown.One.shl(ShiftAmt) | + RHSKnown.One.lshr(BitWidth - ShiftAmt); + break; + } case Intrinsic::x86_mmx_pmovmskb: case Intrinsic::x86_sse_movmsk_ps: case Intrinsic::x86_sse2_movmsk_pd: @@ -923,11 +969,24 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, /// Implement SimplifyDemandedVectorElts for amdgcn buffer and image intrinsics. Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, APInt DemandedElts, - int DMaskIdx) { + int DMaskIdx, + int TFCIdx) { unsigned VWidth = II->getType()->getVectorNumElements(); if (VWidth == 1) return nullptr; + // Need to change to new instruction format + ConstantInt *TFC = nullptr; + bool TFELWEEnabled = false; + if (TFCIdx > 0) { + TFC = dyn_cast<ConstantInt>(II->getArgOperand(TFCIdx)); + TFELWEEnabled = TFC->getZExtValue() & 0x1 // TFE + || TFC->getZExtValue() & 0x2; // LWE + } + + if (TFELWEEnabled) + return nullptr; // TFE not yet supported + ConstantInt *NewDMask = nullptr; if (DMaskIdx < 0) { @@ -1052,8 +1111,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, UndefElts = 0; - // Handle ConstantAggregateZero, ConstantVector, ConstantDataSequential. - if (Constant *C = dyn_cast<Constant>(V)) { + if (auto *C = dyn_cast<Constant>(V)) { // Check if this is identity. If so, return 0 since we are not simplifying // anything. if (DemandedElts.isAllOnesValue()) @@ -1061,7 +1119,6 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, Type *EltTy = cast<VectorType>(V->getType())->getElementType(); Constant *Undef = UndefValue::get(EltTy); - SmallVector<Constant*, 16> Elts; for (unsigned i = 0; i != VWidth; ++i) { if (!DemandedElts[i]) { // If not demanded, set to undef. @@ -1109,9 +1166,21 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (!I) return nullptr; // Only analyze instructions. bool MadeChange = false; + auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum, + APInt Demanded, APInt &Undef) { + auto *II = dyn_cast<IntrinsicInst>(Inst); + Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum); + if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) { + if (II) + II->setArgOperand(OpNum, V); + else + Inst->setOperand(OpNum, V); + MadeChange = true; + } + }; + APInt UndefElts2(VWidth, 0); APInt UndefElts3(VWidth, 0); - Value *TmpV; switch (I->getOpcode()) { default: break; @@ -1122,9 +1191,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (!Idx) { // Note that we can't propagate undef elt info, because we don't know // which elt is getting updated. - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(I, 0, DemandedElts, UndefElts2); break; } @@ -1134,9 +1201,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt PreInsertDemandedElts = DemandedElts; if (IdxNo < VWidth) PreInsertDemandedElts.clearBit(IdxNo); - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), PreInsertDemandedElts, - UndefElts, Depth + 1); - if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + + simplifyAndSetOp(I, 0, PreInsertDemandedElts, UndefElts); // If this is inserting an element that isn't demanded, remove this // insertelement. @@ -1169,14 +1235,10 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } APInt LHSUndefElts(LHSVWidth, 0); - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), LeftDemanded, - LHSUndefElts, Depth + 1); - if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); APInt RHSUndefElts(LHSVWidth, 0); - TmpV = SimplifyDemandedVectorElts(I->getOperand(1), RightDemanded, - RHSUndefElts, Depth + 1); - if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } + simplifyAndSetOp(I, 1, RightDemanded, RHSUndefElts); bool NewUndefElts = false; unsigned LHSIdx = -1u, LHSValIdx = -1u; @@ -1260,32 +1322,43 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } case Instruction::Select: { - APInt LeftDemanded(DemandedElts), RightDemanded(DemandedElts); - if (ConstantVector* CV = dyn_cast<ConstantVector>(I->getOperand(0))) { + // If this is a vector select, try to transform the select condition based + // on the current demanded elements. + SelectInst *Sel = cast<SelectInst>(I); + if (Sel->getCondition()->getType()->isVectorTy()) { + // TODO: We are not doing anything with UndefElts based on this call. + // It is overwritten below based on the other select operands. If an + // element of the select condition is known undef, then we are free to + // choose the output value from either arm of the select. If we know that + // one of those values is undef, then the output can be undef. + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); + } + + // Next, see if we can transform the arms of the select. + APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts); + if (auto *CV = dyn_cast<ConstantVector>(Sel->getCondition())) { for (unsigned i = 0; i < VWidth; i++) { + // isNullValue() always returns false when called on a ConstantExpr. + // Skip constant expressions to avoid propagating incorrect information. Constant *CElt = CV->getAggregateElement(i); - // Method isNullValue always returns false when called on a - // ConstantExpr. If CElt is a ConstantExpr then skip it in order to - // to avoid propagating incorrect information. if (isa<ConstantExpr>(CElt)) continue; + // TODO: If a select condition element is undef, we can demand from + // either side. If one side is known undef, choosing that side would + // propagate undef. if (CElt->isNullValue()) - LeftDemanded.clearBit(i); + DemandedLHS.clearBit(i); else - RightDemanded.clearBit(i); + DemandedRHS.clearBit(i); } } - TmpV = SimplifyDemandedVectorElts(I->getOperand(1), LeftDemanded, UndefElts, - Depth + 1); - if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } - - TmpV = SimplifyDemandedVectorElts(I->getOperand(2), RightDemanded, - UndefElts2, Depth + 1); - if (TmpV) { I->setOperand(2, TmpV); MadeChange = true; } + simplifyAndSetOp(I, 1, DemandedLHS, UndefElts2); + simplifyAndSetOp(I, 2, DemandedRHS, UndefElts3); - // Output elements are undefined if both are undefined. - UndefElts &= UndefElts2; + // Output elements are undefined if the element from each arm is undefined. + // TODO: This can be improved. See comment in select condition handling. + UndefElts = UndefElts2 & UndefElts3; break; } case Instruction::BitCast: { @@ -1323,12 +1396,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), InputDemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { - I->setOperand(0, TmpV); - MadeChange = true; - } + simplifyAndSetOp(I, 0, InputDemandedElts, UndefElts2); if (VWidth == InVWidth) { UndefElts = UndefElts2; @@ -1353,29 +1421,9 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } break; } - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: - case Instruction::Add: - case Instruction::Sub: - case Instruction::Mul: - // div/rem demand all inputs, because they don't want divide by zero. - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, UndefElts, - Depth + 1); - if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } - TmpV = SimplifyDemandedVectorElts(I->getOperand(1), DemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; } - - // Output elements are undefined if both are undefined. Consider things - // like undef&0. The result is known zero, not undef. - UndefElts &= UndefElts2; - break; case Instruction::FPTrunc: case Instruction::FPExt: - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts, UndefElts, - Depth + 1); - if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); break; case Instruction::Call: { @@ -1395,9 +1443,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Only the lower element is used. DemandedElts = 1; - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, - UndefElts, Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 0, DemandedElts, UndefElts); // Only the lower element is undefined. The high elements are zero. UndefElts = UndefElts[0]; @@ -1406,9 +1452,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Unary scalar-as-vector operations that work column-wise. case Intrinsic::x86_sse_rcp_ss: case Intrinsic::x86_sse_rsqrt_ss: - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, - UndefElts, Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 0, DemandedElts, UndefElts); // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { @@ -1428,9 +1472,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::x86_sse2_min_sd: case Intrinsic::x86_sse2_max_sd: case Intrinsic::x86_sse2_cmp_sd: { - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, - UndefElts, Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 0, DemandedElts, UndefElts); // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { @@ -1440,9 +1482,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Only lower element is used for operand 1. DemandedElts = 1; - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 1, DemandedElts, UndefElts2); // Lower element is undefined if both lower elements are undefined. // Consider things like undef&0. The result is known zero, not undef. @@ -1459,9 +1499,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Don't use the low element of operand 0. APInt DemandedElts2 = DemandedElts; DemandedElts2.clearBit(0); - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts2, - UndefElts, Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 0, DemandedElts2, UndefElts); // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { @@ -1471,9 +1509,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Only lower element is used for operand 1. DemandedElts = 1; - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 1, DemandedElts, UndefElts2); // Take the high undef elements from operand 0 and take the lower element // from operand 1. @@ -1497,9 +1533,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::x86_avx512_mask_sub_sd_round: case Intrinsic::x86_avx512_mask_max_sd_round: case Intrinsic::x86_avx512_mask_min_sd_round: - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts, - UndefElts, Depth + 1); - if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 0, DemandedElts, UndefElts); // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { @@ -1509,12 +1543,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // Only lower element is used for operand 1 and 2. DemandedElts = 1; - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts, - UndefElts2, Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } - TmpV = SimplifyDemandedVectorElts(II->getArgOperand(2), DemandedElts, - UndefElts3, Depth + 1); - if (TmpV) { II->setArgOperand(2, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 1, DemandedElts, UndefElts2); + simplifyAndSetOp(II, 2, DemandedElts, UndefElts3); // Lower element is undefined if all three lower elements are undefined. // Consider things like undef&0. The result is known zero, not undef. @@ -1559,14 +1589,8 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } // Demand elements from the operand. - auto *Op = II->getArgOperand(OpNum); APInt OpUndefElts(InnerVWidth, 0); - TmpV = SimplifyDemandedVectorElts(Op, OpDemandedElts, OpUndefElts, - Depth + 1); - if (TmpV) { - II->setArgOperand(OpNum, TmpV); - MadeChange = true; - } + simplifyAndSetOp(II, OpNum, OpDemandedElts, OpUndefElts); // Pack the operand's UNDEF elements, one lane at a time. OpUndefElts = OpUndefElts.zext(VWidth); @@ -1594,10 +1618,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // PERMV case Intrinsic::x86_avx2_permd: case Intrinsic::x86_avx2_permps: { - Value *Op1 = II->getArgOperand(1); - TmpV = SimplifyDemandedVectorElts(Op1, DemandedElts, UndefElts, - Depth + 1); - if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + simplifyAndSetOp(II, 1, DemandedElts, UndefElts); break; } @@ -1611,16 +1632,40 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; case Intrinsic::amdgcn_buffer_load: case Intrinsic::amdgcn_buffer_load_format: + case Intrinsic::amdgcn_raw_buffer_load: + case Intrinsic::amdgcn_raw_buffer_load_format: + case Intrinsic::amdgcn_struct_buffer_load: + case Intrinsic::amdgcn_struct_buffer_load_format: return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts); default: { if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID())) - return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts, 0); + return simplifyAMDGCNMemoryIntrinsicDemanded( + II, DemandedElts, 0, II->getNumArgOperands() - 2); break; } - } + } // switch on IntrinsicID break; + } // case Call + } // switch on Opcode + + // TODO: We bail completely on integer div/rem and shifts because they have + // UB/poison potential, but that should be refined. + BinaryOperator *BO; + if (match(I, m_BinOp(BO)) && !BO->isIntDivRem() && !BO->isShift()) { + simplifyAndSetOp(I, 0, DemandedElts, UndefElts); + simplifyAndSetOp(I, 1, DemandedElts, UndefElts2); + + // Any change to an instruction with potential poison must clear those flags + // because we can not guarantee those constraints now. Other analysis may + // determine that it is safe to re-apply the flags. + if (MadeChange) + BO->dropPoisonGeneratingFlags(); + + // Output elements are undefined if both are undefined. Consider things + // like undef & 0. The result is known zero, not undef. + UndefElts &= UndefElts2; } - } + return MadeChange ? I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 1c2de6352fa5..0ad1fc0e791f 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -46,40 +46,34 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" /// Return true if the value is cheaper to scalarize than it is to leave as a -/// vector operation. isConstant indicates whether we're extracting one known -/// element. If false we're extracting a variable index. -static bool cheapToScalarize(Value *V, bool isConstant) { - if (Constant *C = dyn_cast<Constant>(V)) { - if (isConstant) return true; +/// vector operation. IsConstantExtractIndex indicates whether we are extracting +/// one known element from a vector constant. +/// +/// FIXME: It's possible to create more instructions than previously existed. +static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { + // If we can pick a scalar constant value out of a vector, that is free. + if (auto *C = dyn_cast<Constant>(V)) + return IsConstantExtractIndex || C->getSplatValue(); + + // An insertelement to the same constant index as our extract will simplify + // to the scalar inserted element. An insertelement to a different constant + // index is irrelevant to our extract. + if (match(V, m_InsertElement(m_Value(), m_Value(), m_ConstantInt()))) + return IsConstantExtractIndex; + + if (match(V, m_OneUse(m_Load(m_Value())))) + return true; - // If all elts are the same, we can extract it and use any of the values. - if (Constant *Op0 = C->getAggregateElement(0U)) { - for (unsigned i = 1, e = V->getType()->getVectorNumElements(); i != e; - ++i) - if (C->getAggregateElement(i) != Op0) - return false; + Value *V0, *V1; + if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1))))) + if (cheapToScalarize(V0, IsConstantExtractIndex) || + cheapToScalarize(V1, IsConstantExtractIndex)) return true; - } - } - Instruction *I = dyn_cast<Instruction>(V); - if (!I) return false; - // Insert element gets simplified to the inserted element or is deleted if - // this is constant idx extract element and its a constant idx insertelt. - if (I->getOpcode() == Instruction::InsertElement && isConstant && - isa<ConstantInt>(I->getOperand(2))) - return true; - if (I->getOpcode() == Instruction::Load && I->hasOneUse()) - return true; - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) - if (BO->hasOneUse() && - (cheapToScalarize(BO->getOperand(0), isConstant) || - cheapToScalarize(BO->getOperand(1), isConstant))) - return true; - if (CmpInst *CI = dyn_cast<CmpInst>(I)) - if (CI->hasOneUse() && - (cheapToScalarize(CI->getOperand(0), isConstant) || - cheapToScalarize(CI->getOperand(1), isConstant))) + CmpInst::Predicate UnusedPred; + if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1))))) + if (cheapToScalarize(V0, IsConstantExtractIndex) || + cheapToScalarize(V1, IsConstantExtractIndex)) return true; return false; @@ -166,92 +160,176 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { return &EI; } +static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, + InstCombiner::BuilderTy &Builder, + bool IsBigEndian) { + Value *X; + uint64_t ExtIndexC; + if (!match(Ext.getVectorOperand(), m_BitCast(m_Value(X))) || + !X->getType()->isVectorTy() || + !match(Ext.getIndexOperand(), m_ConstantInt(ExtIndexC))) + return nullptr; + + // If this extractelement is using a bitcast from a vector of the same number + // of elements, see if we can find the source element from the source vector: + // extelt (bitcast VecX), IndexC --> bitcast X[IndexC] + Type *SrcTy = X->getType(); + Type *DestTy = Ext.getType(); + unsigned NumSrcElts = SrcTy->getVectorNumElements(); + unsigned NumElts = Ext.getVectorOperandType()->getNumElements(); + if (NumSrcElts == NumElts) + if (Value *Elt = findScalarElement(X, ExtIndexC)) + return new BitCastInst(Elt, DestTy); + + // If the source elements are wider than the destination, try to shift and + // truncate a subset of scalar bits of an insert op. + if (NumSrcElts < NumElts) { + Value *Scalar; + uint64_t InsIndexC; + if (!match(X, m_InsertElement(m_Value(), m_Value(Scalar), + m_ConstantInt(InsIndexC)))) + return nullptr; + + // The extract must be from the subset of vector elements that we inserted + // into. Example: if we inserted element 1 of a <2 x i64> and we are + // extracting an i16 (narrowing ratio = 4), then this extract must be from 1 + // of elements 4-7 of the bitcasted vector. + unsigned NarrowingRatio = NumElts / NumSrcElts; + if (ExtIndexC / NarrowingRatio != InsIndexC) + return nullptr; + + // We are extracting part of the original scalar. How that scalar is + // inserted into the vector depends on the endian-ness. Example: + // Vector Byte Elt Index: 0 1 2 3 4 5 6 7 + // +--+--+--+--+--+--+--+--+ + // inselt <2 x i32> V, <i32> S, 1: |V0|V1|V2|V3|S0|S1|S2|S3| + // extelt <4 x i16> V', 3: | |S2|S3| + // +--+--+--+--+--+--+--+--+ + // If this is little-endian, S2|S3 are the MSB of the 32-bit 'S' value. + // If this is big-endian, S2|S3 are the LSB of the 32-bit 'S' value. + // In this example, we must right-shift little-endian. Big-endian is just a + // truncate. + unsigned Chunk = ExtIndexC % NarrowingRatio; + if (IsBigEndian) + Chunk = NarrowingRatio - 1 - Chunk; + + // Bail out if this is an FP vector to FP vector sequence. That would take + // more instructions than we started with unless there is no shift, and it + // may not be handled as well in the backend. + bool NeedSrcBitcast = SrcTy->getScalarType()->isFloatingPointTy(); + bool NeedDestBitcast = DestTy->isFloatingPointTy(); + if (NeedSrcBitcast && NeedDestBitcast) + return nullptr; + + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); + unsigned ShAmt = Chunk * DestWidth; + + // TODO: This limitation is more strict than necessary. We could sum the + // number of new instructions and subtract the number eliminated to know if + // we can proceed. + if (!X->hasOneUse() || !Ext.getVectorOperand()->hasOneUse()) + if (NeedSrcBitcast || NeedDestBitcast) + return nullptr; + + if (NeedSrcBitcast) { + Type *SrcIntTy = IntegerType::getIntNTy(Scalar->getContext(), SrcWidth); + Scalar = Builder.CreateBitCast(Scalar, SrcIntTy); + } + + if (ShAmt) { + // Bail out if we could end with more instructions than we started with. + if (!Ext.getVectorOperand()->hasOneUse()) + return nullptr; + Scalar = Builder.CreateLShr(Scalar, ShAmt); + } + + if (NeedDestBitcast) { + Type *DestIntTy = IntegerType::getIntNTy(Scalar->getContext(), DestWidth); + return new BitCastInst(Builder.CreateTrunc(Scalar, DestIntTy), DestTy); + } + return new TruncInst(Scalar, DestTy); + } + + return nullptr; +} + Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { - if (Value *V = SimplifyExtractElementInst(EI.getVectorOperand(), - EI.getIndexOperand(), + Value *SrcVec = EI.getVectorOperand(); + Value *Index = EI.getIndexOperand(); + if (Value *V = SimplifyExtractElementInst(SrcVec, Index, SQ.getWithInstruction(&EI))) return replaceInstUsesWith(EI, V); - // If vector val is constant with all elements the same, replace EI with - // that element. We handle a known element # below. - if (Constant *C = dyn_cast<Constant>(EI.getOperand(0))) - if (cheapToScalarize(C, false)) - return replaceInstUsesWith(EI, C->getAggregateElement(0U)); - // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. - if (ConstantInt *IdxC = dyn_cast<ConstantInt>(EI.getOperand(1))) { - unsigned VectorWidth = EI.getVectorOperandType()->getNumElements(); + auto *IndexC = dyn_cast<ConstantInt>(Index); + if (IndexC) { + unsigned NumElts = EI.getVectorOperandType()->getNumElements(); // InstSimplify should handle cases where the index is invalid. - if (!IdxC->getValue().ule(VectorWidth)) + if (!IndexC->getValue().ule(NumElts)) return nullptr; - unsigned IndexVal = IdxC->getZExtValue(); - // This instruction only demands the single element from the input vector. // If the input vector has a single use, simplify it based on this use // property. - if (EI.getOperand(0)->hasOneUse() && VectorWidth != 1) { - APInt UndefElts(VectorWidth, 0); - APInt DemandedMask(VectorWidth, 0); - DemandedMask.setBit(IndexVal); - if (Value *V = SimplifyDemandedVectorElts(EI.getOperand(0), DemandedMask, + if (SrcVec->hasOneUse() && NumElts != 1) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) { EI.setOperand(0, V); return &EI; } } - // If this extractelement is directly using a bitcast from a vector of - // the same number of elements, see if we can find the source element from - // it. In this case, we will end up needing to bitcast the scalars. - if (BitCastInst *BCI = dyn_cast<BitCastInst>(EI.getOperand(0))) { - if (VectorType *VT = dyn_cast<VectorType>(BCI->getOperand(0)->getType())) - if (VT->getNumElements() == VectorWidth) - if (Value *Elt = findScalarElement(BCI->getOperand(0), IndexVal)) - return new BitCastInst(Elt, EI.getType()); - } + if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) + return I; // If there's a vector PHI feeding a scalar use through this extractelement // instruction, try to scalarize the PHI. - if (PHINode *PN = dyn_cast<PHINode>(EI.getOperand(0))) { - Instruction *scalarPHI = scalarizePHI(EI, PN); - if (scalarPHI) - return scalarPHI; - } + if (auto *Phi = dyn_cast<PHINode>(SrcVec)) + if (Instruction *ScalarPHI = scalarizePHI(EI, Phi)) + return ScalarPHI; } - if (Instruction *I = dyn_cast<Instruction>(EI.getOperand(0))) { - // Push extractelement into predecessor operation if legal and - // profitable to do so. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { - if (I->hasOneUse() && - cheapToScalarize(BO, isa<ConstantInt>(EI.getOperand(1)))) { - Value *newEI0 = - Builder.CreateExtractElement(BO->getOperand(0), EI.getOperand(1), - EI.getName()+".lhs"); - Value *newEI1 = - Builder.CreateExtractElement(BO->getOperand(1), EI.getOperand(1), - EI.getName()+".rhs"); - return BinaryOperator::CreateWithCopiedFlags(BO->getOpcode(), - newEI0, newEI1, BO); - } - } else if (InsertElementInst *IE = dyn_cast<InsertElementInst>(I)) { + BinaryOperator *BO; + if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, IndexC)) { + // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index) + Value *X = BO->getOperand(0), *Y = BO->getOperand(1); + Value *E0 = Builder.CreateExtractElement(X, Index); + Value *E1 = Builder.CreateExtractElement(Y, Index); + return BinaryOperator::CreateWithCopiedFlags(BO->getOpcode(), E0, E1, BO); + } + + Value *X, *Y; + CmpInst::Predicate Pred; + if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) && + cheapToScalarize(SrcVec, IndexC)) { + // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) + Value *E0 = Builder.CreateExtractElement(X, Index); + Value *E1 = Builder.CreateExtractElement(Y, Index); + return CmpInst::Create(cast<CmpInst>(SrcVec)->getOpcode(), Pred, E0, E1); + } + + if (auto *I = dyn_cast<Instruction>(SrcVec)) { + if (auto *IE = dyn_cast<InsertElementInst>(I)) { // Extracting the inserted element? - if (IE->getOperand(2) == EI.getOperand(1)) + if (IE->getOperand(2) == Index) return replaceInstUsesWith(EI, IE->getOperand(1)); // If the inserted and extracted elements are constants, they must not // be the same value, extract from the pre-inserted value instead. - if (isa<Constant>(IE->getOperand(2)) && isa<Constant>(EI.getOperand(1))) { - Worklist.AddValue(EI.getOperand(0)); + if (isa<Constant>(IE->getOperand(2)) && IndexC) { + Worklist.AddValue(SrcVec); EI.setOperand(0, IE->getOperand(0)); return &EI; } - } else if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(I)) { + } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) { // If this is extracting an element from a shufflevector, figure out where // it came from and extract from the appropriate input element instead. - if (ConstantInt *Elt = dyn_cast<ConstantInt>(EI.getOperand(1))) { + if (auto *Elt = dyn_cast<ConstantInt>(Index)) { int SrcIdx = SVI->getMaskValue(Elt->getZExtValue()); Value *Src; unsigned LHSWidth = @@ -270,13 +348,12 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { ConstantInt::get(Int32Ty, SrcIdx, false)); } - } else if (CastInst *CI = dyn_cast<CastInst>(I)) { + } else if (auto *CI = dyn_cast<CastInst>(I)) { // Canonicalize extractelement(cast) -> cast(extractelement). // Bitcasts can change the number of vector elements, and they cost // nothing. if (CI->hasOneUse() && (CI->getOpcode() != Instruction::BitCast)) { - Value *EE = Builder.CreateExtractElement(CI->getOperand(0), - EI.getIndexOperand()); + Value *EE = Builder.CreateExtractElement(CI->getOperand(0), Index); Worklist.AddValue(EE); return CastInst::Create(CI->getOpcode(), EE, EI.getType()); } @@ -791,43 +868,62 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (isa<UndefValue>(ScalarOp) || isa<UndefValue>(IdxOp)) replaceInstUsesWith(IE, VecOp); - // If the inserted element was extracted from some other vector, and if the - // indexes are constant, try to turn this into a shufflevector operation. - if (ExtractElementInst *EI = dyn_cast<ExtractElementInst>(ScalarOp)) { - if (isa<ConstantInt>(EI->getOperand(1)) && isa<ConstantInt>(IdxOp)) { - unsigned NumInsertVectorElts = IE.getType()->getNumElements(); - unsigned NumExtractVectorElts = - EI->getOperand(0)->getType()->getVectorNumElements(); - unsigned ExtractedIdx = - cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); - unsigned InsertedIdx = cast<ConstantInt>(IdxOp)->getZExtValue(); - - if (ExtractedIdx >= NumExtractVectorElts) // Out of range extract. - return replaceInstUsesWith(IE, VecOp); - - if (InsertedIdx >= NumInsertVectorElts) // Out of range insert. - return replaceInstUsesWith(IE, UndefValue::get(IE.getType())); - - // If we are extracting a value from a vector, then inserting it right - // back into the same place, just use the input vector. - if (EI->getOperand(0) == VecOp && ExtractedIdx == InsertedIdx) - return replaceInstUsesWith(IE, VecOp); - - // If this insertelement isn't used by some other insertelement, turn it - // (and any insertelements it points to), into one big shuffle. - if (!IE.hasOneUse() || !isa<InsertElementInst>(IE.user_back())) { - SmallVector<Constant*, 16> Mask; - ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); - - // The proposed shuffle may be trivial, in which case we shouldn't - // perform the combine. - if (LR.first != &IE && LR.second != &IE) { - // We now have a shuffle of LHS, RHS, Mask. - if (LR.second == nullptr) - LR.second = UndefValue::get(LR.first->getType()); - return new ShuffleVectorInst(LR.first, LR.second, - ConstantVector::get(Mask)); - } + // If the inserted element was extracted from some other vector and both + // indexes are constant, try to turn this into a shuffle. + uint64_t InsertedIdx, ExtractedIdx; + Value *ExtVecOp; + if (match(IdxOp, m_ConstantInt(InsertedIdx)) && + match(ScalarOp, m_ExtractElement(m_Value(ExtVecOp), + m_ConstantInt(ExtractedIdx)))) { + unsigned NumInsertVectorElts = IE.getType()->getNumElements(); + unsigned NumExtractVectorElts = ExtVecOp->getType()->getVectorNumElements(); + if (ExtractedIdx >= NumExtractVectorElts) // Out of range extract. + return replaceInstUsesWith(IE, VecOp); + + if (InsertedIdx >= NumInsertVectorElts) // Out of range insert. + return replaceInstUsesWith(IE, UndefValue::get(IE.getType())); + + // If we are extracting a value from a vector, then inserting it right + // back into the same place, just use the input vector. + if (ExtVecOp == VecOp && ExtractedIdx == InsertedIdx) + return replaceInstUsesWith(IE, VecOp); + + // TODO: Looking at the user(s) to determine if this insert is a + // fold-to-shuffle opportunity does not match the usual instcombine + // constraints. We should decide if the transform is worthy based only + // on this instruction and its operands, but that may not work currently. + // + // Here, we are trying to avoid creating shuffles before reaching + // the end of a chain of extract-insert pairs. This is complicated because + // we do not generally form arbitrary shuffle masks in instcombine + // (because those may codegen poorly), but collectShuffleElements() does + // exactly that. + // + // The rules for determining what is an acceptable target-independent + // shuffle mask are fuzzy because they evolve based on the backend's + // capabilities and real-world impact. + auto isShuffleRootCandidate = [](InsertElementInst &Insert) { + if (!Insert.hasOneUse()) + return true; + auto *InsertUser = dyn_cast<InsertElementInst>(Insert.user_back()); + if (!InsertUser) + return true; + return false; + }; + + // Try to form a shuffle from a chain of extract-insert ops. + if (isShuffleRootCandidate(IE)) { + SmallVector<Constant*, 16> Mask; + ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); + + // The proposed shuffle may be trivial, in which case we shouldn't + // perform the combine. + if (LR.first != &IE && LR.second != &IE) { + // We now have a shuffle of LHS, RHS, Mask. + if (LR.second == nullptr) + LR.second = UndefValue::get(LR.first->getType()); + return new ShuffleVectorInst(LR.first, LR.second, + ConstantVector::get(Mask)); } } } @@ -857,7 +953,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { /// Return true if we can evaluate the specified expression tree if the vector /// elements were shuffled in a different order. -static bool CanEvaluateShuffled(Value *V, ArrayRef<int> Mask, +static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, unsigned Depth = 5) { // We can always reorder the elements of a constant. if (isa<Constant>(V)) @@ -904,8 +1000,15 @@ static bool CanEvaluateShuffled(Value *V, ArrayRef<int> Mask, case Instruction::FPTrunc: case Instruction::FPExt: case Instruction::GetElementPtr: { + // Bail out if we would create longer vector ops. We could allow creating + // longer vector ops, but that may result in more expensive codegen. We + // would also need to limit the transform to avoid undefined behavior for + // integer div/rem. + Type *ITy = I->getType(); + if (ITy->isVectorTy() && Mask.size() > ITy->getVectorNumElements()) + return false; for (Value *Operand : I->operands()) { - if (!CanEvaluateShuffled(Operand, Mask, Depth-1)) + if (!canEvaluateShuffled(Operand, Mask, Depth - 1)) return false; } return true; @@ -925,7 +1028,7 @@ static bool CanEvaluateShuffled(Value *V, ArrayRef<int> Mask, SeenOnce = true; } } - return CanEvaluateShuffled(I->getOperand(0), Mask, Depth-1); + return canEvaluateShuffled(I->getOperand(0), Mask, Depth - 1); } } return false; @@ -1009,12 +1112,12 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { llvm_unreachable("failed to rebuild vector instructions"); } -Value * -InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { +static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // Mask.size() does not need to be equal to the number of vector elements. assert(V->getType()->isVectorTy() && "can't reorder non-vector elements"); Type *EltTy = V->getType()->getScalarType(); + Type *I32Ty = IntegerType::getInt32Ty(V->getContext()); if (isa<UndefValue>(V)) return UndefValue::get(VectorType::get(EltTy, Mask.size())); @@ -1025,9 +1128,9 @@ InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { SmallVector<Constant *, 16> MaskValues; for (int i = 0, e = Mask.size(); i != e; ++i) { if (Mask[i] == -1) - MaskValues.push_back(UndefValue::get(Builder.getInt32Ty())); + MaskValues.push_back(UndefValue::get(I32Ty)); else - MaskValues.push_back(Builder.getInt32(Mask[i])); + MaskValues.push_back(ConstantInt::get(I32Ty, Mask[i])); } return ConstantExpr::getShuffleVector(C, UndefValue::get(C->getType()), ConstantVector::get(MaskValues)); @@ -1069,7 +1172,7 @@ InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { SmallVector<Value*, 8> NewOps; bool NeedsRebuild = (Mask.size() != I->getType()->getVectorNumElements()); for (int i = 0, e = I->getNumOperands(); i != e; ++i) { - Value *V = EvaluateInDifferentElementOrder(I->getOperand(i), Mask); + Value *V = evaluateInDifferentElementOrder(I->getOperand(i), Mask); NewOps.push_back(V); NeedsRebuild |= (V != I->getOperand(i)); } @@ -1096,11 +1199,11 @@ InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // If element is not in Mask, no need to handle the operand 1 (element to // be inserted). Just evaluate values in operand 0 according to Mask. if (!Found) - return EvaluateInDifferentElementOrder(I->getOperand(0), Mask); + return evaluateInDifferentElementOrder(I->getOperand(0), Mask); - Value *V = EvaluateInDifferentElementOrder(I->getOperand(0), Mask); + Value *V = evaluateInDifferentElementOrder(I->getOperand(0), Mask); return InsertElementInst::Create(V, I->getOperand(1), - Builder.getInt32(Index), "", I); + ConstantInt::get(I32Ty, Index), "", I); } } llvm_unreachable("failed to reorder elements of vector instruction!"); @@ -1350,12 +1453,144 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, return NewBO; } +/// Match a shuffle-select-shuffle pattern where the shuffles are widening and +/// narrowing (concatenating with undef and extracting back to the original +/// length). This allows replacing the wide select with a narrow select. +static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, + InstCombiner::BuilderTy &Builder) { + // This must be a narrowing identity shuffle. It extracts the 1st N elements + // of the 1st vector operand of a shuffle. + if (!match(Shuf.getOperand(1), m_Undef()) || !Shuf.isIdentityWithExtract()) + return nullptr; + + // The vector being shuffled must be a vector select that we can eliminate. + // TODO: The one-use requirement could be eased if X and/or Y are constants. + Value *Cond, *X, *Y; + if (!match(Shuf.getOperand(0), + m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) + return nullptr; + + // We need a narrow condition value. It must be extended with undef elements + // and have the same number of elements as this shuffle. + unsigned NarrowNumElts = Shuf.getType()->getVectorNumElements(); + Value *NarrowCond; + if (!match(Cond, m_OneUse(m_ShuffleVector(m_Value(NarrowCond), m_Undef(), + m_Constant()))) || + NarrowCond->getType()->getVectorNumElements() != NarrowNumElts || + !cast<ShuffleVectorInst>(Cond)->isIdentityWithPadding()) + return nullptr; + + // shuf (sel (shuf NarrowCond, undef, WideMask), X, Y), undef, NarrowMask) --> + // sel NarrowCond, (shuf X, undef, NarrowMask), (shuf Y, undef, NarrowMask) + Value *Undef = UndefValue::get(X->getType()); + Value *NarrowX = Builder.CreateShuffleVector(X, Undef, Shuf.getMask()); + Value *NarrowY = Builder.CreateShuffleVector(Y, Undef, Shuf.getMask()); + return SelectInst::Create(NarrowCond, NarrowX, NarrowY); +} + +/// Try to combine 2 shuffles into 1 shuffle by concatenating a shuffle mask. +static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + if (!Shuf.isIdentityWithExtract() || !isa<UndefValue>(Op1)) + return nullptr; + + Value *X, *Y; + Constant *Mask; + if (!match(Op0, m_ShuffleVector(m_Value(X), m_Value(Y), m_Constant(Mask)))) + return nullptr; + + // We are extracting a subvector from a shuffle. Remove excess elements from + // the 1st shuffle mask to eliminate the extract. + // + // This transform is conservatively limited to identity extracts because we do + // not allow arbitrary shuffle mask creation as a target-independent transform + // (because we can't guarantee that will lower efficiently). + // + // If the extracting shuffle has an undef mask element, it transfers to the + // new shuffle mask. Otherwise, copy the original mask element. Example: + // shuf (shuf X, Y, <C0, C1, C2, undef, C4>), undef, <0, undef, 2, 3> --> + // shuf X, Y, <C0, undef, C2, undef> + unsigned NumElts = Shuf.getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewMask(NumElts); + assert(NumElts < Mask->getType()->getVectorNumElements() && + "Identity with extract must have less elements than its inputs"); + + for (unsigned i = 0; i != NumElts; ++i) { + Constant *ExtractMaskElt = Shuf.getMask()->getAggregateElement(i); + Constant *MaskElt = Mask->getAggregateElement(i); + NewMask[i] = isa<UndefValue>(ExtractMaskElt) ? ExtractMaskElt : MaskElt; + } + return new ShuffleVectorInst(X, Y, ConstantVector::get(NewMask)); +} + +/// Try to replace a shuffle with an insertelement. +static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) { + Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1); + SmallVector<int, 16> Mask = Shuf.getShuffleMask(); + + // The shuffle must not change vector sizes. + // TODO: This restriction could be removed if the insert has only one use + // (because the transform would require a new length-changing shuffle). + int NumElts = Mask.size(); + if (NumElts != (int)(V0->getType()->getVectorNumElements())) + return nullptr; + + // shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC' + auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) { + // We need an insertelement with a constant index. + if (!match(V0, m_InsertElement(m_Value(), m_Value(Scalar), + m_ConstantInt(IndexC)))) + return false; + + // Test the shuffle mask to see if it splices the inserted scalar into the + // operand 1 vector of the shuffle. + int NewInsIndex = -1; + for (int i = 0; i != NumElts; ++i) { + // Ignore undef mask elements. + if (Mask[i] == -1) + continue; + + // The shuffle takes elements of operand 1 without lane changes. + if (Mask[i] == NumElts + i) + continue; + + // The shuffle must choose the inserted scalar exactly once. + if (NewInsIndex != -1 || Mask[i] != IndexC->getSExtValue()) + return false; + + // The shuffle is placing the inserted scalar into element i. + NewInsIndex = i; + } + + assert(NewInsIndex != -1 && "Did not fold shuffle with unused operand?"); + + // Index is updated to the potentially translated insertion lane. + IndexC = ConstantInt::get(IndexC->getType(), NewInsIndex); + return true; + }; + + // If the shuffle is unnecessary, insert the scalar operand directly into + // operand 1 of the shuffle. Example: + // shuffle (insert ?, S, 1), V1, <1, 5, 6, 7> --> insert V1, S, 0 + Value *Scalar; + ConstantInt *IndexC; + if (isShufflingScalarIntoOp1(Scalar, IndexC)) + return InsertElementInst::Create(V1, Scalar, IndexC); + + // Try again after commuting shuffle. Example: + // shuffle V0, (insert ?, S, 0), <0, 1, 2, 4> --> + // shuffle (insert ?, S, 0), V0, <4, 5, 6, 0> --> insert V0, S, 3 + std::swap(V0, V1); + ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); + if (isShufflingScalarIntoOp1(Scalar, IndexC)) + return InsertElementInst::Create(V1, Scalar, IndexC); + + return nullptr; +} + Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); - SmallVector<int, 16> Mask = SVI.getShuffleMask(); - Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); - if (auto *V = SimplifyShuffleVectorInst( LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) return replaceInstUsesWith(SVI, V); @@ -1363,9 +1598,10 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Instruction *I = foldSelectShuffle(SVI, Builder, DL)) return I; - bool MadeChange = false; - unsigned VWidth = SVI.getType()->getVectorNumElements(); + if (Instruction *I = narrowVectorSelect(SVI, Builder)) + return I; + unsigned VWidth = SVI.getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { @@ -1374,18 +1610,22 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return &SVI; } + if (Instruction *I = foldIdentityExtractShuffle(SVI)) + return I; + + // This transform has the potential to lose undef knowledge, so it is + // intentionally placed after SimplifyDemandedVectorElts(). + if (Instruction *I = foldShuffleWithInsert(SVI)) + return I; + + SmallVector<int, 16> Mask = SVI.getShuffleMask(); + Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); unsigned LHSWidth = LHS->getType()->getVectorNumElements(); + bool MadeChange = false; // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). if (LHS == RHS || isa<UndefValue>(LHS)) { - if (isa<UndefValue>(LHS) && LHS == RHS) { - // shuffle(undef,undef,mask) -> undef. - Value *Result = (VWidth == LHSWidth) - ? LHS : UndefValue::get(SVI.getType()); - return replaceInstUsesWith(SVI, Result); - } - // Remap any references to RHS to use LHS. SmallVector<Constant*, 16> Elts; for (unsigned i = 0, e = LHSWidth; i != VWidth; ++i) { @@ -1421,8 +1661,8 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (isRHSID) return replaceInstUsesWith(SVI, RHS); } - if (isa<UndefValue>(RHS) && CanEvaluateShuffled(LHS, Mask)) { - Value *V = EvaluateInDifferentElementOrder(LHS, Mask); + if (isa<UndefValue>(RHS) && canEvaluateShuffled(LHS, Mask)) { + Value *V = evaluateInDifferentElementOrder(LHS, Mask); return replaceInstUsesWith(SVI, V); } diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index cff0d5447290..be7d43bbcf2c 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -57,7 +57,6 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -97,6 +96,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstdint> @@ -120,6 +120,10 @@ DEBUG_COUNTER(VisitCounter, "instcombine-visit", "Controls which instructions are visited"); static cl::opt<bool> +EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), + cl::init(true)); + +static cl::opt<bool> EnableExpensiveCombines("expensive-combines", cl::desc("Enable expensive instruction combines")); @@ -179,7 +183,10 @@ bool InstCombiner::shouldChangeType(unsigned FromWidth, /// a fundamental type in IR, and there are many specialized optimizations for /// i1 types. bool InstCombiner::shouldChangeType(Type *From, Type *To) const { - assert(From->isIntegerTy() && To->isIntegerTy()); + // TODO: This could be extended to allow vectors. Datalayout changes might be + // needed to properly support that. + if (!From->isIntegerTy() || !To->isIntegerTy()) + return false; unsigned FromWidth = From->getPrimitiveSizeInBits(); unsigned ToWidth = To->getPrimitiveSizeInBits(); @@ -747,8 +754,9 @@ Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, /// Given a 'sub' instruction, return the RHS of the instruction if the LHS is a /// constant zero (which is the 'negate' form). Value *InstCombiner::dyn_castNegVal(Value *V) const { - if (BinaryOperator::isNeg(V)) - return BinaryOperator::getNegArgument(V); + Value *NegV; + if (match(V, m_Neg(m_Value(NegV)))) + return NegV; // Constants can be considered to be negated values if they can be folded. if (ConstantInt *C = dyn_cast<ConstantInt>(V)) @@ -1351,22 +1359,46 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { } while (true); } -Instruction *InstCombiner::foldShuffledBinop(BinaryOperator &Inst) { +Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { if (!Inst.getType()->isVectorTy()) return nullptr; + BinaryOperator::BinaryOps Opcode = Inst.getOpcode(); + unsigned NumElts = cast<VectorType>(Inst.getType())->getNumElements(); + Value *LHS = Inst.getOperand(0), *RHS = Inst.getOperand(1); + assert(cast<VectorType>(LHS->getType())->getNumElements() == NumElts); + assert(cast<VectorType>(RHS->getType())->getNumElements() == NumElts); + + // If both operands of the binop are vector concatenations, then perform the + // narrow binop on each pair of the source operands followed by concatenation + // of the results. + Value *L0, *L1, *R0, *R1; + Constant *Mask; + if (match(LHS, m_ShuffleVector(m_Value(L0), m_Value(L1), m_Constant(Mask))) && + match(RHS, m_ShuffleVector(m_Value(R0), m_Value(R1), m_Specific(Mask))) && + LHS->hasOneUse() && RHS->hasOneUse() && + cast<ShuffleVectorInst>(LHS)->isConcat()) { + // This transform does not have the speculative execution constraint as + // below because the shuffle is a concatenation. The new binops are + // operating on exactly the same elements as the existing binop. + // TODO: We could ease the mask requirement to allow different undef lanes, + // but that requires an analysis of the binop-with-undef output value. + Value *NewBO0 = Builder.CreateBinOp(Opcode, L0, R0); + if (auto *BO = dyn_cast<BinaryOperator>(NewBO0)) + BO->copyIRFlags(&Inst); + Value *NewBO1 = Builder.CreateBinOp(Opcode, L1, R1); + if (auto *BO = dyn_cast<BinaryOperator>(NewBO1)) + BO->copyIRFlags(&Inst); + return new ShuffleVectorInst(NewBO0, NewBO1, Mask); + } + // It may not be safe to reorder shuffles and things like div, urem, etc. // because we may trap when executing those ops on unknown vector elements. // See PR20059. if (!isSafeToSpeculativelyExecute(&Inst)) return nullptr; - unsigned VWidth = cast<VectorType>(Inst.getType())->getNumElements(); - Value *LHS = Inst.getOperand(0), *RHS = Inst.getOperand(1); - assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth); - assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth); - auto createBinOpShuffle = [&](Value *X, Value *Y, Constant *M) { - Value *XY = Builder.CreateBinOp(Inst.getOpcode(), X, Y); + Value *XY = Builder.CreateBinOp(Opcode, X, Y); if (auto *BO = dyn_cast<BinaryOperator>(XY)) BO->copyIRFlags(&Inst); return new ShuffleVectorInst(XY, UndefValue::get(XY->getType()), M); @@ -1375,7 +1407,6 @@ Instruction *InstCombiner::foldShuffledBinop(BinaryOperator &Inst) { // If both arguments of the binary operation are shuffles that use the same // mask and shuffle within a single vector, move the shuffle after the binop. Value *V1, *V2; - Constant *Mask; if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))) && match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(Mask))) && V1->getType() == V2->getType() && @@ -1393,42 +1424,69 @@ Instruction *InstCombiner::foldShuffledBinop(BinaryOperator &Inst) { if (match(&Inst, m_c_BinOp( m_OneUse(m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))), m_Constant(C))) && - V1->getType() == Inst.getType()) { + V1->getType()->getVectorNumElements() <= NumElts) { + assert(Inst.getType()->getScalarType() == V1->getType()->getScalarType() && + "Shuffle should not change scalar type"); + // Find constant NewC that has property: // shuffle(NewC, ShMask) = C // If such constant does not exist (example: ShMask=<0,0> and C=<1,2>) // reorder is not possible. A 1-to-1 mapping is not required. Example: // ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = <undef,5,6,undef> + bool ConstOp1 = isa<Constant>(RHS); SmallVector<int, 16> ShMask; ShuffleVectorInst::getShuffleMask(Mask, ShMask); - SmallVector<Constant *, 16> - NewVecC(VWidth, UndefValue::get(C->getType()->getScalarType())); + unsigned SrcVecNumElts = V1->getType()->getVectorNumElements(); + UndefValue *UndefScalar = UndefValue::get(C->getType()->getScalarType()); + SmallVector<Constant *, 16> NewVecC(SrcVecNumElts, UndefScalar); bool MayChange = true; - for (unsigned I = 0; I < VWidth; ++I) { + for (unsigned I = 0; I < NumElts; ++I) { + Constant *CElt = C->getAggregateElement(I); if (ShMask[I] >= 0) { - assert(ShMask[I] < (int)VWidth); - Constant *CElt = C->getAggregateElement(I); + assert(ShMask[I] < (int)NumElts && "Not expecting narrowing shuffle"); Constant *NewCElt = NewVecC[ShMask[I]]; - if (!CElt || (!isa<UndefValue>(NewCElt) && NewCElt != CElt)) { + // Bail out if: + // 1. The constant vector contains a constant expression. + // 2. The shuffle needs an element of the constant vector that can't + // be mapped to a new constant vector. + // 3. This is a widening shuffle that copies elements of V1 into the + // extended elements (extending with undef is allowed). + if (!CElt || (!isa<UndefValue>(NewCElt) && NewCElt != CElt) || + I >= SrcVecNumElts) { MayChange = false; break; } NewVecC[ShMask[I]] = CElt; } + // If this is a widening shuffle, we must be able to extend with undef + // elements. If the original binop does not produce an undef in the high + // lanes, then this transform is not safe. + // TODO: We could shuffle those non-undef constant values into the + // result by using a constant vector (rather than an undef vector) + // as operand 1 of the new binop, but that might be too aggressive + // for target-independent shuffle creation. + if (I >= SrcVecNumElts) { + Constant *MaybeUndef = + ConstOp1 ? ConstantExpr::get(Opcode, UndefScalar, CElt) + : ConstantExpr::get(Opcode, CElt, UndefScalar); + if (!isa<UndefValue>(MaybeUndef)) { + MayChange = false; + break; + } + } } if (MayChange) { Constant *NewC = ConstantVector::get(NewVecC); // It may not be safe to execute a binop on a vector with undef elements // because the entire instruction can be folded to undef or create poison // that did not exist in the original code. - bool ConstOp1 = isa<Constant>(Inst.getOperand(1)); if (Inst.isIntDivRem() || (Inst.isShift() && ConstOp1)) - NewC = getSafeVectorConstantForBinop(Inst.getOpcode(), NewC, ConstOp1); + NewC = getSafeVectorConstantForBinop(Opcode, NewC, ConstOp1); // Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask) // Op(C, shuffle(V1, Mask)) -> shuffle(Op(NewC, V1), Mask) - Value *NewLHS = isa<Constant>(LHS) ? NewC : V1; - Value *NewRHS = isa<Constant>(LHS) ? V1 : NewC; + Value *NewLHS = ConstOp1 ? V1 : NewC; + Value *NewRHS = ConstOp1 ? NewC : V1; return createBinOpShuffle(NewLHS, NewRHS, Mask); } } @@ -1436,6 +1494,62 @@ Instruction *InstCombiner::foldShuffledBinop(BinaryOperator &Inst) { return nullptr; } +/// Try to narrow the width of a binop if at least 1 operand is an extend of +/// of a value. This requires a potentially expensive known bits check to make +/// sure the narrow op does not overflow. +Instruction *InstCombiner::narrowMathIfNoOverflow(BinaryOperator &BO) { + // We need at least one extended operand. + Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1); + + // If this is a sub, we swap the operands since we always want an extension + // on the RHS. The LHS can be an extension or a constant. + if (BO.getOpcode() == Instruction::Sub) + std::swap(Op0, Op1); + + Value *X; + bool IsSext = match(Op0, m_SExt(m_Value(X))); + if (!IsSext && !match(Op0, m_ZExt(m_Value(X)))) + return nullptr; + + // If both operands are the same extension from the same source type and we + // can eliminate at least one (hasOneUse), this might work. + CastInst::CastOps CastOpc = IsSext ? Instruction::SExt : Instruction::ZExt; + Value *Y; + if (!(match(Op1, m_ZExtOrSExt(m_Value(Y))) && X->getType() == Y->getType() && + cast<Operator>(Op1)->getOpcode() == CastOpc && + (Op0->hasOneUse() || Op1->hasOneUse()))) { + // If that did not match, see if we have a suitable constant operand. + // Truncating and extending must produce the same constant. + Constant *WideC; + if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC))) + return nullptr; + Constant *NarrowC = ConstantExpr::getTrunc(WideC, X->getType()); + if (ConstantExpr::getCast(CastOpc, NarrowC, BO.getType()) != WideC) + return nullptr; + Y = NarrowC; + } + + // Swap back now that we found our operands. + if (BO.getOpcode() == Instruction::Sub) + std::swap(X, Y); + + // Both operands have narrow versions. Last step: the math must not overflow + // in the narrow width. + if (!willNotOverflow(BO.getOpcode(), X, Y, BO, IsSext)) + return nullptr; + + // bo (ext X), (ext Y) --> ext (bo X, Y) + // bo (ext X), C --> ext (bo X, C') + Value *NarrowBO = Builder.CreateBinOp(BO.getOpcode(), X, Y, "narrow"); + if (auto *NewBinOp = dyn_cast<BinaryOperator>(NarrowBO)) { + if (IsSext) + NewBinOp->setHasNoSignedWrap(); + else + NewBinOp->setHasNoUnsignedWrap(); + } + return CastInst::Create(CastOpc, NarrowBO, BO.getType()); +} + Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); Type *GEPType = GEP.getType(); @@ -1963,9 +2077,22 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { areMatchingArrayAndVecTypes(GEPEltType, SrcEltType)) || (GEPEltType->isVectorTy() && SrcEltType->isArrayTy() && areMatchingArrayAndVecTypes(SrcEltType, GEPEltType)))) { - GEP.setOperand(0, SrcOp); - GEP.setSourceElementType(SrcEltType); - return &GEP; + + // Create a new GEP here, as using `setOperand()` followed by + // `setSourceElementType()` won't actually update the type of the + // existing GEP Value. Causing issues if this Value is accessed when + // constructing an AddrSpaceCastInst + Value *NGEP = + GEP.isInBounds() + ? Builder.CreateInBoundsGEP(nullptr, SrcOp, {Ops[1], Ops[2]}) + : Builder.CreateGEP(nullptr, SrcOp, {Ops[1], Ops[2]}); + NGEP->takeName(&GEP); + + // Preserve GEP address space to satisfy users + if (NGEP->getType()->getPointerAddressSpace() != GEP.getAddressSpace()) + return new AddrSpaceCastInst(NGEP, GEPType); + + return replaceInstUsesWith(GEP, NGEP); } // See if we can simplify: @@ -2137,14 +2264,21 @@ static bool isAllocSiteRemovable(Instruction *AI, } Instruction *InstCombiner::visitAllocSite(Instruction &MI) { - // If we have a malloc call which is only used in any amount of comparisons - // to null and free calls, delete the calls and replace the comparisons with - // true or false as appropriate. + // If we have a malloc call which is only used in any amount of comparisons to + // null and free calls, delete the calls and replace the comparisons with true + // or false as appropriate. + + // This is based on the principle that we can substitute our own allocation + // function (which will never return null) rather than knowledge of the + // specific function being called. In some sense this can change the permitted + // outputs of a program (when we convert a malloc to an alloca, the fact that + // the allocation is now on the stack is potentially visible, for example), + // but we believe in a permissible manner. SmallVector<WeakTrackingVH, 64> Users; // If we are removing an alloca with a dbg.declare, insert dbg.value calls // before each store. - TinyPtrVector<DbgInfoIntrinsic *> DIIs; + TinyPtrVector<DbgVariableIntrinsic *> DIIs; std::unique_ptr<DIBuilder> DIB; if (isa<AllocaInst>(MI)) { DIIs = FindDbgAddrUses(&MI); @@ -2215,14 +2349,14 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { /// The move is performed only if the block containing the call to free /// will be removed, i.e.: /// 1. it has only one predecessor P, and P has two successors -/// 2. it contains the call and an unconditional branch +/// 2. it contains the call, noops, and an unconditional branch /// 3. its successor is the same as its predecessor's successor /// /// The profitability is out-of concern here and this function should /// be called only if the caller knows this transformation would be /// profitable (e.g., for code size). -static Instruction * -tryToMoveFreeBeforeNullTest(CallInst &FI) { +static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, + const DataLayout &DL) { Value *Op = FI.getArgOperand(0); BasicBlock *FreeInstrBB = FI.getParent(); BasicBlock *PredBB = FreeInstrBB->getSinglePredecessor(); @@ -2235,20 +2369,34 @@ tryToMoveFreeBeforeNullTest(CallInst &FI) { return nullptr; // Validate constraint #2: Does this block contains only the call to - // free and an unconditional branch? - // FIXME: We could check if we can speculate everything in the - // predecessor block - if (FreeInstrBB->size() != 2) - return nullptr; + // free, noops, and an unconditional branch? BasicBlock *SuccBB; - if (!match(FreeInstrBB->getTerminator(), m_UnconditionalBr(SuccBB))) + Instruction *FreeInstrBBTerminator = FreeInstrBB->getTerminator(); + if (!match(FreeInstrBBTerminator, m_UnconditionalBr(SuccBB))) return nullptr; + // If there are only 2 instructions in the block, at this point, + // this is the call to free and unconditional. + // If there are more than 2 instructions, check that they are noops + // i.e., they won't hurt the performance of the generated code. + if (FreeInstrBB->size() != 2) { + for (const Instruction &Inst : *FreeInstrBB) { + if (&Inst == &FI || &Inst == FreeInstrBBTerminator) + continue; + auto *Cast = dyn_cast<CastInst>(&Inst); + if (!Cast || !Cast->isNoopCast(DL)) + return nullptr; + } + } // Validate the rest of constraint #1 by matching on the pred branch. - TerminatorInst *TI = PredBB->getTerminator(); + Instruction *TI = PredBB->getTerminator(); BasicBlock *TrueBB, *FalseBB; ICmpInst::Predicate Pred; - if (!match(TI, m_Br(m_ICmp(Pred, m_Specific(Op), m_Zero()), TrueBB, FalseBB))) + if (!match(TI, m_Br(m_ICmp(Pred, + m_CombineOr(m_Specific(Op), + m_Specific(Op->stripPointerCasts())), + m_Zero()), + TrueBB, FalseBB))) return nullptr; if (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE) return nullptr; @@ -2259,7 +2407,17 @@ tryToMoveFreeBeforeNullTest(CallInst &FI) { assert(FreeInstrBB == (Pred == ICmpInst::ICMP_EQ ? FalseBB : TrueBB) && "Broken CFG: missing edge from predecessor to successor"); - FI.moveBefore(TI); + // At this point, we know that everything in FreeInstrBB can be moved + // before TI. + for (BasicBlock::iterator It = FreeInstrBB->begin(), End = FreeInstrBB->end(); + It != End;) { + Instruction &Instr = *It++; + if (&Instr == FreeInstrBBTerminator) + break; + Instr.moveBefore(TI); + } + assert(FreeInstrBB->size() == 1 && + "Only the branch instruction should remain"); return &FI; } @@ -2286,7 +2444,7 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { // into // free(foo); if (MinimizeSize) - if (Instruction *I = tryToMoveFreeBeforeNullTest(FI)) + if (Instruction *I = tryToMoveFreeBeforeNullTest(FI, DL)) return I; return nullptr; @@ -2379,9 +2537,11 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { unsigned NewWidth = Known.getBitWidth() - std::max(LeadingKnownZeros, LeadingKnownOnes); // Shrink the condition operand if the new type is smaller than the old type. - // This may produce a non-standard type for the switch, but that's ok because - // the backend should extend back to a legal type for the target. - if (NewWidth > 0 && NewWidth < Known.getBitWidth()) { + // But do not shrink to a non-standard type, because backend can't generate + // good code for that yet. + // TODO: We can make it aggressive again after fixing PR39569. + if (NewWidth > 0 && NewWidth < Known.getBitWidth() && + shouldChangeType(Known.getBitWidth(), NewWidth)) { IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); Builder.SetInsertPoint(&SI); Value *NewCond = Builder.CreateTrunc(Cond, Ty, "trunc"); @@ -2902,7 +3062,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { // Cannot move control-flow-involving, volatile loads, vaarg, etc. if (isa<PHINode>(I) || I->isEHPad() || I->mayHaveSideEffects() || - isa<TerminatorInst>(I)) + I->isTerminator()) return false; // Do not sink alloca instructions out of the entry block. @@ -2934,7 +3094,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { // Also sink all related debug uses from the source basic block. Otherwise we // get debug use before the def. - SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; + SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; findDbgUsers(DbgUsers, I); for (auto *DII : DbgUsers) { if (DII->getParent() == SrcBlock) { @@ -3000,7 +3160,7 @@ bool InstCombiner::run() { } // See if we can trivially sink this instruction to a successor basic block. - if (I->hasOneUse()) { + if (EnableCodeSinking && I->hasOneUse()) { BasicBlock *BB = I->getParent(); Instruction *UserInst = cast<Instruction>(*I->user_begin()); BasicBlock *UserParent; @@ -3183,7 +3343,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, // Recursively visit successors. If this is a branch or switch on a // constant, only visit the reachable successor. - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { if (BI->isConditional() && isa<ConstantInt>(BI->getCondition())) { bool CondVal = cast<ConstantInt>(BI->getCondition())->getZExtValue(); @@ -3198,7 +3358,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, } } - for (BasicBlock *SuccBB : TI->successors()) + for (BasicBlock *SuccBB : successors(TI)) Worklist.push_back(SuccBB); } while (!Worklist.empty()); diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 6af44354225c..f1558c75cb90 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -109,6 +109,7 @@ static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kNetBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kNetBSD_ShadowOffset64 = 1ULL << 46; +static const uint64_t kNetBSDKasan_ShadowOffset64 = 0xdfff900000000000; static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; @@ -344,10 +345,14 @@ static cl::opt<uint32_t> ClForceExperiment( cl::init(0)); static cl::opt<bool> - ClUsePrivateAliasForGlobals("asan-use-private-alias", - cl::desc("Use private aliases for global" - " variables"), - cl::Hidden, cl::init(false)); + ClUsePrivateAlias("asan-use-private-alias", + cl::desc("Use private aliases for global variables"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> + ClUseOdrIndicator("asan-use-odr-indicator", + cl::desc("Use odr indicators to improve ODR reporting"), + cl::Hidden, cl::init(false)); static cl::opt<bool> ClUseGlobalsGC("asan-globals-live-support", @@ -436,8 +441,11 @@ public: for (auto MDN : Globals->operands()) { // Metadata node contains the global and the fields of "Entry". assert(MDN->getNumOperands() == 5); - auto *GV = mdconst::extract_or_null<GlobalVariable>(MDN->getOperand(0)); + auto *V = mdconst::extract_or_null<Constant>(MDN->getOperand(0)); // The optimizer may optimize away a global entirely. + if (!V) continue; + auto *StrippedV = V->stripPointerCasts(); + auto *GV = dyn_cast<GlobalVariable>(StrippedV); if (!GV) continue; // We can already have an entry for GV if it was merged with another // global. @@ -538,11 +546,14 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, Mapping.Offset = kPPC64_ShadowOffset64; else if (IsSystemZ) Mapping.Offset = kSystemZ_ShadowOffset64; - else if (IsFreeBSD) + else if (IsFreeBSD && !IsMIPS64) Mapping.Offset = kFreeBSD_ShadowOffset64; - else if (IsNetBSD) - Mapping.Offset = kNetBSD_ShadowOffset64; - else if (IsPS4CPU) + else if (IsNetBSD) { + if (IsKasan) + Mapping.Offset = kNetBSDKasan_ShadowOffset64; + else + Mapping.Offset = kNetBSD_ShadowOffset64; + } else if (IsPS4CPU) Mapping.Offset = kPS4CPU_ShadowOffset64; else if (IsLinux && IsX86_64) { if (IsKasan) @@ -731,9 +742,12 @@ public: explicit AddressSanitizerModule(bool CompileKernel = false, bool Recover = false, - bool UseGlobalsGC = true) - : ModulePass(ID), - UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), + bool UseGlobalsGC = true, + bool UseOdrIndicator = false) + : ModulePass(ID), UseGlobalsGC(UseGlobalsGC && ClUseGlobalsGC), + // Enable aliases as they should have no downside with ODR indicators. + UsePrivateAlias(UseOdrIndicator || ClUsePrivateAlias), + UseOdrIndicator(UseOdrIndicator || ClUseOdrIndicator), // Not a typo: ClWithComdat is almost completely pointless without // ClUseGlobalsGC (because then it only works on modules without // globals, which are rare); it is a prerequisite for ClUseGlobalsGC; @@ -742,11 +756,10 @@ public: // ClWithComdat and ClUseGlobalsGC unless the frontend says it's ok to // do globals-gc. UseCtorComdat(UseGlobalsGC && ClWithComdat) { - this->Recover = ClRecover.getNumOccurrences() > 0 ? - ClRecover : Recover; - this->CompileKernel = ClEnableKasan.getNumOccurrences() > 0 ? - ClEnableKasan : CompileKernel; - } + this->Recover = ClRecover.getNumOccurrences() > 0 ? ClRecover : Recover; + this->CompileKernel = + ClEnableKasan.getNumOccurrences() > 0 ? ClEnableKasan : CompileKernel; + } bool runOnModule(Module &M) override; StringRef getPassName() const override { return "AddressSanitizerModule"; } @@ -790,6 +803,8 @@ private: bool CompileKernel; bool Recover; bool UseGlobalsGC; + bool UsePrivateAlias; + bool UseOdrIndicator; bool UseCtorComdat; Type *IntptrTy; LLVMContext *C; @@ -990,7 +1005,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { if (ID == Intrinsic::localescape) LocalEscapeCall = &II; if (!ASan.UseAfterScope) return; - if (ID != Intrinsic::lifetime_start && ID != Intrinsic::lifetime_end) + if (!II.isLifetimeStartOrEnd()) return; // Found lifetime intrinsic, add ASan instrumentation if necessary. ConstantInt *Size = dyn_cast<ConstantInt>(II.getArgOperand(0)); @@ -1089,9 +1104,11 @@ INITIALIZE_PASS( ModulePass *llvm::createAddressSanitizerModulePass(bool CompileKernel, bool Recover, - bool UseGlobalsGC) { + bool UseGlobalsGC, + bool UseOdrIndicator) { assert(!CompileKernel || Recover); - return new AddressSanitizerModule(CompileKernel, Recover, UseGlobalsGC); + return new AddressSanitizerModule(CompileKernel, Recover, UseGlobalsGC, + UseOdrIndicator); } static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { @@ -1100,25 +1117,11 @@ static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { return Res; } -// Create a constant for Str so that we can pass it to the run-time lib. -static GlobalVariable *createPrivateGlobalForString(Module &M, StringRef Str, - bool AllowMerging) { - Constant *StrConst = ConstantDataArray::getString(M.getContext(), Str); - // We use private linkage for module-local strings. If they can be merged - // with another one, we set the unnamed_addr attribute. - GlobalVariable *GV = - new GlobalVariable(M, StrConst->getType(), true, - GlobalValue::PrivateLinkage, StrConst, kAsanGenPrefix); - if (AllowMerging) GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); - GV->setAlignment(1); // Strings may not be merged w/o setting align 1. - return GV; -} - /// Create a global describing a source location. static GlobalVariable *createPrivateGlobalForSourceLoc(Module &M, LocationMetadata MD) { Constant *LocData[] = { - createPrivateGlobalForString(M, MD.Filename, true), + createPrivateGlobalForString(M, MD.Filename, true, kAsanGenPrefix), ConstantInt::get(Type::getInt32Ty(M.getContext()), MD.LineNo), ConstantInt::get(Type::getInt32Ty(M.getContext()), MD.ColumnNo), }; @@ -1132,6 +1135,10 @@ static GlobalVariable *createPrivateGlobalForSourceLoc(Module &M, /// Check if \p G has been created by a trusted compiler pass. static bool GlobalWasGeneratedByCompiler(GlobalVariable *G) { + // Do not instrument @llvm.global_ctors, @llvm.used, etc. + if (G->getName().startswith("llvm.")) + return true; + // Do not instrument asan globals. if (G->getName().startswith(kAsanGenPrefix) || G->getName().startswith(kSanCovGenPrefix) || @@ -1379,7 +1386,7 @@ static void instrumentMaskedLoadOrStore(AddressSanitizer *Pass, } else { IRBuilder<> IRB(I); Value *MaskElem = IRB.CreateExtractElement(Mask, Idx); - TerminatorInst *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false); + Instruction *ThenTerm = SplitBlockAndInsertIfThen(MaskElem, I, false); InsertBefore = ThenTerm; } @@ -1532,8 +1539,9 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, Value *TagCheck = IRB.CreateICmpEQ(Tag, ConstantInt::get(IntptrTy, kMyriadDDRTag)); - TerminatorInst *TagCheckTerm = SplitBlockAndInsertIfThen( - TagCheck, InsertBefore, false, MDBuilder(*C).createBranchWeights(1, 100000)); + Instruction *TagCheckTerm = + SplitBlockAndInsertIfThen(TagCheck, InsertBefore, false, + MDBuilder(*C).createBranchWeights(1, 100000)); assert(cast<BranchInst>(TagCheckTerm)->isUnconditional()); IRB.SetInsertPoint(TagCheckTerm); InsertBefore = TagCheckTerm; @@ -1549,12 +1557,12 @@ void AddressSanitizer::instrumentAddress(Instruction *OrigIns, Value *Cmp = IRB.CreateICmpNE(ShadowValue, CmpVal); size_t Granularity = 1ULL << Mapping.Scale; - TerminatorInst *CrashTerm = nullptr; + Instruction *CrashTerm = nullptr; if (ClAlwaysSlowPath || (TypeSize < 8 * Granularity)) { // We use branch weights for the slow path check, to indicate that the slow // path is rarely taken. This seems to be the case for SPEC benchmarks. - TerminatorInst *CheckTerm = SplitBlockAndInsertIfThen( + Instruction *CheckTerm = SplitBlockAndInsertIfThen( Cmp, InsertBefore, false, MDBuilder(*C).createBranchWeights(1, 100000)); assert(cast<BranchInst>(CheckTerm)->isUnconditional()); BasicBlock *NextBB = CheckTerm->getSuccessor(0); @@ -1653,14 +1661,6 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { if (!Ty->isSized()) return false; if (!G->hasInitializer()) return false; if (GlobalWasGeneratedByCompiler(G)) return false; // Our own globals. - // Touch only those globals that will not be defined in other modules. - // Don't handle ODR linkage types and COMDATs since other modules may be built - // without ASan. - if (G->getLinkage() != GlobalVariable::ExternalLinkage && - G->getLinkage() != GlobalVariable::PrivateLinkage && - G->getLinkage() != GlobalVariable::InternalLinkage) - return false; - if (G->hasComdat()) return false; // Two problems with thread-locals: // - The address of the main thread's copy can't be computed at link-time. // - Need to poison all copies, not just the main thread's one. @@ -1668,6 +1668,33 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { // For now, just ignore this Global if the alignment is large. if (G->getAlignment() > MinRedzoneSizeForGlobal()) return false; + // For non-COFF targets, only instrument globals known to be defined by this + // TU. + // FIXME: We can instrument comdat globals on ELF if we are using the + // GC-friendly metadata scheme. + if (!TargetTriple.isOSBinFormatCOFF()) { + if (!G->hasExactDefinition() || G->hasComdat()) + return false; + } else { + // On COFF, don't instrument non-ODR linkages. + if (G->isInterposable()) + return false; + } + + // If a comdat is present, it must have a selection kind that implies ODR + // semantics: no duplicates, any, or exact match. + if (Comdat *C = G->getComdat()) { + switch (C->getSelectionKind()) { + case Comdat::Any: + case Comdat::ExactMatch: + case Comdat::NoDuplicates: + break; + case Comdat::Largest: + case Comdat::SameSize: + return false; + } + } + if (G->hasSection()) { StringRef Section = G->getSection(); @@ -2082,7 +2109,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool // We shouldn't merge same module names, as this string serves as unique // module ID in runtime. GlobalVariable *ModuleName = createPrivateGlobalForString( - M, M.getModuleIdentifier(), /*AllowMerging*/ false); + M, M.getModuleIdentifier(), /*AllowMerging*/ false, kAsanGenPrefix); for (size_t i = 0; i < n; i++) { static const uint64_t kMaxGlobalRedzone = 1 << 18; @@ -2094,7 +2121,7 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool // if it's available, otherwise just write the name of global variable). GlobalVariable *Name = createPrivateGlobalForString( M, MD.Name.empty() ? NameForGlobal : MD.Name, - /*AllowMerging*/ true); + /*AllowMerging*/ true, kAsanGenPrefix); Type *Ty = G->getValueType(); uint64_t SizeInBytes = DL.getTypeAllocSize(Ty); @@ -2121,7 +2148,12 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool new GlobalVariable(M, NewTy, G->isConstant(), Linkage, NewInitializer, "", G, G->getThreadLocalMode()); NewGlobal->copyAttributesFrom(G); + NewGlobal->setComdat(G->getComdat()); NewGlobal->setAlignment(MinRZ); + // Don't fold globals with redzones. ODR violation detector and redzone + // poisoning implicitly creates a dependence on the global's address, so it + // is no longer valid for it to be marked unnamed_addr. + NewGlobal->setUnnamedAddr(GlobalValue::UnnamedAddr::None); // Move null-terminated C strings to "__asan_cstring" section on Darwin. if (TargetTriple.isOSBinFormatMachO() && !G->hasSection() && @@ -2162,12 +2194,18 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool bool CanUsePrivateAliases = TargetTriple.isOSBinFormatELF() || TargetTriple.isOSBinFormatMachO() || TargetTriple.isOSBinFormatWasm(); - if (CanUsePrivateAliases && ClUsePrivateAliasForGlobals) { + if (CanUsePrivateAliases && UsePrivateAlias) { // Create local alias for NewGlobal to avoid crash on ODR between // instrumented and non-instrumented libraries. - auto *GA = GlobalAlias::create(GlobalValue::InternalLinkage, - NameForGlobal + M.getName(), NewGlobal); + InstrumentedGlobal = + GlobalAlias::create(GlobalValue::PrivateLinkage, "", NewGlobal); + } + // ODR should not happen for local linkage. + if (NewGlobal->hasLocalLinkage()) { + ODRIndicator = ConstantExpr::getIntToPtr(ConstantInt::get(IntptrTy, -1), + IRB.getInt8PtrTy()); + } else if (UseOdrIndicator) { // With local aliases, we need to provide another externally visible // symbol __odr_asan_XXX to detect ODR violation. auto *ODRIndicatorSym = @@ -2181,7 +2219,6 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool ODRIndicatorSym->setDLLStorageClass(NewGlobal->getDLLStorageClass()); ODRIndicatorSym->setAlignment(1); ODRIndicator = ODRIndicatorSym; - InstrumentedGlobal = GA; } Constant *Initializer = ConstantStruct::get( @@ -2996,7 +3033,7 @@ void FunctionStackPoisoner::processStaticAllocas() { IntptrPtrTy); GlobalVariable *StackDescriptionGlobal = createPrivateGlobalForString(*F.getParent(), DescriptionString, - /*AllowMerging*/ true); + /*AllowMerging*/ true, kAsanGenPrefix); Value *Description = IRB.CreatePointerCast(StackDescriptionGlobal, IntptrTy); IRB.CreateStore(Description, BasePlus1); // Write the PC to redzone[2]. @@ -3054,7 +3091,7 @@ void FunctionStackPoisoner::processStaticAllocas() { // <This is not a fake stack; unpoison the redzones> Value *Cmp = IRBRet.CreateICmpNE(FakeStack, Constant::getNullValue(IntptrTy)); - TerminatorInst *ThenTerm, *ElseTerm; + Instruction *ThenTerm, *ElseTerm; SplitBlockAndInsertIfThenElse(Cmp, Ret, &ThenTerm, &ElseTerm); IRBuilder<> IRBPoison(ThenTerm); diff --git a/lib/Transforms/Instrumentation/BoundsChecking.cpp b/lib/Transforms/Instrumentation/BoundsChecking.cpp index e13db08e263c..a0c78e0468c6 100644 --- a/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -47,21 +47,17 @@ STATISTIC(ChecksUnable, "Bounds checks unable to add"); using BuilderTy = IRBuilder<TargetFolder>; -/// Adds run-time bounds checks to memory accessing instructions. +/// Gets the conditions under which memory accessing instructions will overflow. /// /// \p Ptr is the pointer that will be read/written, and \p InstVal is either /// the result from the load or the value being stored. It is used to determine /// the size of memory block that is touched. /// -/// \p GetTrapBB is a callable that returns the trap BB to use on failure. -/// -/// Returns true if any change was made to the IR, false otherwise. -template <typename GetTrapBBT> -static bool instrumentMemAccess(Value *Ptr, Value *InstVal, - const DataLayout &DL, TargetLibraryInfo &TLI, - ObjectSizeOffsetEvaluator &ObjSizeEval, - BuilderTy &IRB, GetTrapBBT GetTrapBB, - ScalarEvolution &SE) { +/// Returns the condition under which the access will overflow. +static Value *getBoundsCheckCond(Value *Ptr, Value *InstVal, + const DataLayout &DL, TargetLibraryInfo &TLI, + ObjectSizeOffsetEvaluator &ObjSizeEval, + BuilderTy &IRB, ScalarEvolution &SE) { uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType()); LLVM_DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize) << " bytes\n"); @@ -70,7 +66,7 @@ static bool instrumentMemAccess(Value *Ptr, Value *InstVal, if (!ObjSizeEval.bothKnown(SizeOffset)) { ++ChecksUnable; - return false; + return nullptr; } Value *Size = SizeOffset.first; @@ -107,13 +103,23 @@ static bool instrumentMemAccess(Value *Ptr, Value *InstVal, Or = IRB.CreateOr(Cmp1, Or); } + return Or; +} + +/// Adds run-time bounds checks to memory accessing instructions. +/// +/// \p Or is the condition that should guard the trap. +/// +/// \p GetTrapBB is a callable that returns the trap BB to use on failure. +template <typename GetTrapBBT> +static void insertBoundsCheck(Value *Or, BuilderTy IRB, GetTrapBBT GetTrapBB) { // check if the comparison is always false ConstantInt *C = dyn_cast_or_null<ConstantInt>(Or); if (C) { ++ChecksSkipped; // If non-zero, nothing to do. if (!C->getZExtValue()) - return true; + return; } ++ChecksAdded; @@ -127,12 +133,11 @@ static bool instrumentMemAccess(Value *Ptr, Value *InstVal, // FIXME: We should really handle this differently to bypass the splitting // the block. BranchInst::Create(GetTrapBB(IRB), OldBB); - return true; + return; } // Create the conditional branch. BranchInst::Create(GetTrapBB(IRB), Cont, Or, OldBB); - return true; } static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, @@ -143,11 +148,25 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory // touching instructions - std::vector<Instruction *> WorkList; + SmallVector<std::pair<Instruction *, Value *>, 4> TrapInfo; for (Instruction &I : instructions(F)) { - if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<AtomicCmpXchgInst>(I) || - isa<AtomicRMWInst>(I)) - WorkList.push_back(&I); + Value *Or = nullptr; + BuilderTy IRB(I.getParent(), BasicBlock::iterator(&I), TargetFolder(DL)); + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + Or = getBoundsCheckCond(LI->getPointerOperand(), LI, DL, TLI, + ObjSizeEval, IRB, SE); + } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { + Or = getBoundsCheckCond(SI->getPointerOperand(), SI->getValueOperand(), + DL, TLI, ObjSizeEval, IRB, SE); + } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(&I)) { + Or = getBoundsCheckCond(AI->getPointerOperand(), AI->getCompareOperand(), + DL, TLI, ObjSizeEval, IRB, SE); + } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(&I)) { + Or = getBoundsCheckCond(AI->getPointerOperand(), AI->getValOperand(), DL, + TLI, ObjSizeEval, IRB, SE); + } + if (Or) + TrapInfo.push_back(std::make_pair(&I, Or)); } // Create a trapping basic block on demand using a callback. Depending on @@ -176,29 +195,14 @@ static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI, return TrapBB; }; - bool MadeChange = false; - for (Instruction *Inst : WorkList) { + // Add the checks. + for (const auto &Entry : TrapInfo) { + Instruction *Inst = Entry.first; BuilderTy IRB(Inst->getParent(), BasicBlock::iterator(Inst), TargetFolder(DL)); - if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { - MadeChange |= instrumentMemAccess(LI->getPointerOperand(), LI, DL, TLI, - ObjSizeEval, IRB, GetTrapBB, SE); - } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - MadeChange |= - instrumentMemAccess(SI->getPointerOperand(), SI->getValueOperand(), - DL, TLI, ObjSizeEval, IRB, GetTrapBB, SE); - } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) { - MadeChange |= - instrumentMemAccess(AI->getPointerOperand(), AI->getCompareOperand(), - DL, TLI, ObjSizeEval, IRB, GetTrapBB, SE); - } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(Inst)) { - MadeChange |= - instrumentMemAccess(AI->getPointerOperand(), AI->getValOperand(), DL, - TLI, ObjSizeEval, IRB, GetTrapBB, SE); - } else { - llvm_unreachable("unknown Instruction type"); - } + insertBoundsCheck(Entry.second, IRB, GetTrapBB); } - return MadeChange; + + return !TrapInfo.empty(); } PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager &AM) { diff --git a/lib/Transforms/Instrumentation/CFGMST.h b/lib/Transforms/Instrumentation/CFGMST.h index cc9b149d0b6a..e178ef386e68 100644 --- a/lib/Transforms/Instrumentation/CFGMST.h +++ b/lib/Transforms/Instrumentation/CFGMST.h @@ -119,7 +119,7 @@ public: static const uint32_t CriticalEdgeMultiplier = 1000; for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); uint64_t BBWeight = (BFI != nullptr ? BFI->getBlockFreq(&*BB).getFrequency() : 2); uint64_t Weight = 2; diff --git a/lib/Transforms/Instrumentation/CGProfile.cpp b/lib/Transforms/Instrumentation/CGProfile.cpp index 9606b3da2475..cdcd01726906 100644 --- a/lib/Transforms/Instrumentation/CGProfile.cpp +++ b/lib/Transforms/Instrumentation/CGProfile.cpp @@ -88,11 +88,10 @@ void CGProfilePass::addModuleFlags( std::vector<Metadata *> Nodes; for (auto E : Counts) { - SmallVector<Metadata *, 3> Vals; - Vals.push_back(ValueAsMetadata::get(E.first.first)); - Vals.push_back(ValueAsMetadata::get(E.first.second)); - Vals.push_back(MDB.createConstant( - ConstantInt::get(Type::getInt64Ty(Context), E.second))); + Metadata *Vals[] = {ValueAsMetadata::get(E.first.first), + ValueAsMetadata::get(E.first.second), + MDB.createConstant(ConstantInt::get( + Type::getInt64Ty(Context), E.second))}; Nodes.push_back(MDNode::get(Context, Vals)); } diff --git a/lib/Transforms/Instrumentation/CMakeLists.txt b/lib/Transforms/Instrumentation/CMakeLists.txt index 5d0084823190..94461849d509 100644 --- a/lib/Transforms/Instrumentation/CMakeLists.txt +++ b/lib/Transforms/Instrumentation/CMakeLists.txt @@ -2,6 +2,7 @@ add_llvm_library(LLVMInstrumentation AddressSanitizer.cpp BoundsChecking.cpp CGProfile.cpp + ControlHeightReduction.cpp DataFlowSanitizer.cpp GCOVProfiling.cpp MemorySanitizer.cpp diff --git a/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp new file mode 100644 index 000000000000..1ada0b713092 --- /dev/null +++ b/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -0,0 +1,2074 @@ +//===-- ControlHeightReduction.cpp - Control Height Reduction -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass merges conditional blocks of code and reduces the number of +// conditional branches in the hot paths based on profiles. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Instrumentation/ControlHeightReduction.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/Analysis/RegionIterator.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#include <set> +#include <sstream> + +using namespace llvm; + +#define DEBUG_TYPE "chr" + +#define CHR_DEBUG(X) LLVM_DEBUG(X) + +static cl::opt<bool> ForceCHR("force-chr", cl::init(false), cl::Hidden, + cl::desc("Apply CHR for all functions")); + +static cl::opt<double> CHRBiasThreshold( + "chr-bias-threshold", cl::init(0.99), cl::Hidden, + cl::desc("CHR considers a branch bias greater than this ratio as biased")); + +static cl::opt<unsigned> CHRMergeThreshold( + "chr-merge-threshold", cl::init(2), cl::Hidden, + cl::desc("CHR merges a group of N branches/selects where N >= this value")); + +static cl::opt<std::string> CHRModuleList( + "chr-module-list", cl::init(""), cl::Hidden, + cl::desc("Specify file to retrieve the list of modules to apply CHR to")); + +static cl::opt<std::string> CHRFunctionList( + "chr-function-list", cl::init(""), cl::Hidden, + cl::desc("Specify file to retrieve the list of functions to apply CHR to")); + +static StringSet<> CHRModules; +static StringSet<> CHRFunctions; + +static void parseCHRFilterFiles() { + if (!CHRModuleList.empty()) { + auto FileOrErr = MemoryBuffer::getFile(CHRModuleList); + if (!FileOrErr) { + errs() << "Error: Couldn't read the chr-module-list file " << CHRModuleList << "\n"; + std::exit(1); + } + StringRef Buf = FileOrErr->get()->getBuffer(); + SmallVector<StringRef, 0> Lines; + Buf.split(Lines, '\n'); + for (StringRef Line : Lines) { + Line = Line.trim(); + if (!Line.empty()) + CHRModules.insert(Line); + } + } + if (!CHRFunctionList.empty()) { + auto FileOrErr = MemoryBuffer::getFile(CHRFunctionList); + if (!FileOrErr) { + errs() << "Error: Couldn't read the chr-function-list file " << CHRFunctionList << "\n"; + std::exit(1); + } + StringRef Buf = FileOrErr->get()->getBuffer(); + SmallVector<StringRef, 0> Lines; + Buf.split(Lines, '\n'); + for (StringRef Line : Lines) { + Line = Line.trim(); + if (!Line.empty()) + CHRFunctions.insert(Line); + } + } +} + +namespace { +class ControlHeightReductionLegacyPass : public FunctionPass { +public: + static char ID; + + ControlHeightReductionLegacyPass() : FunctionPass(ID) { + initializeControlHeightReductionLegacyPassPass( + *PassRegistry::getPassRegistry()); + parseCHRFilterFiles(); + } + + bool runOnFunction(Function &F) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<BlockFrequencyInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ProfileSummaryInfoWrapperPass>(); + AU.addRequired<RegionInfoPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; +} // end anonymous namespace + +char ControlHeightReductionLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ControlHeightReductionLegacyPass, + "chr", + "Reduce control height in the hot paths", + false, false) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) +INITIALIZE_PASS_END(ControlHeightReductionLegacyPass, + "chr", + "Reduce control height in the hot paths", + false, false) + +FunctionPass *llvm::createControlHeightReductionLegacyPass() { + return new ControlHeightReductionLegacyPass(); +} + +namespace { + +struct CHRStats { + CHRStats() : NumBranches(0), NumBranchesDelta(0), + WeightedNumBranchesDelta(0) {} + void print(raw_ostream &OS) const { + OS << "CHRStats: NumBranches " << NumBranches + << " NumBranchesDelta " << NumBranchesDelta + << " WeightedNumBranchesDelta " << WeightedNumBranchesDelta; + } + uint64_t NumBranches; // The original number of conditional branches / + // selects + uint64_t NumBranchesDelta; // The decrease of the number of conditional + // branches / selects in the hot paths due to CHR. + uint64_t WeightedNumBranchesDelta; // NumBranchesDelta weighted by the profile + // count at the scope entry. +}; + +// RegInfo - some properties of a Region. +struct RegInfo { + RegInfo() : R(nullptr), HasBranch(false) {} + RegInfo(Region *RegionIn) : R(RegionIn), HasBranch(false) {} + Region *R; + bool HasBranch; + SmallVector<SelectInst *, 8> Selects; +}; + +typedef DenseMap<Region *, DenseSet<Instruction *>> HoistStopMapTy; + +// CHRScope - a sequence of regions to CHR together. It corresponds to a +// sequence of conditional blocks. It can have subscopes which correspond to +// nested conditional blocks. Nested CHRScopes form a tree. +class CHRScope { + public: + CHRScope(RegInfo RI) : BranchInsertPoint(nullptr) { + assert(RI.R && "Null RegionIn"); + RegInfos.push_back(RI); + } + + Region *getParentRegion() { + assert(RegInfos.size() > 0 && "Empty CHRScope"); + Region *Parent = RegInfos[0].R->getParent(); + assert(Parent && "Unexpected to call this on the top-level region"); + return Parent; + } + + BasicBlock *getEntryBlock() { + assert(RegInfos.size() > 0 && "Empty CHRScope"); + return RegInfos.front().R->getEntry(); + } + + BasicBlock *getExitBlock() { + assert(RegInfos.size() > 0 && "Empty CHRScope"); + return RegInfos.back().R->getExit(); + } + + bool appendable(CHRScope *Next) { + // The next scope is appendable only if this scope is directly connected to + // it (which implies it post-dominates this scope) and this scope dominates + // it (no edge to the next scope outside this scope). + BasicBlock *NextEntry = Next->getEntryBlock(); + if (getExitBlock() != NextEntry) + // Not directly connected. + return false; + Region *LastRegion = RegInfos.back().R; + for (BasicBlock *Pred : predecessors(NextEntry)) + if (!LastRegion->contains(Pred)) + // There's an edge going into the entry of the next scope from outside + // of this scope. + return false; + return true; + } + + void append(CHRScope *Next) { + assert(RegInfos.size() > 0 && "Empty CHRScope"); + assert(Next->RegInfos.size() > 0 && "Empty CHRScope"); + assert(getParentRegion() == Next->getParentRegion() && + "Must be siblings"); + assert(getExitBlock() == Next->getEntryBlock() && + "Must be adjacent"); + for (RegInfo &RI : Next->RegInfos) + RegInfos.push_back(RI); + for (CHRScope *Sub : Next->Subs) + Subs.push_back(Sub); + } + + void addSub(CHRScope *SubIn) { +#ifndef NDEBUG + bool IsChild = false; + for (RegInfo &RI : RegInfos) + if (RI.R == SubIn->getParentRegion()) { + IsChild = true; + break; + } + assert(IsChild && "Must be a child"); +#endif + Subs.push_back(SubIn); + } + + // Split this scope at the boundary region into two, which will belong to the + // tail and returns the tail. + CHRScope *split(Region *Boundary) { + assert(Boundary && "Boundary null"); + assert(RegInfos.begin()->R != Boundary && + "Can't be split at beginning"); + auto BoundaryIt = std::find_if(RegInfos.begin(), RegInfos.end(), + [&Boundary](const RegInfo& RI) { + return Boundary == RI.R; + }); + if (BoundaryIt == RegInfos.end()) + return nullptr; + SmallVector<RegInfo, 8> TailRegInfos; + SmallVector<CHRScope *, 8> TailSubs; + TailRegInfos.insert(TailRegInfos.begin(), BoundaryIt, RegInfos.end()); + RegInfos.resize(BoundaryIt - RegInfos.begin()); + DenseSet<Region *> TailRegionSet; + for (RegInfo &RI : TailRegInfos) + TailRegionSet.insert(RI.R); + for (auto It = Subs.begin(); It != Subs.end(); ) { + CHRScope *Sub = *It; + assert(Sub && "null Sub"); + Region *Parent = Sub->getParentRegion(); + if (TailRegionSet.count(Parent)) { + TailSubs.push_back(Sub); + It = Subs.erase(It); + } else { + assert(std::find_if(RegInfos.begin(), RegInfos.end(), + [&Parent](const RegInfo& RI) { + return Parent == RI.R; + }) != RegInfos.end() && + "Must be in head"); + ++It; + } + } + assert(HoistStopMap.empty() && "MapHoistStops must be empty"); + return new CHRScope(TailRegInfos, TailSubs); + } + + bool contains(Instruction *I) const { + BasicBlock *Parent = I->getParent(); + for (const RegInfo &RI : RegInfos) + if (RI.R->contains(Parent)) + return true; + return false; + } + + void print(raw_ostream &OS) const; + + SmallVector<RegInfo, 8> RegInfos; // Regions that belong to this scope + SmallVector<CHRScope *, 8> Subs; // Subscopes. + + // The instruction at which to insert the CHR conditional branch (and hoist + // the dependent condition values). + Instruction *BranchInsertPoint; + + // True-biased and false-biased regions (conditional blocks), + // respectively. Used only for the outermost scope and includes regions in + // subscopes. The rest are unbiased. + DenseSet<Region *> TrueBiasedRegions; + DenseSet<Region *> FalseBiasedRegions; + // Among the biased regions, the regions that get CHRed. + SmallVector<RegInfo, 8> CHRRegions; + + // True-biased and false-biased selects, respectively. Used only for the + // outermost scope and includes ones in subscopes. + DenseSet<SelectInst *> TrueBiasedSelects; + DenseSet<SelectInst *> FalseBiasedSelects; + + // Map from one of the above regions to the instructions to stop + // hoisting instructions at through use-def chains. + HoistStopMapTy HoistStopMap; + + private: + CHRScope(SmallVector<RegInfo, 8> &RegInfosIn, + SmallVector<CHRScope *, 8> &SubsIn) + : RegInfos(RegInfosIn), Subs(SubsIn), BranchInsertPoint(nullptr) {} +}; + +class CHR { + public: + CHR(Function &Fin, BlockFrequencyInfo &BFIin, DominatorTree &DTin, + ProfileSummaryInfo &PSIin, RegionInfo &RIin, + OptimizationRemarkEmitter &OREin) + : F(Fin), BFI(BFIin), DT(DTin), PSI(PSIin), RI(RIin), ORE(OREin) {} + + ~CHR() { + for (CHRScope *Scope : Scopes) { + delete Scope; + } + } + + bool run(); + + private: + // See the comments in CHR::run() for the high level flow of the algorithm and + // what the following functions do. + + void findScopes(SmallVectorImpl<CHRScope *> &Output) { + Region *R = RI.getTopLevelRegion(); + CHRScope *Scope = findScopes(R, nullptr, nullptr, Output); + if (Scope) { + Output.push_back(Scope); + } + } + CHRScope *findScopes(Region *R, Region *NextRegion, Region *ParentRegion, + SmallVectorImpl<CHRScope *> &Scopes); + CHRScope *findScope(Region *R); + void checkScopeHoistable(CHRScope *Scope); + + void splitScopes(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output); + SmallVector<CHRScope *, 8> splitScope(CHRScope *Scope, + CHRScope *Outer, + DenseSet<Value *> *OuterConditionValues, + Instruction *OuterInsertPoint, + SmallVectorImpl<CHRScope *> &Output, + DenseSet<Instruction *> &Unhoistables); + + void classifyBiasedScopes(SmallVectorImpl<CHRScope *> &Scopes); + void classifyBiasedScopes(CHRScope *Scope, CHRScope *OutermostScope); + + void filterScopes(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output); + + void setCHRRegions(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output); + void setCHRRegions(CHRScope *Scope, CHRScope *OutermostScope); + + void sortScopes(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output); + + void transformScopes(SmallVectorImpl<CHRScope *> &CHRScopes); + void transformScopes(CHRScope *Scope, DenseSet<PHINode *> &TrivialPHIs); + void cloneScopeBlocks(CHRScope *Scope, + BasicBlock *PreEntryBlock, + BasicBlock *ExitBlock, + Region *LastRegion, + ValueToValueMapTy &VMap); + BranchInst *createMergedBranch(BasicBlock *PreEntryBlock, + BasicBlock *EntryBlock, + BasicBlock *NewEntryBlock, + ValueToValueMapTy &VMap); + void fixupBranchesAndSelects(CHRScope *Scope, + BasicBlock *PreEntryBlock, + BranchInst *MergedBR, + uint64_t ProfileCount); + void fixupBranch(Region *R, + CHRScope *Scope, + IRBuilder<> &IRB, + Value *&MergedCondition, BranchProbability &CHRBranchBias); + void fixupSelect(SelectInst* SI, + CHRScope *Scope, + IRBuilder<> &IRB, + Value *&MergedCondition, BranchProbability &CHRBranchBias); + void addToMergedCondition(bool IsTrueBiased, Value *Cond, + Instruction *BranchOrSelect, + CHRScope *Scope, + IRBuilder<> &IRB, + Value *&MergedCondition); + + Function &F; + BlockFrequencyInfo &BFI; + DominatorTree &DT; + ProfileSummaryInfo &PSI; + RegionInfo &RI; + OptimizationRemarkEmitter &ORE; + CHRStats Stats; + + // All the true-biased regions in the function + DenseSet<Region *> TrueBiasedRegionsGlobal; + // All the false-biased regions in the function + DenseSet<Region *> FalseBiasedRegionsGlobal; + // All the true-biased selects in the function + DenseSet<SelectInst *> TrueBiasedSelectsGlobal; + // All the false-biased selects in the function + DenseSet<SelectInst *> FalseBiasedSelectsGlobal; + // A map from biased regions to their branch bias + DenseMap<Region *, BranchProbability> BranchBiasMap; + // A map from biased selects to their branch bias + DenseMap<SelectInst *, BranchProbability> SelectBiasMap; + // All the scopes. + DenseSet<CHRScope *> Scopes; +}; + +} // end anonymous namespace + +static inline +raw_ostream LLVM_ATTRIBUTE_UNUSED &operator<<(raw_ostream &OS, + const CHRStats &Stats) { + Stats.print(OS); + return OS; +} + +static inline +raw_ostream &operator<<(raw_ostream &OS, const CHRScope &Scope) { + Scope.print(OS); + return OS; +} + +static bool shouldApply(Function &F, ProfileSummaryInfo& PSI) { + if (ForceCHR) + return true; + + if (!CHRModuleList.empty() || !CHRFunctionList.empty()) { + if (CHRModules.count(F.getParent()->getName())) + return true; + return CHRFunctions.count(F.getName()); + } + + assert(PSI.hasProfileSummary() && "Empty PSI?"); + return PSI.isFunctionEntryHot(&F); +} + +static void LLVM_ATTRIBUTE_UNUSED dumpIR(Function &F, const char *Label, + CHRStats *Stats) { + StringRef FuncName = F.getName(); + StringRef ModuleName = F.getParent()->getName(); + (void)(FuncName); // Unused in release build. + (void)(ModuleName); // Unused in release build. + CHR_DEBUG(dbgs() << "CHR IR dump " << Label << " " << ModuleName << " " + << FuncName); + if (Stats) + CHR_DEBUG(dbgs() << " " << *Stats); + CHR_DEBUG(dbgs() << "\n"); + CHR_DEBUG(F.dump()); +} + +void CHRScope::print(raw_ostream &OS) const { + assert(RegInfos.size() > 0 && "Empty CHRScope"); + OS << "CHRScope["; + OS << RegInfos.size() << ", Regions["; + for (const RegInfo &RI : RegInfos) { + OS << RI.R->getNameStr(); + if (RI.HasBranch) + OS << " B"; + if (RI.Selects.size() > 0) + OS << " S" << RI.Selects.size(); + OS << ", "; + } + if (RegInfos[0].R->getParent()) { + OS << "], Parent " << RegInfos[0].R->getParent()->getNameStr(); + } else { + // top level region + OS << "]"; + } + OS << ", Subs["; + for (CHRScope *Sub : Subs) { + OS << *Sub << ", "; + } + OS << "]]"; +} + +// Return true if the given instruction type can be hoisted by CHR. +static bool isHoistableInstructionType(Instruction *I) { + return isa<BinaryOperator>(I) || isa<CastInst>(I) || isa<SelectInst>(I) || + isa<GetElementPtrInst>(I) || isa<CmpInst>(I) || + isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) || + isa<ShuffleVectorInst>(I) || isa<ExtractValueInst>(I) || + isa<InsertValueInst>(I); +} + +// Return true if the given instruction can be hoisted by CHR. +static bool isHoistable(Instruction *I, DominatorTree &DT) { + if (!isHoistableInstructionType(I)) + return false; + return isSafeToSpeculativelyExecute(I, nullptr, &DT); +} + +// Recursively traverse the use-def chains of the given value and return a set +// of the unhoistable base values defined within the scope (excluding the +// first-region entry block) or the (hoistable or unhoistable) base values that +// are defined outside (including the first-region entry block) of the +// scope. The returned set doesn't include constants. +static std::set<Value *> getBaseValues(Value *V, + DominatorTree &DT) { + std::set<Value *> Result; + if (auto *I = dyn_cast<Instruction>(V)) { + // We don't stop at a block that's not in the Scope because we would miss some + // instructions that are based on the same base values if we stop there. + if (!isHoistable(I, DT)) { + Result.insert(I); + return Result; + } + // I is hoistable above the Scope. + for (Value *Op : I->operands()) { + std::set<Value *> OpResult = getBaseValues(Op, DT); + Result.insert(OpResult.begin(), OpResult.end()); + } + return Result; + } + if (isa<Argument>(V)) { + Result.insert(V); + return Result; + } + // We don't include others like constants because those won't lead to any + // chance of folding of conditions (eg two bit checks merged into one check) + // after CHR. + return Result; // empty +} + +// Return true if V is already hoisted or can be hoisted (along with its +// operands) above the insert point. When it returns true and HoistStops is +// non-null, the instructions to stop hoisting at through the use-def chains are +// inserted into HoistStops. +static bool +checkHoistValue(Value *V, Instruction *InsertPoint, DominatorTree &DT, + DenseSet<Instruction *> &Unhoistables, + DenseSet<Instruction *> *HoistStops) { + assert(InsertPoint && "Null InsertPoint"); + if (auto *I = dyn_cast<Instruction>(V)) { + assert(DT.getNode(I->getParent()) && "DT must contain I's parent block"); + assert(DT.getNode(InsertPoint->getParent()) && "DT must contain Destination"); + if (Unhoistables.count(I)) { + // Don't hoist if they are not to be hoisted. + return false; + } + if (DT.dominates(I, InsertPoint)) { + // We are already above the insert point. Stop here. + if (HoistStops) + HoistStops->insert(I); + return true; + } + // We aren't not above the insert point, check if we can hoist it above the + // insert point. + if (isHoistable(I, DT)) { + // Check operands first. + DenseSet<Instruction *> OpsHoistStops; + bool AllOpsHoisted = true; + for (Value *Op : I->operands()) { + if (!checkHoistValue(Op, InsertPoint, DT, Unhoistables, &OpsHoistStops)) { + AllOpsHoisted = false; + break; + } + } + if (AllOpsHoisted) { + CHR_DEBUG(dbgs() << "checkHoistValue " << *I << "\n"); + if (HoistStops) + HoistStops->insert(OpsHoistStops.begin(), OpsHoistStops.end()); + return true; + } + } + return false; + } + // Non-instructions are considered hoistable. + return true; +} + +// Returns true and sets the true probability and false probability of an +// MD_prof metadata if it's well-formed. +static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb, + BranchProbability &FalseProb) { + if (!MD) return false; + MDString *MDName = cast<MDString>(MD->getOperand(0)); + if (MDName->getString() != "branch_weights" || + MD->getNumOperands() != 3) + return false; + ConstantInt *TrueWeight = mdconst::extract<ConstantInt>(MD->getOperand(1)); + ConstantInt *FalseWeight = mdconst::extract<ConstantInt>(MD->getOperand(2)); + if (!TrueWeight || !FalseWeight) + return false; + uint64_t TrueWt = TrueWeight->getValue().getZExtValue(); + uint64_t FalseWt = FalseWeight->getValue().getZExtValue(); + uint64_t SumWt = TrueWt + FalseWt; + + assert(SumWt >= TrueWt && SumWt >= FalseWt && + "Overflow calculating branch probabilities."); + + TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); + FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + return true; +} + +static BranchProbability getCHRBiasThreshold() { + return BranchProbability::getBranchProbability( + static_cast<uint64_t>(CHRBiasThreshold * 1000000), 1000000); +} + +// A helper for CheckBiasedBranch and CheckBiasedSelect. If TrueProb >= +// CHRBiasThreshold, put Key into TrueSet and return true. If FalseProb >= +// CHRBiasThreshold, put Key into FalseSet and return true. Otherwise, return +// false. +template <typename K, typename S, typename M> +static bool checkBias(K *Key, BranchProbability TrueProb, + BranchProbability FalseProb, S &TrueSet, S &FalseSet, + M &BiasMap) { + BranchProbability Threshold = getCHRBiasThreshold(); + if (TrueProb >= Threshold) { + TrueSet.insert(Key); + BiasMap[Key] = TrueProb; + return true; + } else if (FalseProb >= Threshold) { + FalseSet.insert(Key); + BiasMap[Key] = FalseProb; + return true; + } + return false; +} + +// Returns true and insert a region into the right biased set and the map if the +// branch of the region is biased. +static bool checkBiasedBranch(BranchInst *BI, Region *R, + DenseSet<Region *> &TrueBiasedRegionsGlobal, + DenseSet<Region *> &FalseBiasedRegionsGlobal, + DenseMap<Region *, BranchProbability> &BranchBiasMap) { + if (!BI->isConditional()) + return false; + BranchProbability ThenProb, ElseProb; + if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof), + ThenProb, ElseProb)) + return false; + BasicBlock *IfThen = BI->getSuccessor(0); + BasicBlock *IfElse = BI->getSuccessor(1); + assert((IfThen == R->getExit() || IfElse == R->getExit()) && + IfThen != IfElse && + "Invariant from findScopes"); + if (IfThen == R->getExit()) { + // Swap them so that IfThen/ThenProb means going into the conditional code + // and IfElse/ElseProb means skipping it. + std::swap(IfThen, IfElse); + std::swap(ThenProb, ElseProb); + } + CHR_DEBUG(dbgs() << "BI " << *BI << " "); + CHR_DEBUG(dbgs() << "ThenProb " << ThenProb << " "); + CHR_DEBUG(dbgs() << "ElseProb " << ElseProb << "\n"); + return checkBias(R, ThenProb, ElseProb, + TrueBiasedRegionsGlobal, FalseBiasedRegionsGlobal, + BranchBiasMap); +} + +// Returns true and insert a select into the right biased set and the map if the +// select is biased. +static bool checkBiasedSelect( + SelectInst *SI, Region *R, + DenseSet<SelectInst *> &TrueBiasedSelectsGlobal, + DenseSet<SelectInst *> &FalseBiasedSelectsGlobal, + DenseMap<SelectInst *, BranchProbability> &SelectBiasMap) { + BranchProbability TrueProb, FalseProb; + if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof), + TrueProb, FalseProb)) + return false; + CHR_DEBUG(dbgs() << "SI " << *SI << " "); + CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " "); + CHR_DEBUG(dbgs() << "FalseProb " << FalseProb << "\n"); + return checkBias(SI, TrueProb, FalseProb, + TrueBiasedSelectsGlobal, FalseBiasedSelectsGlobal, + SelectBiasMap); +} + +// Returns the instruction at which to hoist the dependent condition values and +// insert the CHR branch for a region. This is the terminator branch in the +// entry block or the first select in the entry block, if any. +static Instruction* getBranchInsertPoint(RegInfo &RI) { + Region *R = RI.R; + BasicBlock *EntryBB = R->getEntry(); + // The hoist point is by default the terminator of the entry block, which is + // the same as the branch instruction if RI.HasBranch is true. + Instruction *HoistPoint = EntryBB->getTerminator(); + for (SelectInst *SI : RI.Selects) { + if (SI->getParent() == EntryBB) { + // Pick the first select in Selects in the entry block. Note Selects is + // sorted in the instruction order within a block (asserted below). + HoistPoint = SI; + break; + } + } + assert(HoistPoint && "Null HoistPoint"); +#ifndef NDEBUG + // Check that HoistPoint is the first one in Selects in the entry block, + // if any. + DenseSet<Instruction *> EntryBlockSelectSet; + for (SelectInst *SI : RI.Selects) { + if (SI->getParent() == EntryBB) { + EntryBlockSelectSet.insert(SI); + } + } + for (Instruction &I : *EntryBB) { + if (EntryBlockSelectSet.count(&I) > 0) { + assert(&I == HoistPoint && + "HoistPoint must be the first one in Selects"); + break; + } + } +#endif + return HoistPoint; +} + +// Find a CHR scope in the given region. +CHRScope * CHR::findScope(Region *R) { + CHRScope *Result = nullptr; + BasicBlock *Entry = R->getEntry(); + BasicBlock *Exit = R->getExit(); // null if top level. + assert(Entry && "Entry must not be null"); + assert((Exit == nullptr) == (R->isTopLevelRegion()) && + "Only top level region has a null exit"); + if (Entry) + CHR_DEBUG(dbgs() << "Entry " << Entry->getName() << "\n"); + else + CHR_DEBUG(dbgs() << "Entry null\n"); + if (Exit) + CHR_DEBUG(dbgs() << "Exit " << Exit->getName() << "\n"); + else + CHR_DEBUG(dbgs() << "Exit null\n"); + // Exclude cases where Entry is part of a subregion (hence it doesn't belong + // to this region). + bool EntryInSubregion = RI.getRegionFor(Entry) != R; + if (EntryInSubregion) + return nullptr; + // Exclude loops + for (BasicBlock *Pred : predecessors(Entry)) + if (R->contains(Pred)) + return nullptr; + if (Exit) { + // Try to find an if-then block (check if R is an if-then). + // if (cond) { + // ... + // } + auto *BI = dyn_cast<BranchInst>(Entry->getTerminator()); + if (BI) + CHR_DEBUG(dbgs() << "BI.isConditional " << BI->isConditional() << "\n"); + else + CHR_DEBUG(dbgs() << "BI null\n"); + if (BI && BI->isConditional()) { + BasicBlock *S0 = BI->getSuccessor(0); + BasicBlock *S1 = BI->getSuccessor(1); + CHR_DEBUG(dbgs() << "S0 " << S0->getName() << "\n"); + CHR_DEBUG(dbgs() << "S1 " << S1->getName() << "\n"); + if (S0 != S1 && (S0 == Exit || S1 == Exit)) { + RegInfo RI(R); + RI.HasBranch = checkBiasedBranch( + BI, R, TrueBiasedRegionsGlobal, FalseBiasedRegionsGlobal, + BranchBiasMap); + Result = new CHRScope(RI); + Scopes.insert(Result); + CHR_DEBUG(dbgs() << "Found a region with a branch\n"); + ++Stats.NumBranches; + if (!RI.HasBranch) { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "BranchNotBiased", BI) + << "Branch not biased"; + }); + } + } + } + } + { + // Try to look for selects in the direct child blocks (as opposed to in + // subregions) of R. + // ... + // if (..) { // Some subregion + // ... + // } + // if (..) { // Some subregion + // ... + // } + // ... + // a = cond ? b : c; + // ... + SmallVector<SelectInst *, 8> Selects; + for (RegionNode *E : R->elements()) { + if (E->isSubRegion()) + continue; + // This returns the basic block of E if E is a direct child of R (not a + // subregion.) + BasicBlock *BB = E->getEntry(); + // Need to push in the order to make it easier to find the first Select + // later. + for (Instruction &I : *BB) { + if (auto *SI = dyn_cast<SelectInst>(&I)) { + Selects.push_back(SI); + ++Stats.NumBranches; + } + } + } + if (Selects.size() > 0) { + auto AddSelects = [&](RegInfo &RI) { + for (auto *SI : Selects) + if (checkBiasedSelect(SI, RI.R, + TrueBiasedSelectsGlobal, + FalseBiasedSelectsGlobal, + SelectBiasMap)) + RI.Selects.push_back(SI); + else + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "SelectNotBiased", SI) + << "Select not biased"; + }); + }; + if (!Result) { + CHR_DEBUG(dbgs() << "Found a select-only region\n"); + RegInfo RI(R); + AddSelects(RI); + Result = new CHRScope(RI); + Scopes.insert(Result); + } else { + CHR_DEBUG(dbgs() << "Found select(s) in a region with a branch\n"); + AddSelects(Result->RegInfos[0]); + } + } + } + + if (Result) { + checkScopeHoistable(Result); + } + return Result; +} + +// Check that any of the branch and the selects in the region could be +// hoisted above the the CHR branch insert point (the most dominating of +// them, either the branch (at the end of the first block) or the first +// select in the first block). If the branch can't be hoisted, drop the +// selects in the first blocks. +// +// For example, for the following scope/region with selects, we want to insert +// the merged branch right before the first select in the first/entry block by +// hoisting c1, c2, c3, and c4. +// +// // Branch insert point here. +// a = c1 ? b : c; // Select 1 +// d = c2 ? e : f; // Select 2 +// if (c3) { // Branch +// ... +// c4 = foo() // A call. +// g = c4 ? h : i; // Select 3 +// } +// +// But suppose we can't hoist c4 because it's dependent on the preceding +// call. Then, we drop Select 3. Furthermore, if we can't hoist c2, we also drop +// Select 2. If we can't hoist c3, we drop Selects 1 & 2. +void CHR::checkScopeHoistable(CHRScope *Scope) { + RegInfo &RI = Scope->RegInfos[0]; + Region *R = RI.R; + BasicBlock *EntryBB = R->getEntry(); + auto *Branch = RI.HasBranch ? + cast<BranchInst>(EntryBB->getTerminator()) : nullptr; + SmallVector<SelectInst *, 8> &Selects = RI.Selects; + if (RI.HasBranch || !Selects.empty()) { + Instruction *InsertPoint = getBranchInsertPoint(RI); + CHR_DEBUG(dbgs() << "InsertPoint " << *InsertPoint << "\n"); + // Avoid a data dependence from a select or a branch to a(nother) + // select. Note no instruction can't data-depend on a branch (a branch + // instruction doesn't produce a value). + DenseSet<Instruction *> Unhoistables; + // Initialize Unhoistables with the selects. + for (SelectInst *SI : Selects) { + Unhoistables.insert(SI); + } + // Remove Selects that can't be hoisted. + for (auto it = Selects.begin(); it != Selects.end(); ) { + SelectInst *SI = *it; + if (SI == InsertPoint) { + ++it; + continue; + } + bool IsHoistable = checkHoistValue(SI->getCondition(), InsertPoint, + DT, Unhoistables, nullptr); + if (!IsHoistable) { + CHR_DEBUG(dbgs() << "Dropping select " << *SI << "\n"); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "DropUnhoistableSelect", SI) + << "Dropped unhoistable select"; + }); + it = Selects.erase(it); + // Since we are dropping the select here, we also drop it from + // Unhoistables. + Unhoistables.erase(SI); + } else + ++it; + } + // Update InsertPoint after potentially removing selects. + InsertPoint = getBranchInsertPoint(RI); + CHR_DEBUG(dbgs() << "InsertPoint " << *InsertPoint << "\n"); + if (RI.HasBranch && InsertPoint != Branch) { + bool IsHoistable = checkHoistValue(Branch->getCondition(), InsertPoint, + DT, Unhoistables, nullptr); + if (!IsHoistable) { + // If the branch isn't hoistable, drop the selects in the entry + // block, preferring the branch, which makes the branch the hoist + // point. + assert(InsertPoint != Branch && "Branch must not be the hoist point"); + CHR_DEBUG(dbgs() << "Dropping selects in entry block \n"); + CHR_DEBUG( + for (SelectInst *SI : Selects) { + dbgs() << "SI " << *SI << "\n"; + }); + for (SelectInst *SI : Selects) { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "DropSelectUnhoistableBranch", SI) + << "Dropped select due to unhoistable branch"; + }); + } + Selects.erase(std::remove_if(Selects.begin(), Selects.end(), + [EntryBB](SelectInst *SI) { + return SI->getParent() == EntryBB; + }), Selects.end()); + Unhoistables.clear(); + InsertPoint = Branch; + } + } + CHR_DEBUG(dbgs() << "InsertPoint " << *InsertPoint << "\n"); +#ifndef NDEBUG + if (RI.HasBranch) { + assert(!DT.dominates(Branch, InsertPoint) && + "Branch can't be already above the hoist point"); + assert(checkHoistValue(Branch->getCondition(), InsertPoint, + DT, Unhoistables, nullptr) && + "checkHoistValue for branch"); + } + for (auto *SI : Selects) { + assert(!DT.dominates(SI, InsertPoint) && + "SI can't be already above the hoist point"); + assert(checkHoistValue(SI->getCondition(), InsertPoint, DT, + Unhoistables, nullptr) && + "checkHoistValue for selects"); + } + CHR_DEBUG(dbgs() << "Result\n"); + if (RI.HasBranch) { + CHR_DEBUG(dbgs() << "BI " << *Branch << "\n"); + } + for (auto *SI : Selects) { + CHR_DEBUG(dbgs() << "SI " << *SI << "\n"); + } +#endif + } +} + +// Traverse the region tree, find all nested scopes and merge them if possible. +CHRScope * CHR::findScopes(Region *R, Region *NextRegion, Region *ParentRegion, + SmallVectorImpl<CHRScope *> &Scopes) { + CHR_DEBUG(dbgs() << "findScopes " << R->getNameStr() << "\n"); + CHRScope *Result = findScope(R); + // Visit subscopes. + CHRScope *ConsecutiveSubscope = nullptr; + SmallVector<CHRScope *, 8> Subscopes; + for (auto It = R->begin(); It != R->end(); ++It) { + const std::unique_ptr<Region> &SubR = *It; + auto NextIt = std::next(It); + Region *NextSubR = NextIt != R->end() ? NextIt->get() : nullptr; + CHR_DEBUG(dbgs() << "Looking at subregion " << SubR.get()->getNameStr() + << "\n"); + CHRScope *SubCHRScope = findScopes(SubR.get(), NextSubR, R, Scopes); + if (SubCHRScope) { + CHR_DEBUG(dbgs() << "Subregion Scope " << *SubCHRScope << "\n"); + } else { + CHR_DEBUG(dbgs() << "Subregion Scope null\n"); + } + if (SubCHRScope) { + if (!ConsecutiveSubscope) + ConsecutiveSubscope = SubCHRScope; + else if (!ConsecutiveSubscope->appendable(SubCHRScope)) { + Subscopes.push_back(ConsecutiveSubscope); + ConsecutiveSubscope = SubCHRScope; + } else + ConsecutiveSubscope->append(SubCHRScope); + } else { + if (ConsecutiveSubscope) { + Subscopes.push_back(ConsecutiveSubscope); + } + ConsecutiveSubscope = nullptr; + } + } + if (ConsecutiveSubscope) { + Subscopes.push_back(ConsecutiveSubscope); + } + for (CHRScope *Sub : Subscopes) { + if (Result) { + // Combine it with the parent. + Result->addSub(Sub); + } else { + // Push Subscopes as they won't be combined with the parent. + Scopes.push_back(Sub); + } + } + return Result; +} + +static DenseSet<Value *> getCHRConditionValuesForRegion(RegInfo &RI) { + DenseSet<Value *> ConditionValues; + if (RI.HasBranch) { + auto *BI = cast<BranchInst>(RI.R->getEntry()->getTerminator()); + ConditionValues.insert(BI->getCondition()); + } + for (SelectInst *SI : RI.Selects) { + ConditionValues.insert(SI->getCondition()); + } + return ConditionValues; +} + + +// Determine whether to split a scope depending on the sets of the branch +// condition values of the previous region and the current region. We split +// (return true) it if 1) the condition values of the inner/lower scope can't be +// hoisted up to the outer/upper scope, or 2) the two sets of the condition +// values have an empty intersection (because the combined branch conditions +// won't probably lead to a simpler combined condition). +static bool shouldSplit(Instruction *InsertPoint, + DenseSet<Value *> &PrevConditionValues, + DenseSet<Value *> &ConditionValues, + DominatorTree &DT, + DenseSet<Instruction *> &Unhoistables) { + CHR_DEBUG( + dbgs() << "shouldSplit " << *InsertPoint << " PrevConditionValues "; + for (Value *V : PrevConditionValues) { + dbgs() << *V << ", "; + } + dbgs() << " ConditionValues "; + for (Value *V : ConditionValues) { + dbgs() << *V << ", "; + } + dbgs() << "\n"); + assert(InsertPoint && "Null InsertPoint"); + // If any of Bases isn't hoistable to the hoist point, split. + for (Value *V : ConditionValues) { + if (!checkHoistValue(V, InsertPoint, DT, Unhoistables, nullptr)) { + CHR_DEBUG(dbgs() << "Split. checkHoistValue false " << *V << "\n"); + return true; // Not hoistable, split. + } + } + // If PrevConditionValues or ConditionValues is empty, don't split to avoid + // unnecessary splits at scopes with no branch/selects. If + // PrevConditionValues and ConditionValues don't intersect at all, split. + if (!PrevConditionValues.empty() && !ConditionValues.empty()) { + // Use std::set as DenseSet doesn't work with set_intersection. + std::set<Value *> PrevBases, Bases; + for (Value *V : PrevConditionValues) { + std::set<Value *> BaseValues = getBaseValues(V, DT); + PrevBases.insert(BaseValues.begin(), BaseValues.end()); + } + for (Value *V : ConditionValues) { + std::set<Value *> BaseValues = getBaseValues(V, DT); + Bases.insert(BaseValues.begin(), BaseValues.end()); + } + CHR_DEBUG( + dbgs() << "PrevBases "; + for (Value *V : PrevBases) { + dbgs() << *V << ", "; + } + dbgs() << " Bases "; + for (Value *V : Bases) { + dbgs() << *V << ", "; + } + dbgs() << "\n"); + std::set<Value *> Intersection; + std::set_intersection(PrevBases.begin(), PrevBases.end(), + Bases.begin(), Bases.end(), + std::inserter(Intersection, Intersection.begin())); + if (Intersection.empty()) { + // Empty intersection, split. + CHR_DEBUG(dbgs() << "Split. Intersection empty\n"); + return true; + } + } + CHR_DEBUG(dbgs() << "No split\n"); + return false; // Don't split. +} + +static void getSelectsInScope(CHRScope *Scope, + DenseSet<Instruction *> &Output) { + for (RegInfo &RI : Scope->RegInfos) + for (SelectInst *SI : RI.Selects) + Output.insert(SI); + for (CHRScope *Sub : Scope->Subs) + getSelectsInScope(Sub, Output); +} + +void CHR::splitScopes(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output) { + for (CHRScope *Scope : Input) { + assert(!Scope->BranchInsertPoint && + "BranchInsertPoint must not be set"); + DenseSet<Instruction *> Unhoistables; + getSelectsInScope(Scope, Unhoistables); + splitScope(Scope, nullptr, nullptr, nullptr, Output, Unhoistables); + } +#ifndef NDEBUG + for (CHRScope *Scope : Output) { + assert(Scope->BranchInsertPoint && "BranchInsertPoint must be set"); + } +#endif +} + +SmallVector<CHRScope *, 8> CHR::splitScope( + CHRScope *Scope, + CHRScope *Outer, + DenseSet<Value *> *OuterConditionValues, + Instruction *OuterInsertPoint, + SmallVectorImpl<CHRScope *> &Output, + DenseSet<Instruction *> &Unhoistables) { + if (Outer) { + assert(OuterConditionValues && "Null OuterConditionValues"); + assert(OuterInsertPoint && "Null OuterInsertPoint"); + } + bool PrevSplitFromOuter = true; + DenseSet<Value *> PrevConditionValues; + Instruction *PrevInsertPoint = nullptr; + SmallVector<CHRScope *, 8> Splits; + SmallVector<bool, 8> SplitsSplitFromOuter; + SmallVector<DenseSet<Value *>, 8> SplitsConditionValues; + SmallVector<Instruction *, 8> SplitsInsertPoints; + SmallVector<RegInfo, 8> RegInfos(Scope->RegInfos); // Copy + for (RegInfo &RI : RegInfos) { + Instruction *InsertPoint = getBranchInsertPoint(RI); + DenseSet<Value *> ConditionValues = getCHRConditionValuesForRegion(RI); + CHR_DEBUG( + dbgs() << "ConditionValues "; + for (Value *V : ConditionValues) { + dbgs() << *V << ", "; + } + dbgs() << "\n"); + if (RI.R == RegInfos[0].R) { + // First iteration. Check to see if we should split from the outer. + if (Outer) { + CHR_DEBUG(dbgs() << "Outer " << *Outer << "\n"); + CHR_DEBUG(dbgs() << "Should split from outer at " + << RI.R->getNameStr() << "\n"); + if (shouldSplit(OuterInsertPoint, *OuterConditionValues, + ConditionValues, DT, Unhoistables)) { + PrevConditionValues = ConditionValues; + PrevInsertPoint = InsertPoint; + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "SplitScopeFromOuter", + RI.R->getEntry()->getTerminator()) + << "Split scope from outer due to unhoistable branch/select " + << "and/or lack of common condition values"; + }); + } else { + // Not splitting from the outer. Use the outer bases and insert + // point. Union the bases. + PrevSplitFromOuter = false; + PrevConditionValues = *OuterConditionValues; + PrevConditionValues.insert(ConditionValues.begin(), + ConditionValues.end()); + PrevInsertPoint = OuterInsertPoint; + } + } else { + CHR_DEBUG(dbgs() << "Outer null\n"); + PrevConditionValues = ConditionValues; + PrevInsertPoint = InsertPoint; + } + } else { + CHR_DEBUG(dbgs() << "Should split from prev at " + << RI.R->getNameStr() << "\n"); + if (shouldSplit(PrevInsertPoint, PrevConditionValues, ConditionValues, + DT, Unhoistables)) { + CHRScope *Tail = Scope->split(RI.R); + Scopes.insert(Tail); + Splits.push_back(Scope); + SplitsSplitFromOuter.push_back(PrevSplitFromOuter); + SplitsConditionValues.push_back(PrevConditionValues); + SplitsInsertPoints.push_back(PrevInsertPoint); + Scope = Tail; + PrevConditionValues = ConditionValues; + PrevInsertPoint = InsertPoint; + PrevSplitFromOuter = true; + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "SplitScopeFromPrev", + RI.R->getEntry()->getTerminator()) + << "Split scope from previous due to unhoistable branch/select " + << "and/or lack of common condition values"; + }); + } else { + // Not splitting. Union the bases. Keep the hoist point. + PrevConditionValues.insert(ConditionValues.begin(), ConditionValues.end()); + } + } + } + Splits.push_back(Scope); + SplitsSplitFromOuter.push_back(PrevSplitFromOuter); + SplitsConditionValues.push_back(PrevConditionValues); + assert(PrevInsertPoint && "Null PrevInsertPoint"); + SplitsInsertPoints.push_back(PrevInsertPoint); + assert(Splits.size() == SplitsConditionValues.size() && + Splits.size() == SplitsSplitFromOuter.size() && + Splits.size() == SplitsInsertPoints.size() && "Mismatching sizes"); + for (size_t I = 0; I < Splits.size(); ++I) { + CHRScope *Split = Splits[I]; + DenseSet<Value *> &SplitConditionValues = SplitsConditionValues[I]; + Instruction *SplitInsertPoint = SplitsInsertPoints[I]; + SmallVector<CHRScope *, 8> NewSubs; + DenseSet<Instruction *> SplitUnhoistables; + getSelectsInScope(Split, SplitUnhoistables); + for (CHRScope *Sub : Split->Subs) { + SmallVector<CHRScope *, 8> SubSplits = splitScope( + Sub, Split, &SplitConditionValues, SplitInsertPoint, Output, + SplitUnhoistables); + NewSubs.insert(NewSubs.end(), SubSplits.begin(), SubSplits.end()); + } + Split->Subs = NewSubs; + } + SmallVector<CHRScope *, 8> Result; + for (size_t I = 0; I < Splits.size(); ++I) { + CHRScope *Split = Splits[I]; + if (SplitsSplitFromOuter[I]) { + // Split from the outer. + Output.push_back(Split); + Split->BranchInsertPoint = SplitsInsertPoints[I]; + CHR_DEBUG(dbgs() << "BranchInsertPoint " << *SplitsInsertPoints[I] + << "\n"); + } else { + // Connected to the outer. + Result.push_back(Split); + } + } + if (!Outer) + assert(Result.empty() && + "If no outer (top-level), must return no nested ones"); + return Result; +} + +void CHR::classifyBiasedScopes(SmallVectorImpl<CHRScope *> &Scopes) { + for (CHRScope *Scope : Scopes) { + assert(Scope->TrueBiasedRegions.empty() && Scope->FalseBiasedRegions.empty() && "Empty"); + classifyBiasedScopes(Scope, Scope); + CHR_DEBUG( + dbgs() << "classifyBiasedScopes " << *Scope << "\n"; + dbgs() << "TrueBiasedRegions "; + for (Region *R : Scope->TrueBiasedRegions) { + dbgs() << R->getNameStr() << ", "; + } + dbgs() << "\n"; + dbgs() << "FalseBiasedRegions "; + for (Region *R : Scope->FalseBiasedRegions) { + dbgs() << R->getNameStr() << ", "; + } + dbgs() << "\n"; + dbgs() << "TrueBiasedSelects "; + for (SelectInst *SI : Scope->TrueBiasedSelects) { + dbgs() << *SI << ", "; + } + dbgs() << "\n"; + dbgs() << "FalseBiasedSelects "; + for (SelectInst *SI : Scope->FalseBiasedSelects) { + dbgs() << *SI << ", "; + } + dbgs() << "\n";); + } +} + +void CHR::classifyBiasedScopes(CHRScope *Scope, CHRScope *OutermostScope) { + for (RegInfo &RI : Scope->RegInfos) { + if (RI.HasBranch) { + Region *R = RI.R; + if (TrueBiasedRegionsGlobal.count(R) > 0) + OutermostScope->TrueBiasedRegions.insert(R); + else if (FalseBiasedRegionsGlobal.count(R) > 0) + OutermostScope->FalseBiasedRegions.insert(R); + else + llvm_unreachable("Must be biased"); + } + for (SelectInst *SI : RI.Selects) { + if (TrueBiasedSelectsGlobal.count(SI) > 0) + OutermostScope->TrueBiasedSelects.insert(SI); + else if (FalseBiasedSelectsGlobal.count(SI) > 0) + OutermostScope->FalseBiasedSelects.insert(SI); + else + llvm_unreachable("Must be biased"); + } + } + for (CHRScope *Sub : Scope->Subs) { + classifyBiasedScopes(Sub, OutermostScope); + } +} + +static bool hasAtLeastTwoBiasedBranches(CHRScope *Scope) { + unsigned NumBiased = Scope->TrueBiasedRegions.size() + + Scope->FalseBiasedRegions.size() + + Scope->TrueBiasedSelects.size() + + Scope->FalseBiasedSelects.size(); + return NumBiased >= CHRMergeThreshold; +} + +void CHR::filterScopes(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output) { + for (CHRScope *Scope : Input) { + // Filter out the ones with only one region and no subs. + if (!hasAtLeastTwoBiasedBranches(Scope)) { + CHR_DEBUG(dbgs() << "Filtered out by biased branches truthy-regions " + << Scope->TrueBiasedRegions.size() + << " falsy-regions " << Scope->FalseBiasedRegions.size() + << " true-selects " << Scope->TrueBiasedSelects.size() + << " false-selects " << Scope->FalseBiasedSelects.size() << "\n"); + ORE.emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, + "DropScopeWithOneBranchOrSelect", + Scope->RegInfos[0].R->getEntry()->getTerminator()) + << "Drop scope with < " + << ore::NV("CHRMergeThreshold", CHRMergeThreshold) + << " biased branch(es) or select(s)"; + }); + continue; + } + Output.push_back(Scope); + } +} + +void CHR::setCHRRegions(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output) { + for (CHRScope *Scope : Input) { + assert(Scope->HoistStopMap.empty() && Scope->CHRRegions.empty() && + "Empty"); + setCHRRegions(Scope, Scope); + Output.push_back(Scope); + CHR_DEBUG( + dbgs() << "setCHRRegions HoistStopMap " << *Scope << "\n"; + for (auto pair : Scope->HoistStopMap) { + Region *R = pair.first; + dbgs() << "Region " << R->getNameStr() << "\n"; + for (Instruction *I : pair.second) { + dbgs() << "HoistStop " << *I << "\n"; + } + } + dbgs() << "CHRRegions" << "\n"; + for (RegInfo &RI : Scope->CHRRegions) { + dbgs() << RI.R->getNameStr() << "\n"; + }); + } +} + +void CHR::setCHRRegions(CHRScope *Scope, CHRScope *OutermostScope) { + DenseSet<Instruction *> Unhoistables; + // Put the biased selects in Unhoistables because they should stay where they + // are and constant-folded after CHR (in case one biased select or a branch + // can depend on another biased select.) + for (RegInfo &RI : Scope->RegInfos) { + for (SelectInst *SI : RI.Selects) { + Unhoistables.insert(SI); + } + } + Instruction *InsertPoint = OutermostScope->BranchInsertPoint; + for (RegInfo &RI : Scope->RegInfos) { + Region *R = RI.R; + DenseSet<Instruction *> HoistStops; + bool IsHoisted = false; + if (RI.HasBranch) { + assert((OutermostScope->TrueBiasedRegions.count(R) > 0 || + OutermostScope->FalseBiasedRegions.count(R) > 0) && + "Must be truthy or falsy"); + auto *BI = cast<BranchInst>(R->getEntry()->getTerminator()); + // Note checkHoistValue fills in HoistStops. + bool IsHoistable = checkHoistValue(BI->getCondition(), InsertPoint, DT, + Unhoistables, &HoistStops); + assert(IsHoistable && "Must be hoistable"); + (void)(IsHoistable); // Unused in release build + IsHoisted = true; + } + for (SelectInst *SI : RI.Selects) { + assert((OutermostScope->TrueBiasedSelects.count(SI) > 0 || + OutermostScope->FalseBiasedSelects.count(SI) > 0) && + "Must be true or false biased"); + // Note checkHoistValue fills in HoistStops. + bool IsHoistable = checkHoistValue(SI->getCondition(), InsertPoint, DT, + Unhoistables, &HoistStops); + assert(IsHoistable && "Must be hoistable"); + (void)(IsHoistable); // Unused in release build + IsHoisted = true; + } + if (IsHoisted) { + OutermostScope->CHRRegions.push_back(RI); + OutermostScope->HoistStopMap[R] = HoistStops; + } + } + for (CHRScope *Sub : Scope->Subs) + setCHRRegions(Sub, OutermostScope); +} + +bool CHRScopeSorter(CHRScope *Scope1, CHRScope *Scope2) { + return Scope1->RegInfos[0].R->getDepth() < Scope2->RegInfos[0].R->getDepth(); +} + +void CHR::sortScopes(SmallVectorImpl<CHRScope *> &Input, + SmallVectorImpl<CHRScope *> &Output) { + Output.resize(Input.size()); + llvm::copy(Input, Output.begin()); + std::stable_sort(Output.begin(), Output.end(), CHRScopeSorter); +} + +// Return true if V is already hoisted or was hoisted (along with its operands) +// to the insert point. +static void hoistValue(Value *V, Instruction *HoistPoint, Region *R, + HoistStopMapTy &HoistStopMap, + DenseSet<Instruction *> &HoistedSet, + DenseSet<PHINode *> &TrivialPHIs) { + auto IT = HoistStopMap.find(R); + assert(IT != HoistStopMap.end() && "Region must be in hoist stop map"); + DenseSet<Instruction *> &HoistStops = IT->second; + if (auto *I = dyn_cast<Instruction>(V)) { + if (I == HoistPoint) + return; + if (HoistStops.count(I)) + return; + if (auto *PN = dyn_cast<PHINode>(I)) + if (TrivialPHIs.count(PN)) + // The trivial phi inserted by the previous CHR scope could replace a + // non-phi in HoistStops. Note that since this phi is at the exit of a + // previous CHR scope, which dominates this scope, it's safe to stop + // hoisting there. + return; + if (HoistedSet.count(I)) + // Already hoisted, return. + return; + assert(isHoistableInstructionType(I) && "Unhoistable instruction type"); + for (Value *Op : I->operands()) { + hoistValue(Op, HoistPoint, R, HoistStopMap, HoistedSet, TrivialPHIs); + } + I->moveBefore(HoistPoint); + HoistedSet.insert(I); + CHR_DEBUG(dbgs() << "hoistValue " << *I << "\n"); + } +} + +// Hoist the dependent condition values of the branches and the selects in the +// scope to the insert point. +static void hoistScopeConditions(CHRScope *Scope, Instruction *HoistPoint, + DenseSet<PHINode *> &TrivialPHIs) { + DenseSet<Instruction *> HoistedSet; + for (const RegInfo &RI : Scope->CHRRegions) { + Region *R = RI.R; + bool IsTrueBiased = Scope->TrueBiasedRegions.count(R); + bool IsFalseBiased = Scope->FalseBiasedRegions.count(R); + if (RI.HasBranch && (IsTrueBiased || IsFalseBiased)) { + auto *BI = cast<BranchInst>(R->getEntry()->getTerminator()); + hoistValue(BI->getCondition(), HoistPoint, R, Scope->HoistStopMap, + HoistedSet, TrivialPHIs); + } + for (SelectInst *SI : RI.Selects) { + bool IsTrueBiased = Scope->TrueBiasedSelects.count(SI); + bool IsFalseBiased = Scope->FalseBiasedSelects.count(SI); + if (!(IsTrueBiased || IsFalseBiased)) + continue; + hoistValue(SI->getCondition(), HoistPoint, R, Scope->HoistStopMap, + HoistedSet, TrivialPHIs); + } + } +} + +// Negate the predicate if an ICmp if it's used only by branches or selects by +// swapping the operands of the branches or the selects. Returns true if success. +static bool negateICmpIfUsedByBranchOrSelectOnly(ICmpInst *ICmp, + Instruction *ExcludedUser, + CHRScope *Scope) { + for (User *U : ICmp->users()) { + if (U == ExcludedUser) + continue; + if (isa<BranchInst>(U) && cast<BranchInst>(U)->isConditional()) + continue; + if (isa<SelectInst>(U) && cast<SelectInst>(U)->getCondition() == ICmp) + continue; + return false; + } + for (User *U : ICmp->users()) { + if (U == ExcludedUser) + continue; + if (auto *BI = dyn_cast<BranchInst>(U)) { + assert(BI->isConditional() && "Must be conditional"); + BI->swapSuccessors(); + // Don't need to swap this in terms of + // TrueBiasedRegions/FalseBiasedRegions because true-based/false-based + // mean whehter the branch is likely go into the if-then rather than + // successor0/successor1 and because we can tell which edge is the then or + // the else one by comparing the destination to the region exit block. + continue; + } + if (auto *SI = dyn_cast<SelectInst>(U)) { + // Swap operands + Value *TrueValue = SI->getTrueValue(); + Value *FalseValue = SI->getFalseValue(); + SI->setTrueValue(FalseValue); + SI->setFalseValue(TrueValue); + SI->swapProfMetadata(); + if (Scope->TrueBiasedSelects.count(SI)) { + assert(Scope->FalseBiasedSelects.count(SI) == 0 && + "Must not be already in"); + Scope->FalseBiasedSelects.insert(SI); + } else if (Scope->FalseBiasedSelects.count(SI)) { + assert(Scope->TrueBiasedSelects.count(SI) == 0 && + "Must not be already in"); + Scope->TrueBiasedSelects.insert(SI); + } + continue; + } + llvm_unreachable("Must be a branch or a select"); + } + ICmp->setPredicate(CmpInst::getInversePredicate(ICmp->getPredicate())); + return true; +} + +// A helper for transformScopes. Insert a trivial phi at the scope exit block +// for a value that's defined in the scope but used outside it (meaning it's +// alive at the exit block). +static void insertTrivialPHIs(CHRScope *Scope, + BasicBlock *EntryBlock, BasicBlock *ExitBlock, + DenseSet<PHINode *> &TrivialPHIs) { + DenseSet<BasicBlock *> BlocksInScopeSet; + SmallVector<BasicBlock *, 8> BlocksInScopeVec; + for (RegInfo &RI : Scope->RegInfos) { + for (BasicBlock *BB : RI.R->blocks()) { // This includes the blocks in the + // sub-Scopes. + BlocksInScopeSet.insert(BB); + BlocksInScopeVec.push_back(BB); + } + } + CHR_DEBUG( + dbgs() << "Inserting redudant phis\n"; + for (BasicBlock *BB : BlocksInScopeVec) { + dbgs() << "BlockInScope " << BB->getName() << "\n"; + }); + for (BasicBlock *BB : BlocksInScopeVec) { + for (Instruction &I : *BB) { + SmallVector<Instruction *, 8> Users; + for (User *U : I.users()) { + if (auto *UI = dyn_cast<Instruction>(U)) { + if (BlocksInScopeSet.count(UI->getParent()) == 0 && + // Unless there's already a phi for I at the exit block. + !(isa<PHINode>(UI) && UI->getParent() == ExitBlock)) { + CHR_DEBUG(dbgs() << "V " << I << "\n"); + CHR_DEBUG(dbgs() << "Used outside scope by user " << *UI << "\n"); + Users.push_back(UI); + } else if (UI->getParent() == EntryBlock && isa<PHINode>(UI)) { + // There's a loop backedge from a block that's dominated by this + // scope to the entry block. + CHR_DEBUG(dbgs() << "V " << I << "\n"); + CHR_DEBUG(dbgs() + << "Used at entry block (for a back edge) by a phi user " + << *UI << "\n"); + Users.push_back(UI); + } + } + } + if (Users.size() > 0) { + // Insert a trivial phi for I (phi [&I, P0], [&I, P1], ...) at + // ExitBlock. Replace I with the new phi in UI unless UI is another + // phi at ExitBlock. + unsigned PredCount = std::distance(pred_begin(ExitBlock), + pred_end(ExitBlock)); + PHINode *PN = PHINode::Create(I.getType(), PredCount, "", + &ExitBlock->front()); + for (BasicBlock *Pred : predecessors(ExitBlock)) { + PN->addIncoming(&I, Pred); + } + TrivialPHIs.insert(PN); + CHR_DEBUG(dbgs() << "Insert phi " << *PN << "\n"); + for (Instruction *UI : Users) { + for (unsigned J = 0, NumOps = UI->getNumOperands(); J < NumOps; ++J) { + if (UI->getOperand(J) == &I) { + UI->setOperand(J, PN); + } + } + CHR_DEBUG(dbgs() << "Updated user " << *UI << "\n"); + } + } + } + } +} + +// Assert that all the CHR regions of the scope have a biased branch or select. +static void LLVM_ATTRIBUTE_UNUSED +assertCHRRegionsHaveBiasedBranchOrSelect(CHRScope *Scope) { +#ifndef NDEBUG + auto HasBiasedBranchOrSelect = [](RegInfo &RI, CHRScope *Scope) { + if (Scope->TrueBiasedRegions.count(RI.R) || + Scope->FalseBiasedRegions.count(RI.R)) + return true; + for (SelectInst *SI : RI.Selects) + if (Scope->TrueBiasedSelects.count(SI) || + Scope->FalseBiasedSelects.count(SI)) + return true; + return false; + }; + for (RegInfo &RI : Scope->CHRRegions) { + assert(HasBiasedBranchOrSelect(RI, Scope) && + "Must have biased branch or select"); + } +#endif +} + +// Assert that all the condition values of the biased branches and selects have +// been hoisted to the pre-entry block or outside of the scope. +static void LLVM_ATTRIBUTE_UNUSED assertBranchOrSelectConditionHoisted( + CHRScope *Scope, BasicBlock *PreEntryBlock) { + CHR_DEBUG(dbgs() << "Biased regions condition values \n"); + for (RegInfo &RI : Scope->CHRRegions) { + Region *R = RI.R; + bool IsTrueBiased = Scope->TrueBiasedRegions.count(R); + bool IsFalseBiased = Scope->FalseBiasedRegions.count(R); + if (RI.HasBranch && (IsTrueBiased || IsFalseBiased)) { + auto *BI = cast<BranchInst>(R->getEntry()->getTerminator()); + Value *V = BI->getCondition(); + CHR_DEBUG(dbgs() << *V << "\n"); + if (auto *I = dyn_cast<Instruction>(V)) { + (void)(I); // Unused in release build. + assert((I->getParent() == PreEntryBlock || + !Scope->contains(I)) && + "Must have been hoisted to PreEntryBlock or outside the scope"); + } + } + for (SelectInst *SI : RI.Selects) { + bool IsTrueBiased = Scope->TrueBiasedSelects.count(SI); + bool IsFalseBiased = Scope->FalseBiasedSelects.count(SI); + if (!(IsTrueBiased || IsFalseBiased)) + continue; + Value *V = SI->getCondition(); + CHR_DEBUG(dbgs() << *V << "\n"); + if (auto *I = dyn_cast<Instruction>(V)) { + (void)(I); // Unused in release build. + assert((I->getParent() == PreEntryBlock || + !Scope->contains(I)) && + "Must have been hoisted to PreEntryBlock or outside the scope"); + } + } + } +} + +void CHR::transformScopes(CHRScope *Scope, DenseSet<PHINode *> &TrivialPHIs) { + CHR_DEBUG(dbgs() << "transformScopes " << *Scope << "\n"); + + assert(Scope->RegInfos.size() >= 1 && "Should have at least one Region"); + Region *FirstRegion = Scope->RegInfos[0].R; + BasicBlock *EntryBlock = FirstRegion->getEntry(); + Region *LastRegion = Scope->RegInfos[Scope->RegInfos.size() - 1].R; + BasicBlock *ExitBlock = LastRegion->getExit(); + Optional<uint64_t> ProfileCount = BFI.getBlockProfileCount(EntryBlock); + + if (ExitBlock) { + // Insert a trivial phi at the exit block (where the CHR hot path and the + // cold path merges) for a value that's defined in the scope but used + // outside it (meaning it's alive at the exit block). We will add the + // incoming values for the CHR cold paths to it below. Without this, we'd + // miss updating phi's for such values unless there happens to already be a + // phi for that value there. + insertTrivialPHIs(Scope, EntryBlock, ExitBlock, TrivialPHIs); + } + + // Split the entry block of the first region. The new block becomes the new + // entry block of the first region. The old entry block becomes the block to + // insert the CHR branch into. Note DT gets updated. Since DT gets updated + // through the split, we update the entry of the first region after the split, + // and Region only points to the entry and the exit blocks, rather than + // keeping everything in a list or set, the blocks membership and the + // entry/exit blocks of the region are still valid after the split. + CHR_DEBUG(dbgs() << "Splitting entry block " << EntryBlock->getName() + << " at " << *Scope->BranchInsertPoint << "\n"); + BasicBlock *NewEntryBlock = + SplitBlock(EntryBlock, Scope->BranchInsertPoint, &DT); + assert(NewEntryBlock->getSinglePredecessor() == EntryBlock && + "NewEntryBlock's only pred must be EntryBlock"); + FirstRegion->replaceEntryRecursive(NewEntryBlock); + BasicBlock *PreEntryBlock = EntryBlock; + + ValueToValueMapTy VMap; + // Clone the blocks in the scope (excluding the PreEntryBlock) to split into a + // hot path (originals) and a cold path (clones) and update the PHIs at the + // exit block. + cloneScopeBlocks(Scope, PreEntryBlock, ExitBlock, LastRegion, VMap); + + // Replace the old (placeholder) branch with the new (merged) conditional + // branch. + BranchInst *MergedBr = createMergedBranch(PreEntryBlock, EntryBlock, + NewEntryBlock, VMap); + +#ifndef NDEBUG + assertCHRRegionsHaveBiasedBranchOrSelect(Scope); +#endif + + // Hoist the conditional values of the branches/selects. + hoistScopeConditions(Scope, PreEntryBlock->getTerminator(), TrivialPHIs); + +#ifndef NDEBUG + assertBranchOrSelectConditionHoisted(Scope, PreEntryBlock); +#endif + + // Create the combined branch condition and constant-fold the branches/selects + // in the hot path. + fixupBranchesAndSelects(Scope, PreEntryBlock, MergedBr, + ProfileCount ? ProfileCount.getValue() : 0); +} + +// A helper for transformScopes. Clone the blocks in the scope (excluding the +// PreEntryBlock) to split into a hot path and a cold path and update the PHIs +// at the exit block. +void CHR::cloneScopeBlocks(CHRScope *Scope, + BasicBlock *PreEntryBlock, + BasicBlock *ExitBlock, + Region *LastRegion, + ValueToValueMapTy &VMap) { + // Clone all the blocks. The original blocks will be the hot-path + // CHR-optimized code and the cloned blocks will be the original unoptimized + // code. This is so that the block pointers from the + // CHRScope/Region/RegionInfo can stay valid in pointing to the hot-path code + // which CHR should apply to. + SmallVector<BasicBlock*, 8> NewBlocks; + for (RegInfo &RI : Scope->RegInfos) + for (BasicBlock *BB : RI.R->blocks()) { // This includes the blocks in the + // sub-Scopes. + assert(BB != PreEntryBlock && "Don't copy the preetntry block"); + BasicBlock *NewBB = CloneBasicBlock(BB, VMap, ".nonchr", &F); + NewBlocks.push_back(NewBB); + VMap[BB] = NewBB; + } + + // Place the cloned blocks right after the original blocks (right before the + // exit block of.) + if (ExitBlock) + F.getBasicBlockList().splice(ExitBlock->getIterator(), + F.getBasicBlockList(), + NewBlocks[0]->getIterator(), F.end()); + + // Update the cloned blocks/instructions to refer to themselves. + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) + for (Instruction &I : *NewBlocks[i]) + RemapInstruction(&I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + // Add the cloned blocks to the PHIs of the exit blocks. ExitBlock is null for + // the top-level region but we don't need to add PHIs. The trivial PHIs + // inserted above will be updated here. + if (ExitBlock) + for (PHINode &PN : ExitBlock->phis()) + for (unsigned I = 0, NumOps = PN.getNumIncomingValues(); I < NumOps; + ++I) { + BasicBlock *Pred = PN.getIncomingBlock(I); + if (LastRegion->contains(Pred)) { + Value *V = PN.getIncomingValue(I); + auto It = VMap.find(V); + if (It != VMap.end()) V = It->second; + assert(VMap.find(Pred) != VMap.end() && "Pred must have been cloned"); + PN.addIncoming(V, cast<BasicBlock>(VMap[Pred])); + } + } +} + +// A helper for transformScope. Replace the old (placeholder) branch with the +// new (merged) conditional branch. +BranchInst *CHR::createMergedBranch(BasicBlock *PreEntryBlock, + BasicBlock *EntryBlock, + BasicBlock *NewEntryBlock, + ValueToValueMapTy &VMap) { + BranchInst *OldBR = cast<BranchInst>(PreEntryBlock->getTerminator()); + assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == NewEntryBlock && + "SplitBlock did not work correctly!"); + assert(NewEntryBlock->getSinglePredecessor() == EntryBlock && + "NewEntryBlock's only pred must be EntryBlock"); + assert(VMap.find(NewEntryBlock) != VMap.end() && + "NewEntryBlock must have been copied"); + OldBR->dropAllReferences(); + OldBR->eraseFromParent(); + // The true predicate is a placeholder. It will be replaced later in + // fixupBranchesAndSelects(). + BranchInst *NewBR = BranchInst::Create(NewEntryBlock, + cast<BasicBlock>(VMap[NewEntryBlock]), + ConstantInt::getTrue(F.getContext())); + PreEntryBlock->getInstList().push_back(NewBR); + assert(NewEntryBlock->getSinglePredecessor() == EntryBlock && + "NewEntryBlock's only pred must be EntryBlock"); + return NewBR; +} + +// A helper for transformScopes. Create the combined branch condition and +// constant-fold the branches/selects in the hot path. +void CHR::fixupBranchesAndSelects(CHRScope *Scope, + BasicBlock *PreEntryBlock, + BranchInst *MergedBR, + uint64_t ProfileCount) { + Value *MergedCondition = ConstantInt::getTrue(F.getContext()); + BranchProbability CHRBranchBias(1, 1); + uint64_t NumCHRedBranches = 0; + IRBuilder<> IRB(PreEntryBlock->getTerminator()); + for (RegInfo &RI : Scope->CHRRegions) { + Region *R = RI.R; + if (RI.HasBranch) { + fixupBranch(R, Scope, IRB, MergedCondition, CHRBranchBias); + ++NumCHRedBranches; + } + for (SelectInst *SI : RI.Selects) { + fixupSelect(SI, Scope, IRB, MergedCondition, CHRBranchBias); + ++NumCHRedBranches; + } + } + Stats.NumBranchesDelta += NumCHRedBranches - 1; + Stats.WeightedNumBranchesDelta += (NumCHRedBranches - 1) * ProfileCount; + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, + "CHR", + // Refer to the hot (original) path + MergedBR->getSuccessor(0)->getTerminator()) + << "Merged " << ore::NV("NumCHRedBranches", NumCHRedBranches) + << " branches or selects"; + }); + MergedBR->setCondition(MergedCondition); + SmallVector<uint32_t, 2> Weights; + Weights.push_back(static_cast<uint32_t>(CHRBranchBias.scale(1000))); + Weights.push_back(static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000))); + MDBuilder MDB(F.getContext()); + MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1] + << "\n"); +} + +// A helper for fixupBranchesAndSelects. Add to the combined branch condition +// and constant-fold a branch in the hot path. +void CHR::fixupBranch(Region *R, CHRScope *Scope, + IRBuilder<> &IRB, + Value *&MergedCondition, + BranchProbability &CHRBranchBias) { + bool IsTrueBiased = Scope->TrueBiasedRegions.count(R); + assert((IsTrueBiased || Scope->FalseBiasedRegions.count(R)) && + "Must be truthy or falsy"); + auto *BI = cast<BranchInst>(R->getEntry()->getTerminator()); + assert(BranchBiasMap.find(R) != BranchBiasMap.end() && + "Must be in the bias map"); + BranchProbability Bias = BranchBiasMap[R]; + assert(Bias >= getCHRBiasThreshold() && "Must be highly biased"); + // Take the min. + if (CHRBranchBias > Bias) + CHRBranchBias = Bias; + BasicBlock *IfThen = BI->getSuccessor(1); + BasicBlock *IfElse = BI->getSuccessor(0); + BasicBlock *RegionExitBlock = R->getExit(); + assert(RegionExitBlock && "Null ExitBlock"); + assert((IfThen == RegionExitBlock || IfElse == RegionExitBlock) && + IfThen != IfElse && "Invariant from findScopes"); + if (IfThen == RegionExitBlock) { + // Swap them so that IfThen means going into it and IfElse means skipping + // it. + std::swap(IfThen, IfElse); + } + CHR_DEBUG(dbgs() << "IfThen " << IfThen->getName() + << " IfElse " << IfElse->getName() << "\n"); + Value *Cond = BI->getCondition(); + BasicBlock *HotTarget = IsTrueBiased ? IfThen : IfElse; + bool ConditionTrue = HotTarget == BI->getSuccessor(0); + addToMergedCondition(ConditionTrue, Cond, BI, Scope, IRB, + MergedCondition); + // Constant-fold the branch at ClonedEntryBlock. + assert(ConditionTrue == (HotTarget == BI->getSuccessor(0)) && + "The successor shouldn't change"); + Value *NewCondition = ConditionTrue ? + ConstantInt::getTrue(F.getContext()) : + ConstantInt::getFalse(F.getContext()); + BI->setCondition(NewCondition); +} + +// A helper for fixupBranchesAndSelects. Add to the combined branch condition +// and constant-fold a select in the hot path. +void CHR::fixupSelect(SelectInst *SI, CHRScope *Scope, + IRBuilder<> &IRB, + Value *&MergedCondition, + BranchProbability &CHRBranchBias) { + bool IsTrueBiased = Scope->TrueBiasedSelects.count(SI); + assert((IsTrueBiased || + Scope->FalseBiasedSelects.count(SI)) && "Must be biased"); + assert(SelectBiasMap.find(SI) != SelectBiasMap.end() && + "Must be in the bias map"); + BranchProbability Bias = SelectBiasMap[SI]; + assert(Bias >= getCHRBiasThreshold() && "Must be highly biased"); + // Take the min. + if (CHRBranchBias > Bias) + CHRBranchBias = Bias; + Value *Cond = SI->getCondition(); + addToMergedCondition(IsTrueBiased, Cond, SI, Scope, IRB, + MergedCondition); + Value *NewCondition = IsTrueBiased ? + ConstantInt::getTrue(F.getContext()) : + ConstantInt::getFalse(F.getContext()); + SI->setCondition(NewCondition); +} + +// A helper for fixupBranch/fixupSelect. Add a branch condition to the merged +// condition. +void CHR::addToMergedCondition(bool IsTrueBiased, Value *Cond, + Instruction *BranchOrSelect, + CHRScope *Scope, + IRBuilder<> &IRB, + Value *&MergedCondition) { + if (IsTrueBiased) { + MergedCondition = IRB.CreateAnd(MergedCondition, Cond); + } else { + // If Cond is an icmp and all users of V except for BranchOrSelect is a + // branch, negate the icmp predicate and swap the branch targets and avoid + // inserting an Xor to negate Cond. + bool Done = false; + if (auto *ICmp = dyn_cast<ICmpInst>(Cond)) + if (negateICmpIfUsedByBranchOrSelectOnly(ICmp, BranchOrSelect, Scope)) { + MergedCondition = IRB.CreateAnd(MergedCondition, Cond); + Done = true; + } + if (!Done) { + Value *Negate = IRB.CreateXor( + ConstantInt::getTrue(F.getContext()), Cond); + MergedCondition = IRB.CreateAnd(MergedCondition, Negate); + } + } +} + +void CHR::transformScopes(SmallVectorImpl<CHRScope *> &CHRScopes) { + unsigned I = 0; + DenseSet<PHINode *> TrivialPHIs; + for (CHRScope *Scope : CHRScopes) { + transformScopes(Scope, TrivialPHIs); + CHR_DEBUG( + std::ostringstream oss; + oss << " after transformScopes " << I++; + dumpIR(F, oss.str().c_str(), nullptr)); + (void)I; + } +} + +static void LLVM_ATTRIBUTE_UNUSED +dumpScopes(SmallVectorImpl<CHRScope *> &Scopes, const char *Label) { + dbgs() << Label << " " << Scopes.size() << "\n"; + for (CHRScope *Scope : Scopes) { + dbgs() << *Scope << "\n"; + } +} + +bool CHR::run() { + if (!shouldApply(F, PSI)) + return false; + + CHR_DEBUG(dumpIR(F, "before", nullptr)); + + bool Changed = false; + { + CHR_DEBUG( + dbgs() << "RegionInfo:\n"; + RI.print(dbgs())); + + // Recursively traverse the region tree and find regions that have biased + // branches and/or selects and create scopes. + SmallVector<CHRScope *, 8> AllScopes; + findScopes(AllScopes); + CHR_DEBUG(dumpScopes(AllScopes, "All scopes")); + + // Split the scopes if 1) the conditiona values of the biased + // branches/selects of the inner/lower scope can't be hoisted up to the + // outermost/uppermost scope entry, or 2) the condition values of the biased + // branches/selects in a scope (including subscopes) don't share at least + // one common value. + SmallVector<CHRScope *, 8> SplitScopes; + splitScopes(AllScopes, SplitScopes); + CHR_DEBUG(dumpScopes(SplitScopes, "Split scopes")); + + // After splitting, set the biased regions and selects of a scope (a tree + // root) that include those of the subscopes. + classifyBiasedScopes(SplitScopes); + CHR_DEBUG(dbgs() << "Set per-scope bias " << SplitScopes.size() << "\n"); + + // Filter out the scopes that has only one biased region or select (CHR + // isn't useful in such a case). + SmallVector<CHRScope *, 8> FilteredScopes; + filterScopes(SplitScopes, FilteredScopes); + CHR_DEBUG(dumpScopes(FilteredScopes, "Filtered scopes")); + + // Set the regions to be CHR'ed and their hoist stops for each scope. + SmallVector<CHRScope *, 8> SetScopes; + setCHRRegions(FilteredScopes, SetScopes); + CHR_DEBUG(dumpScopes(SetScopes, "Set CHR regions")); + + // Sort CHRScopes by the depth so that outer CHRScopes comes before inner + // ones. We need to apply CHR from outer to inner so that we apply CHR only + // to the hot path, rather than both hot and cold paths. + SmallVector<CHRScope *, 8> SortedScopes; + sortScopes(SetScopes, SortedScopes); + CHR_DEBUG(dumpScopes(SortedScopes, "Sorted scopes")); + + CHR_DEBUG( + dbgs() << "RegionInfo:\n"; + RI.print(dbgs())); + + // Apply the CHR transformation. + if (!SortedScopes.empty()) { + transformScopes(SortedScopes); + Changed = true; + } + } + + if (Changed) { + CHR_DEBUG(dumpIR(F, "after", &Stats)); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Stats", &F) + << ore::NV("Function", &F) << " " + << "Reduced the number of branches in hot paths by " + << ore::NV("NumBranchesDelta", Stats.NumBranchesDelta) + << " (static) and " + << ore::NV("WeightedNumBranchesDelta", Stats.WeightedNumBranchesDelta) + << " (weighted by PGO count)"; + }); + } + + return Changed; +} + +bool ControlHeightReductionLegacyPass::runOnFunction(Function &F) { + BlockFrequencyInfo &BFI = + getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + ProfileSummaryInfo &PSI = + getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + RegionInfo &RI = getAnalysis<RegionInfoPass>().getRegionInfo(); + std::unique_ptr<OptimizationRemarkEmitter> OwnedORE = + llvm::make_unique<OptimizationRemarkEmitter>(&F); + return CHR(F, BFI, DT, PSI, RI, *OwnedORE.get()).run(); +} + +namespace llvm { + +ControlHeightReductionPass::ControlHeightReductionPass() { + parseCHRFilterFiles(); +} + +PreservedAnalyses ControlHeightReductionPass::run( + Function &F, + FunctionAnalysisManager &FAM) { + auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); + auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); + auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F); + auto &MAM = MAMProxy.getManager(); + auto &PSI = *MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + auto &RI = FAM.getResult<RegionInfoAnalysis>(F); + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + bool Changed = CHR(F, BFI, DT, PSI, RI, ORE).run(); + if (!Changed) + return PreservedAnalyses::all(); + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; +} + +} // namespace llvm diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index bb0e4379d1a8..4c3c6c9added 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -231,17 +231,17 @@ struct TransformedFunction { TransformedFunction& operator=(TransformedFunction&&) = default; /// Type of the function before the transformation. - FunctionType* const OriginalType; + FunctionType *OriginalType; /// Type of the function after the transformation. - FunctionType* const TransformedType; + FunctionType *TransformedType; /// Transforming a function may change the position of arguments. This /// member records the mapping from each argument's old position to its new /// position. Argument positions are zero-indexed. If the transformation /// from F to F' made the first argument of F into the third argument of F', /// then ArgumentIndexMapping[0] will equal 2. - const std::vector<unsigned> ArgumentIndexMapping; + std::vector<unsigned> ArgumentIndexMapping; }; /// Given function attributes from a call site for the original function, @@ -645,8 +645,8 @@ DataFlowSanitizer::buildWrapperFunction(Function *F, StringRef NewFName, GlobalValue::LinkageTypes NewFLink, FunctionType *NewFT) { FunctionType *FT = F->getFunctionType(); - Function *NewF = Function::Create(NewFT, NewFLink, NewFName, - F->getParent()); + Function *NewF = Function::Create(NewFT, NewFLink, F->getAddressSpace(), + NewFName, F->getParent()); NewF->copyAttributesFrom(F); NewF->removeAttributes( AttributeList::ReturnIndex, @@ -819,7 +819,8 @@ bool DataFlowSanitizer::runOnModule(Module &M) { // easily identify cases of mismatching ABIs. if (getInstrumentedABI() == IA_Args && !IsZeroArgsVoidRet) { FunctionType *NewFT = getArgsFunctionType(FT); - Function *NewF = Function::Create(NewFT, F.getLinkage(), "", &M); + Function *NewF = Function::Create(NewFT, F.getLinkage(), + F.getAddressSpace(), "", &M); NewF->copyAttributesFrom(&F); NewF->removeAttributes( AttributeList::ReturnIndex, @@ -924,7 +925,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) { Instruction *Next = Inst->getNextNode(); // DFSanVisitor may delete Inst, so keep track of whether it was a // terminator. - bool IsTerminator = isa<TerminatorInst>(Inst); + bool IsTerminator = Inst->isTerminator(); if (!DFSF.SkipInsts.count(Inst)) DFSanVisitor(DFSF).visit(Inst); if (IsTerminator) diff --git a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp index 33f220a893df..db438e78ded9 100644 --- a/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp +++ b/lib/Transforms/Instrumentation/EfficiencySanitizer.cpp @@ -144,21 +144,6 @@ OverrideOptionsFromCL(EfficiencySanitizerOptions Options) { return Options; } -// Create a constant for Str so that we can pass it to the run-time lib. -static GlobalVariable *createPrivateGlobalForString(Module &M, StringRef Str, - bool AllowMerging) { - Constant *StrConst = ConstantDataArray::getString(M.getContext(), Str); - // We use private linkage for module-local strings. If they can be merged - // with another one, we set the unnamed_addr attribute. - GlobalVariable *GV = - new GlobalVariable(M, StrConst->getType(), true, - GlobalValue::PrivateLinkage, StrConst, ""); - if (AllowMerging) - GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); - GV->setAlignment(1); // Strings may not be merged w/o setting align 1. - return GV; -} - /// EfficiencySanitizer: instrument each module to find performance issues. class EfficiencySanitizer : public ModulePass { public: @@ -902,7 +887,7 @@ bool EfficiencySanitizer::instrumentFastpathWorkingSet( Value *OldValue = IRB.CreateLoad(IRB.CreateIntToPtr(ShadowPtr, ShadowPtrTy)); // The AND and CMP will be turned into a TEST instruction by the compiler. Value *Cmp = IRB.CreateICmpNE(IRB.CreateAnd(OldValue, ValueMask), ValueMask); - TerminatorInst *CmpTerm = SplitBlockAndInsertIfThen(Cmp, I, false); + Instruction *CmpTerm = SplitBlockAndInsertIfThen(Cmp, I, false); // FIXME: do I need to call SetCurrentDebugLocation? IRB.SetInsertPoint(CmpTerm); // We use OR to set the shadow bits to avoid corrupting the middle 6 bits, diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 132e8089fe3b..9af64ed332cd 100644 --- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -21,9 +21,9 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" -#include "llvm/ADT/UniqueVector.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/IRBuilder.h" @@ -36,6 +36,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" +#include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/GCOVProfiler.h" @@ -96,30 +97,25 @@ private: // profiling runtime to emit .gcda files when run. bool emitProfileArcs(); + bool isFunctionInstrumented(const Function &F); + std::vector<Regex> createRegexesFromString(StringRef RegexesStr); + static bool doesFilenameMatchARegex(StringRef Filename, + std::vector<Regex> &Regexes); + // Get pointers to the functions in the runtime library. Constant *getStartFileFunc(); - Constant *getIncrementIndirectCounterFunc(); Constant *getEmitFunctionFunc(); Constant *getEmitArcsFunc(); Constant *getSummaryInfoFunc(); Constant *getEndFileFunc(); - // Create or retrieve an i32 state value that is used to represent the - // pred block number for certain non-trivial edges. - GlobalVariable *getEdgeStateValue(); - - // Produce a table of pointers to counters, by predecessor and successor - // block number. - GlobalVariable *buildEdgeLookupTable(Function *F, GlobalVariable *Counter, - const UniqueVector<BasicBlock *> &Preds, - const UniqueVector<BasicBlock *> &Succs); - // Add the function to write out all our counters to the global destructor // list. Function * insertCounterWriteout(ArrayRef<std::pair<GlobalVariable *, MDNode *>>); Function *insertFlush(ArrayRef<std::pair<GlobalVariable *, MDNode *>>); - void insertIndirectCounterIncrement(); + + void AddFlushBeforeForkAndExec(); enum class GCovFileType { GCNO, GCDA }; std::string mangleName(const DICompileUnit *CU, GCovFileType FileType); @@ -135,6 +131,9 @@ private: const TargetLibraryInfo *TLI; LLVMContext *Ctx; SmallVector<std::unique_ptr<GCOVFunction>, 16> Funcs; + std::vector<Regex> FilterRe; + std::vector<Regex> ExcludeRe; + StringMap<bool> InstrumentedFiles; }; class GCOVProfilerLegacyPass : public ModulePass { @@ -181,6 +180,21 @@ static StringRef getFunctionName(const DISubprogram *SP) { return SP->getName(); } +/// Extract a filename for a DISubprogram. +/// +/// Prefer relative paths in the coverage notes. Clang also may split +/// up absolute paths into a directory and filename component. When +/// the relative path doesn't exist, reconstruct the absolute path. +static SmallString<128> getFilename(const DISubprogram *SP) { + SmallString<128> Path; + StringRef RelPath = SP->getFilename(); + if (sys::fs::exists(RelPath)) + Path = RelPath; + else + sys::path::append(Path, SP->getDirectory(), SP->getFilename()); + return Path; +} + namespace { class GCOVRecord { protected: @@ -257,7 +271,7 @@ namespace { } private: - StringRef Filename; + std::string Filename; SmallVector<uint32_t, 32> Lines; }; @@ -287,11 +301,10 @@ namespace { write(Len); write(Number); - llvm::sort( - SortedLinesByFile.begin(), SortedLinesByFile.end(), - [](StringMapEntry<GCOVLines> *LHS, StringMapEntry<GCOVLines> *RHS) { - return LHS->getKey() < RHS->getKey(); - }); + llvm::sort(SortedLinesByFile, [](StringMapEntry<GCOVLines> *LHS, + StringMapEntry<GCOVLines> *RHS) { + return LHS->getKey() < RHS->getKey(); + }); for (auto &I : SortedLinesByFile) I->getValue().writeOut(); write(0); @@ -379,8 +392,9 @@ namespace { void writeOut() { writeBytes(FunctionTag, 4); + SmallString<128> Filename = getFilename(SP); uint32_t BlockLen = 1 + 1 + 1 + lengthOfGCOVString(getFunctionName(SP)) + - 1 + lengthOfGCOVString(SP->getFilename()) + 1; + 1 + lengthOfGCOVString(Filename) + 1; if (UseCfgChecksum) ++BlockLen; write(BlockLen); @@ -389,7 +403,7 @@ namespace { if (UseCfgChecksum) write(CfgChecksum); writeGCOVString(getFunctionName(SP)); - writeGCOVString(SP->getFilename()); + writeGCOVString(Filename); write(SP->getLine()); // Emit count of blocks. @@ -434,6 +448,72 @@ namespace { }; } +// RegexesStr is a string containing differents regex separated by a semi-colon. +// For example "foo\..*$;bar\..*$". +std::vector<Regex> GCOVProfiler::createRegexesFromString(StringRef RegexesStr) { + std::vector<Regex> Regexes; + while (!RegexesStr.empty()) { + std::pair<StringRef, StringRef> HeadTail = RegexesStr.split(';'); + if (!HeadTail.first.empty()) { + Regex Re(HeadTail.first); + std::string Err; + if (!Re.isValid(Err)) { + Ctx->emitError(Twine("Regex ") + HeadTail.first + + " is not valid: " + Err); + } + Regexes.emplace_back(std::move(Re)); + } + RegexesStr = HeadTail.second; + } + return Regexes; +} + +bool GCOVProfiler::doesFilenameMatchARegex(StringRef Filename, + std::vector<Regex> &Regexes) { + for (Regex &Re : Regexes) { + if (Re.match(Filename)) { + return true; + } + } + return false; +} + +bool GCOVProfiler::isFunctionInstrumented(const Function &F) { + if (FilterRe.empty() && ExcludeRe.empty()) { + return true; + } + SmallString<128> Filename = getFilename(F.getSubprogram()); + auto It = InstrumentedFiles.find(Filename); + if (It != InstrumentedFiles.end()) { + return It->second; + } + + SmallString<256> RealPath; + StringRef RealFilename; + + // Path can be + // /usr/lib/gcc/x86_64-linux-gnu/8/../../../../include/c++/8/bits/*.h so for + // such a case we must get the real_path. + if (sys::fs::real_path(Filename, RealPath)) { + // real_path can fail with path like "foo.c". + RealFilename = Filename; + } else { + RealFilename = RealPath; + } + + bool ShouldInstrument; + if (FilterRe.empty()) { + ShouldInstrument = !doesFilenameMatchARegex(RealFilename, ExcludeRe); + } else if (ExcludeRe.empty()) { + ShouldInstrument = doesFilenameMatchARegex(RealFilename, FilterRe); + } else { + ShouldInstrument = doesFilenameMatchARegex(RealFilename, FilterRe) && + !doesFilenameMatchARegex(RealFilename, ExcludeRe); + } + InstrumentedFiles[Filename] = ShouldInstrument; + return ShouldInstrument; +} + std::string GCOVProfiler::mangleName(const DICompileUnit *CU, GCovFileType OutputType) { bool Notes = OutputType == GCovFileType::GCNO; @@ -481,6 +561,11 @@ bool GCOVProfiler::runOnModule(Module &M, const TargetLibraryInfo &TLI) { this->TLI = &TLI; Ctx = &M.getContext(); + AddFlushBeforeForkAndExec(); + + FilterRe = createRegexesFromString(Options.Filter); + ExcludeRe = createRegexesFromString(Options.Exclude); + if (Options.EmitNotes) emitProfileNotes(); if (Options.EmitData) return emitProfileArcs(); return false; @@ -537,6 +622,38 @@ static bool shouldKeepInEntry(BasicBlock::iterator It) { return false; } +void GCOVProfiler::AddFlushBeforeForkAndExec() { + SmallVector<Instruction *, 2> ForkAndExecs; + for (auto &F : M->functions()) { + for (auto &I : instructions(F)) { + if (CallInst *CI = dyn_cast<CallInst>(&I)) { + if (Function *Callee = CI->getCalledFunction()) { + LibFunc LF; + if (TLI->getLibFunc(*Callee, LF) && + (LF == LibFunc_fork || LF == LibFunc_execl || + LF == LibFunc_execle || LF == LibFunc_execlp || + LF == LibFunc_execv || LF == LibFunc_execvp || + LF == LibFunc_execve || LF == LibFunc_execvpe || + LF == LibFunc_execvP)) { + ForkAndExecs.push_back(&I); + } + } + } + } + } + + // We need to split the block after the fork/exec call + // because else the counters for the lines after will be + // the same as before the call. + for (auto I : ForkAndExecs) { + IRBuilder<> Builder(I); + FunctionType *FTy = FunctionType::get(Builder.getVoidTy(), {}, false); + Constant *GCOVFlush = M->getOrInsertFunction("__gcov_flush", FTy); + Builder.CreateCall(GCOVFlush); + I->getParent()->splitBasicBlock(I); + } +} + void GCOVProfiler::emitProfileNotes() { NamedMDNode *CU_Nodes = M->getNamedMetadata("llvm.dbg.cu"); if (!CU_Nodes) return; @@ -566,7 +683,8 @@ void GCOVProfiler::emitProfileNotes() { for (auto &F : M->functions()) { DISubprogram *SP = F.getSubprogram(); if (!SP) continue; - if (!functionHasLines(F)) continue; + if (!functionHasLines(F) || !isFunctionInstrumented(F)) + continue; // TODO: Functions using scope-based EH are currently not supported. if (isUsingScopeBasedEH(F)) continue; @@ -583,9 +701,15 @@ void GCOVProfiler::emitProfileNotes() { Options.ExitBlockBeforeBody)); GCOVFunction &Func = *Funcs.back(); + // Add the function line number to the lines of the entry block + // to have a counter for the function definition. + uint32_t Line = SP->getLine(); + auto Filename = getFilename(SP); + Func.getBlock(&EntryBlock).getFile(Filename).addLine(Line); + for (auto &BB : F) { GCOVBlock &Block = Func.getBlock(&BB); - TerminatorInst *TI = BB.getTerminator(); + Instruction *TI = BB.getTerminator(); if (int successors = TI->getNumSuccessors()) { for (int i = 0; i != successors; ++i) { Block.addEdge(Func.getBlock(TI->getSuccessor(i))); @@ -594,7 +718,6 @@ void GCOVProfiler::emitProfileNotes() { Block.addEdge(Func.getReturnBlock()); } - uint32_t Line = 0; for (auto &I : BB) { // Debug intrinsic locations correspond to the location of the // declaration, not necessarily any statements or expressions. @@ -605,16 +728,18 @@ void GCOVProfiler::emitProfileNotes() { continue; // Artificial lines such as calls to the global constructors. - if (Loc.getLine() == 0) continue; + if (Loc.getLine() == 0 || Loc.isImplicitCode()) + continue; if (Line == Loc.getLine()) continue; Line = Loc.getLine(); if (SP != getDISubprogram(Loc.getScope())) continue; - GCOVLines &Lines = Block.getFile(SP->getFilename()); + GCOVLines &Lines = Block.getFile(Filename); Lines.addLine(Loc.getLine()); } + Line = 0; } EdgeDestinations += Func.getEdgeDestinations(); } @@ -639,24 +764,28 @@ bool GCOVProfiler::emitProfileArcs() { if (!CU_Nodes) return false; bool Result = false; - bool InsertIndCounterIncrCode = false; for (unsigned i = 0, e = CU_Nodes->getNumOperands(); i != e; ++i) { SmallVector<std::pair<GlobalVariable *, MDNode *>, 8> CountersBySP; for (auto &F : M->functions()) { DISubprogram *SP = F.getSubprogram(); if (!SP) continue; - if (!functionHasLines(F)) continue; + if (!functionHasLines(F) || !isFunctionInstrumented(F)) + continue; // TODO: Functions using scope-based EH are currently not supported. if (isUsingScopeBasedEH(F)) continue; if (!Result) Result = true; + DenseMap<std::pair<BasicBlock *, BasicBlock *>, unsigned> EdgeToCounter; unsigned Edges = 0; for (auto &BB : F) { - TerminatorInst *TI = BB.getTerminator(); - if (isa<ReturnInst>(TI)) - ++Edges; - else - Edges += TI->getNumSuccessors(); + Instruction *TI = BB.getTerminator(); + if (isa<ReturnInst>(TI)) { + EdgeToCounter[{&BB, nullptr}] = Edges++; + } else { + for (BasicBlock *Succ : successors(TI)) { + EdgeToCounter[{&BB, Succ}] = Edges++; + } + } } ArrayType *CounterTy = @@ -668,63 +797,42 @@ bool GCOVProfiler::emitProfileArcs() { "__llvm_gcov_ctr"); CountersBySP.push_back(std::make_pair(Counters, SP)); - UniqueVector<BasicBlock *> ComplexEdgePreds; - UniqueVector<BasicBlock *> ComplexEdgeSuccs; - - unsigned Edge = 0; + // If a BB has several predecessors, use a PHINode to select + // the correct counter. for (auto &BB : F) { - TerminatorInst *TI = BB.getTerminator(); - int Successors = isa<ReturnInst>(TI) ? 1 : TI->getNumSuccessors(); - if (Successors) { - if (Successors == 1) { - IRBuilder<> Builder(&*BB.getFirstInsertionPt()); - Value *Counter = Builder.CreateConstInBoundsGEP2_64(Counters, 0, - Edge); - Value *Count = Builder.CreateLoad(Counter); - Count = Builder.CreateAdd(Count, Builder.getInt64(1)); - Builder.CreateStore(Count, Counter); - } else if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { - IRBuilder<> Builder(BI); - Value *Sel = Builder.CreateSelect(BI->getCondition(), - Builder.getInt64(Edge), - Builder.getInt64(Edge + 1)); - Value *Counter = Builder.CreateInBoundsGEP( - Counters->getValueType(), Counters, {Builder.getInt64(0), Sel}); + const unsigned EdgeCount = + std::distance(pred_begin(&BB), pred_end(&BB)); + if (EdgeCount) { + // The phi node must be at the begin of the BB. + IRBuilder<> BuilderForPhi(&*BB.begin()); + Type *Int64PtrTy = Type::getInt64PtrTy(*Ctx); + PHINode *Phi = BuilderForPhi.CreatePHI(Int64PtrTy, EdgeCount); + for (BasicBlock *Pred : predecessors(&BB)) { + auto It = EdgeToCounter.find({Pred, &BB}); + assert(It != EdgeToCounter.end()); + const unsigned Edge = It->second; + Value *EdgeCounter = + BuilderForPhi.CreateConstInBoundsGEP2_64(Counters, 0, Edge); + Phi->addIncoming(EdgeCounter, Pred); + } + + // Skip phis, landingpads. + IRBuilder<> Builder(&*BB.getFirstInsertionPt()); + Value *Count = Builder.CreateLoad(Phi); + Count = Builder.CreateAdd(Count, Builder.getInt64(1)); + Builder.CreateStore(Count, Phi); + + Instruction *TI = BB.getTerminator(); + if (isa<ReturnInst>(TI)) { + auto It = EdgeToCounter.find({&BB, nullptr}); + assert(It != EdgeToCounter.end()); + const unsigned Edge = It->second; + Value *Counter = + Builder.CreateConstInBoundsGEP2_64(Counters, 0, Edge); Value *Count = Builder.CreateLoad(Counter); Count = Builder.CreateAdd(Count, Builder.getInt64(1)); Builder.CreateStore(Count, Counter); - } else { - ComplexEdgePreds.insert(&BB); - for (int i = 0; i != Successors; ++i) - ComplexEdgeSuccs.insert(TI->getSuccessor(i)); } - - Edge += Successors; - } - } - - if (!ComplexEdgePreds.empty()) { - GlobalVariable *EdgeTable = - buildEdgeLookupTable(&F, Counters, - ComplexEdgePreds, ComplexEdgeSuccs); - GlobalVariable *EdgeState = getEdgeStateValue(); - - for (int i = 0, e = ComplexEdgePreds.size(); i != e; ++i) { - IRBuilder<> Builder(&*ComplexEdgePreds[i + 1]->getFirstInsertionPt()); - Builder.CreateStore(Builder.getInt32(i), EdgeState); - } - - for (int i = 0, e = ComplexEdgeSuccs.size(); i != e; ++i) { - // Call runtime to perform increment. - IRBuilder<> Builder(&*ComplexEdgeSuccs[i + 1]->getFirstInsertionPt()); - Value *CounterPtrArray = - Builder.CreateConstInBoundsGEP2_64(EdgeTable, 0, - i * ComplexEdgePreds.size()); - - // Build code to increment the counter. - InsertIndCounterIncrCode = true; - Builder.CreateCall(getIncrementIndirectCounterFunc(), - {EdgeState, CounterPtrArray}); } } } @@ -763,60 +871,9 @@ bool GCOVProfiler::emitProfileArcs() { appendToGlobalCtors(*M, F, 0); } - if (InsertIndCounterIncrCode) - insertIndirectCounterIncrement(); - return Result; } -// All edges with successors that aren't branches are "complex", because it -// requires complex logic to pick which counter to update. -GlobalVariable *GCOVProfiler::buildEdgeLookupTable( - Function *F, - GlobalVariable *Counters, - const UniqueVector<BasicBlock *> &Preds, - const UniqueVector<BasicBlock *> &Succs) { - // TODO: support invoke, threads. We rely on the fact that nothing can modify - // the whole-Module pred edge# between the time we set it and the time we next - // read it. Threads and invoke make this untrue. - - // emit [(succs * preds) x i64*], logically [succ x [pred x i64*]]. - size_t TableSize = Succs.size() * Preds.size(); - Type *Int64PtrTy = Type::getInt64PtrTy(*Ctx); - ArrayType *EdgeTableTy = ArrayType::get(Int64PtrTy, TableSize); - - std::unique_ptr<Constant * []> EdgeTable(new Constant *[TableSize]); - Constant *NullValue = Constant::getNullValue(Int64PtrTy); - for (size_t i = 0; i != TableSize; ++i) - EdgeTable[i] = NullValue; - - unsigned Edge = 0; - for (BasicBlock &BB : *F) { - TerminatorInst *TI = BB.getTerminator(); - int Successors = isa<ReturnInst>(TI) ? 1 : TI->getNumSuccessors(); - if (Successors > 1 && !isa<BranchInst>(TI) && !isa<ReturnInst>(TI)) { - for (int i = 0; i != Successors; ++i) { - BasicBlock *Succ = TI->getSuccessor(i); - IRBuilder<> Builder(Succ); - Value *Counter = Builder.CreateConstInBoundsGEP2_64(Counters, 0, - Edge + i); - EdgeTable[((Succs.idFor(Succ) - 1) * Preds.size()) + - (Preds.idFor(&BB) - 1)] = cast<Constant>(Counter); - } - } - Edge += Successors; - } - - GlobalVariable *EdgeTableGV = - new GlobalVariable( - *M, EdgeTableTy, true, GlobalValue::InternalLinkage, - ConstantArray::get(EdgeTableTy, - makeArrayRef(&EdgeTable[0],TableSize)), - "__llvm_gcda_edge_table"); - EdgeTableGV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); - return EdgeTableGV; -} - Constant *GCOVProfiler::getStartFileFunc() { Type *Args[] = { Type::getInt8PtrTy(*Ctx), // const char *orig_filename @@ -832,17 +889,6 @@ Constant *GCOVProfiler::getStartFileFunc() { } -Constant *GCOVProfiler::getIncrementIndirectCounterFunc() { - Type *Int32Ty = Type::getInt32Ty(*Ctx); - Type *Int64Ty = Type::getInt64Ty(*Ctx); - Type *Args[] = { - Int32Ty->getPointerTo(), // uint32_t *predecessor - Int64Ty->getPointerTo()->getPointerTo() // uint64_t **counters - }; - FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), Args, false); - return M->getOrInsertFunction("__llvm_gcov_indirect_counter_increment", FTy); -} - Constant *GCOVProfiler::getEmitFunctionFunc() { Type *Args[] = { Type::getInt32Ty(*Ctx), // uint32_t ident @@ -886,19 +932,6 @@ Constant *GCOVProfiler::getEndFileFunc() { return M->getOrInsertFunction("llvm_gcda_end_file", FTy); } -GlobalVariable *GCOVProfiler::getEdgeStateValue() { - GlobalVariable *GV = M->getGlobalVariable("__llvm_gcov_global_state_pred"); - if (!GV) { - GV = new GlobalVariable(*M, Type::getInt32Ty(*Ctx), false, - GlobalValue::InternalLinkage, - ConstantInt::get(Type::getInt32Ty(*Ctx), - 0xffffffff), - "__llvm_gcov_global_state_pred"); - GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); - } - return GV; -} - Function *GCOVProfiler::insertCounterWriteout( ArrayRef<std::pair<GlobalVariable *, MDNode *> > CountersBySP) { FunctionType *WriteoutFTy = FunctionType::get(Type::getVoidTy(*Ctx), false); @@ -1122,57 +1155,6 @@ Function *GCOVProfiler::insertCounterWriteout( return WriteoutF; } -void GCOVProfiler::insertIndirectCounterIncrement() { - Function *Fn = - cast<Function>(GCOVProfiler::getIncrementIndirectCounterFunc()); - Fn->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); - Fn->setLinkage(GlobalValue::InternalLinkage); - Fn->addFnAttr(Attribute::NoInline); - if (Options.NoRedZone) - Fn->addFnAttr(Attribute::NoRedZone); - - // Create basic blocks for function. - BasicBlock *BB = BasicBlock::Create(*Ctx, "entry", Fn); - IRBuilder<> Builder(BB); - - BasicBlock *PredNotNegOne = BasicBlock::Create(*Ctx, "", Fn); - BasicBlock *CounterEnd = BasicBlock::Create(*Ctx, "", Fn); - BasicBlock *Exit = BasicBlock::Create(*Ctx, "exit", Fn); - - // uint32_t pred = *predecessor; - // if (pred == 0xffffffff) return; - Argument *Arg = &*Fn->arg_begin(); - Arg->setName("predecessor"); - Value *Pred = Builder.CreateLoad(Arg, "pred"); - Value *Cond = Builder.CreateICmpEQ(Pred, Builder.getInt32(0xffffffff)); - BranchInst::Create(Exit, PredNotNegOne, Cond, BB); - - Builder.SetInsertPoint(PredNotNegOne); - - // uint64_t *counter = counters[pred]; - // if (!counter) return; - Value *ZExtPred = Builder.CreateZExt(Pred, Builder.getInt64Ty()); - Arg = &*std::next(Fn->arg_begin()); - Arg->setName("counters"); - Value *GEP = Builder.CreateGEP(Type::getInt64PtrTy(*Ctx), Arg, ZExtPred); - Value *Counter = Builder.CreateLoad(GEP, "counter"); - Cond = Builder.CreateICmpEQ(Counter, - Constant::getNullValue( - Builder.getInt64Ty()->getPointerTo())); - Builder.CreateCondBr(Cond, Exit, CounterEnd); - - // ++*counter; - Builder.SetInsertPoint(CounterEnd); - Value *Add = Builder.CreateAdd(Builder.CreateLoad(Counter), - Builder.getInt64(1)); - Builder.CreateStore(Add, Counter); - Builder.CreateBr(Exit); - - // Fill in the exit block. - Builder.SetInsertPoint(Exit); - Builder.CreateRetVoid(); -} - Function *GCOVProfiler:: insertFlush(ArrayRef<std::pair<GlobalVariable*, MDNode*> > CountersBySP) { FunctionType *FTy = FunctionType::get(Type::getVoidTy(*Ctx), false); diff --git a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index d62598bb5d4f..d04c2b76288f 100644 --- a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -44,6 +44,7 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include <sstream> using namespace llvm; @@ -63,6 +64,8 @@ static const uint64_t kDynamicShadowSentinel = std::numeric_limits<uint64_t>::max(); static const unsigned kPointerTagShift = 56; +static const unsigned kShadowBaseAlignment = 32; + static cl::opt<std::string> ClMemoryAccessCallbackPrefix( "hwasan-memory-access-callback-prefix", cl::desc("Prefix for memory access callbacks"), cl::Hidden, @@ -127,6 +130,32 @@ static cl::opt<unsigned long long> ClMappingOffset( cl::desc("HWASan shadow mapping offset [EXPERIMENTAL]"), cl::Hidden, cl::init(0)); +static cl::opt<bool> + ClWithIfunc("hwasan-with-ifunc", + cl::desc("Access dynamic shadow through an ifunc global on " + "platforms that support this"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> ClWithTls( + "hwasan-with-tls", + cl::desc("Access dynamic shadow through an thread-local pointer on " + "platforms that support this"), + cl::Hidden, cl::init(true)); + +static cl::opt<bool> + ClRecordStackHistory("hwasan-record-stack-history", + cl::desc("Record stack frames with tagged allocations " + "in a thread-local ring buffer"), + cl::Hidden, cl::init(true)); +static cl::opt<bool> + ClCreateFrameDescriptions("hwasan-create-frame-descriptions", + cl::desc("create static frame descriptions"), + cl::Hidden, cl::init(true)); + +static cl::opt<bool> + ClInstrumentMemIntrinsics("hwasan-instrument-mem-intrinsics", + cl::desc("instrument memory intrinsics"), + cl::Hidden, cl::init(true)); namespace { /// An instrumentation pass implementing detection of addressability bugs @@ -150,13 +179,14 @@ public: void initializeCallbacks(Module &M); - void maybeInsertDynamicShadowAtFunctionEntry(Function &F); + Value *getDynamicShadowNonTls(IRBuilder<> &IRB); void untagPointerOperand(Instruction *I, Value *Addr); Value *memToShadow(Value *Shadow, Type *Ty, IRBuilder<> &IRB); void instrumentMemAccessInline(Value *PtrLong, bool IsWrite, unsigned AccessSizeIndex, Instruction *InsertBefore); + void instrumentMemIntrinsic(MemIntrinsic *MI); bool instrumentMemAccess(Instruction *I); Value *isInterestingMemoryAccess(Instruction *I, bool *IsWrite, uint64_t *TypeSize, unsigned *Alignment, @@ -167,26 +197,53 @@ public: Value *tagPointer(IRBuilder<> &IRB, Type *Ty, Value *PtrLong, Value *Tag); Value *untagPointer(IRBuilder<> &IRB, Value *PtrLong); bool instrumentStack(SmallVectorImpl<AllocaInst *> &Allocas, - SmallVectorImpl<Instruction *> &RetVec); + SmallVectorImpl<Instruction *> &RetVec, Value *StackTag); Value *getNextTagWithCall(IRBuilder<> &IRB); Value *getStackBaseTag(IRBuilder<> &IRB); Value *getAllocaTag(IRBuilder<> &IRB, Value *StackTag, AllocaInst *AI, unsigned AllocaNo); Value *getUARTag(IRBuilder<> &IRB, Value *StackTag); + Value *getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty); + Value *emitPrologue(IRBuilder<> &IRB, bool WithFrameRecord); + private: LLVMContext *C; + std::string CurModuleUniqueId; Triple TargetTriple; + Function *HWAsanMemmove, *HWAsanMemcpy, *HWAsanMemset; + + // Frame description is a way to pass names/sizes of local variables + // to the run-time w/o adding extra executable code in every function. + // We do this by creating a separate section with {PC,Descr} pairs and passing + // the section beg/end to __hwasan_init_frames() at module init time. + std::string createFrameString(ArrayRef<AllocaInst*> Allocas); + void createFrameGlobal(Function &F, const std::string &FrameString); + // Get the section name for frame descriptions. Currently ELF-only. + const char *getFrameSection() { return "__hwasan_frames"; } + const char *getFrameSectionBeg() { return "__start___hwasan_frames"; } + const char *getFrameSectionEnd() { return "__stop___hwasan_frames"; } + GlobalVariable *createFrameSectionBound(Module &M, Type *Ty, + const char *Name) { + auto GV = new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, + nullptr, Name); + GV->setVisibility(GlobalValue::HiddenVisibility); + return GV; + } /// This struct defines the shadow mapping using the rule: /// shadow = (mem >> Scale) + Offset. /// If InGlobal is true, then /// extern char __hwasan_shadow[]; /// shadow = (mem >> Scale) + &__hwasan_shadow + /// If InTls is true, then + /// extern char *__hwasan_tls; + /// shadow = (mem>>Scale) + align_up(__hwasan_shadow, kShadowBaseAlignment) struct ShadowMapping { int Scale; uint64_t Offset; bool InGlobal; + bool InTls; void init(Triple &TargetTriple); unsigned getAllocaAlignment() const { return 1U << Scale; } @@ -194,6 +251,7 @@ private: ShadowMapping Mapping; Type *IntptrTy; + Type *Int8PtrTy; Type *Int8Ty; bool CompileKernel; @@ -206,10 +264,12 @@ private: Function *HwasanTagMemoryFunc; Function *HwasanGenerateTagFunc; + Function *HwasanThreadEnterFunc; Constant *ShadowGlobal; Value *LocalDynamicShadow = nullptr; + GlobalValue *ThreadPtrGlobal = nullptr; }; } // end anonymous namespace @@ -243,8 +303,10 @@ bool HWAddressSanitizer::doInitialization(Module &M) { Mapping.init(TargetTriple); C = &(M.getContext()); + CurModuleUniqueId = getUniqueModuleId(&M); IRBuilder<> IRB(*C); IntptrTy = IRB.getIntPtrTy(DL); + Int8PtrTy = IRB.getInt8PtrTy(); Int8Ty = IRB.getInt8Ty(); HwasanCtorFunction = nullptr; @@ -254,8 +316,38 @@ bool HWAddressSanitizer::doInitialization(Module &M) { kHwasanInitName, /*InitArgTypes=*/{}, /*InitArgs=*/{}); - appendToGlobalCtors(M, HwasanCtorFunction, 0); + Comdat *CtorComdat = M.getOrInsertComdat(kHwasanModuleCtorName); + HwasanCtorFunction->setComdat(CtorComdat); + appendToGlobalCtors(M, HwasanCtorFunction, 0, HwasanCtorFunction); + + // Create a zero-length global in __hwasan_frame so that the linker will + // always create start and stop symbols. + // + // N.B. If we ever start creating associated metadata in this pass this + // global will need to be associated with the ctor. + Type *Int8Arr0Ty = ArrayType::get(Int8Ty, 0); + auto GV = + new GlobalVariable(M, Int8Arr0Ty, /*isConstantGlobal*/ true, + GlobalVariable::PrivateLinkage, + Constant::getNullValue(Int8Arr0Ty), "__hwasan"); + GV->setSection(getFrameSection()); + GV->setComdat(CtorComdat); + appendToCompilerUsed(M, GV); + + IRBuilder<> IRBCtor(HwasanCtorFunction->getEntryBlock().getTerminator()); + IRBCtor.CreateCall( + declareSanitizerInitFunction(M, "__hwasan_init_frames", + {Int8PtrTy, Int8PtrTy}), + {createFrameSectionBound(M, Int8Ty, getFrameSectionBeg()), + createFrameSectionBound(M, Int8Ty, getFrameSectionEnd())}); } + + if (!TargetTriple.isAndroid()) + appendToCompilerUsed( + M, ThreadPtrGlobal = new GlobalVariable( + M, IntptrTy, false, GlobalVariable::ExternalLinkage, nullptr, + "__hwasan_tls", nullptr, GlobalVariable::InitialExecTLSModel)); + return true; } @@ -281,21 +373,35 @@ void HWAddressSanitizer::initializeCallbacks(Module &M) { } HwasanTagMemoryFunc = checkSanitizerInterfaceFunction(M.getOrInsertFunction( - "__hwasan_tag_memory", IRB.getVoidTy(), IntptrTy, Int8Ty, IntptrTy)); + "__hwasan_tag_memory", IRB.getVoidTy(), Int8PtrTy, Int8Ty, IntptrTy)); HwasanGenerateTagFunc = checkSanitizerInterfaceFunction( M.getOrInsertFunction("__hwasan_generate_tag", Int8Ty)); if (Mapping.InGlobal) ShadowGlobal = M.getOrInsertGlobal("__hwasan_shadow", ArrayType::get(IRB.getInt8Ty(), 0)); + + const std::string MemIntrinCallbackPrefix = + CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix; + HWAsanMemmove = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); + HWAsanMemcpy = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memcpy", IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy)); + HWAsanMemset = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + MemIntrinCallbackPrefix + "memset", IRB.getInt8PtrTy(), + IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy)); + + HwasanThreadEnterFunc = checkSanitizerInterfaceFunction( + M.getOrInsertFunction("__hwasan_thread_enter", IRB.getVoidTy())); } -void HWAddressSanitizer::maybeInsertDynamicShadowAtFunctionEntry(Function &F) { +Value *HWAddressSanitizer::getDynamicShadowNonTls(IRBuilder<> &IRB) { // Generate code only when dynamic addressing is needed. if (Mapping.Offset != kDynamicShadowSentinel) - return; + return nullptr; - IRBuilder<> IRB(&F.front().front()); if (Mapping.InGlobal) { // An empty inline asm with input reg == output reg. // An opaque pointer-to-int cast, basically. @@ -303,11 +409,12 @@ void HWAddressSanitizer::maybeInsertDynamicShadowAtFunctionEntry(Function &F) { FunctionType::get(IntptrTy, {ShadowGlobal->getType()}, false), StringRef(""), StringRef("=r,0"), /*hasSideEffects=*/false); - LocalDynamicShadow = IRB.CreateCall(Asm, {ShadowGlobal}, ".hwasan.shadow"); + return IRB.CreateCall(Asm, {ShadowGlobal}, ".hwasan.shadow"); } else { - Value *GlobalDynamicAddress = F.getParent()->getOrInsertGlobal( - kHwasanShadowMemoryDynamicAddress, IntptrTy); - LocalDynamicShadow = IRB.CreateLoad(GlobalDynamicAddress); + Value *GlobalDynamicAddress = + IRB.GetInsertBlock()->getParent()->getParent()->getOrInsertGlobal( + kHwasanShadowMemoryDynamicAddress, IntptrTy); + return IRB.CreateLoad(GlobalDynamicAddress); } } @@ -421,8 +528,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, IRB.getInt8Ty()); Value *AddrLong = untagPointer(IRB, PtrLong); Value *ShadowLong = memToShadow(AddrLong, PtrLong->getType(), IRB); - Value *MemTag = - IRB.CreateLoad(IRB.CreateIntToPtr(ShadowLong, IRB.getInt8PtrTy())); + Value *MemTag = IRB.CreateLoad(IRB.CreateIntToPtr(ShadowLong, Int8PtrTy)); Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag); int matchAllTag = ClMatchAllTag.getNumOccurrences() > 0 ? @@ -433,7 +539,7 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, TagMismatch = IRB.CreateAnd(TagMismatch, TagNotIgnored); } - TerminatorInst *CheckTerm = + Instruction *CheckTerm = SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, !Recover, MDBuilder(*C).createBranchWeights(1, 100000)); @@ -464,12 +570,36 @@ void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, IRB.CreateCall(Asm, PtrLong); } +void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) { + IRBuilder<> IRB(MI); + if (isa<MemTransferInst>(MI)) { + IRB.CreateCall( + isa<MemMoveInst>(MI) ? HWAsanMemmove : HWAsanMemcpy, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + } else if (isa<MemSetInst>(MI)) { + IRB.CreateCall( + HWAsanMemset, + {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()), + IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false), + IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)}); + } + MI->eraseFromParent(); +} + bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) { LLVM_DEBUG(dbgs() << "Instrumenting: " << *I << "\n"); bool IsWrite = false; unsigned Alignment = 0; uint64_t TypeSize = 0; Value *MaybeMask = nullptr; + + if (ClInstrumentMemIntrinsics && isa<MemIntrinsic>(I)) { + instrumentMemIntrinsic(cast<MemIntrinsic>(I)); + return true; + } + Value *Addr = isInterestingMemoryAccess(I, &IsWrite, &TypeSize, &Alignment, &MaybeMask); @@ -521,13 +651,13 @@ bool HWAddressSanitizer::tagAlloca(IRBuilder<> &IRB, AllocaInst *AI, Value *JustTag = IRB.CreateTrunc(Tag, IRB.getInt8Ty()); if (ClInstrumentWithCalls) { IRB.CreateCall(HwasanTagMemoryFunc, - {IRB.CreatePointerCast(AI, IntptrTy), JustTag, + {IRB.CreatePointerCast(AI, Int8PtrTy), JustTag, ConstantInt::get(IntptrTy, Size)}); } else { size_t ShadowSize = Size >> Mapping.Scale; Value *ShadowPtr = IRB.CreateIntToPtr( memToShadow(IRB.CreatePointerCast(AI, IntptrTy), AI->getType(), IRB), - IRB.getInt8PtrTy()); + Int8PtrTy); // If this memset is not inlined, it will be intercepted in the hwasan // runtime library. That's OK, because the interceptor skips the checks if // the address is in the shadow region. @@ -557,7 +687,7 @@ Value *HWAddressSanitizer::getNextTagWithCall(IRBuilder<> &IRB) { Value *HWAddressSanitizer::getStackBaseTag(IRBuilder<> &IRB) { if (ClGenerateTagsWithCalls) - return nullptr; + return getNextTagWithCall(IRB); // FIXME: use addressofreturnaddress (but implement it in aarch64 backend // first). Module *M = IRB.GetInsertBlock()->getParent()->getParent(); @@ -625,15 +755,141 @@ Value *HWAddressSanitizer::untagPointer(IRBuilder<> &IRB, Value *PtrLong) { return UntaggedPtrLong; } -bool HWAddressSanitizer::instrumentStack( - SmallVectorImpl<AllocaInst *> &Allocas, - SmallVectorImpl<Instruction *> &RetVec) { - Function *F = Allocas[0]->getParent()->getParent(); - Instruction *InsertPt = &*F->getEntryBlock().begin(); - IRBuilder<> IRB(InsertPt); +Value *HWAddressSanitizer::getHwasanThreadSlotPtr(IRBuilder<> &IRB, Type *Ty) { + Module *M = IRB.GetInsertBlock()->getParent()->getParent(); + if (TargetTriple.isAArch64() && TargetTriple.isAndroid()) { + // Android provides a fixed TLS slot for sanitizers. See TLS_SLOT_SANITIZER + // in Bionic's libc/private/bionic_tls.h. + Function *ThreadPointerFunc = + Intrinsic::getDeclaration(M, Intrinsic::thread_pointer); + Value *SlotPtr = IRB.CreatePointerCast( + IRB.CreateConstGEP1_32(IRB.CreateCall(ThreadPointerFunc), 0x30), + Ty->getPointerTo(0)); + return SlotPtr; + } + if (ThreadPtrGlobal) + return ThreadPtrGlobal; + + + return nullptr; +} + +// Creates a string with a description of the stack frame (set of Allocas). +// The string is intended to be human readable. +// The current form is: Size1 Name1; Size2 Name2; ... +std::string +HWAddressSanitizer::createFrameString(ArrayRef<AllocaInst *> Allocas) { + std::ostringstream Descr; + for (auto AI : Allocas) + Descr << getAllocaSizeInBytes(*AI) << " " << AI->getName().str() << "; "; + return Descr.str(); +} - Value *StackTag = getStackBaseTag(IRB); +// Creates a global in the frame section which consists of two pointers: +// the function PC and the frame string constant. +void HWAddressSanitizer::createFrameGlobal(Function &F, + const std::string &FrameString) { + Module &M = *F.getParent(); + auto DescrGV = createPrivateGlobalForString(M, FrameString, true); + auto PtrPairTy = StructType::get(F.getType(), DescrGV->getType()); + auto GV = new GlobalVariable( + M, PtrPairTy, /*isConstantGlobal*/ true, GlobalVariable::PrivateLinkage, + ConstantStruct::get(PtrPairTy, (Constant *)&F, (Constant *)DescrGV), + "__hwasan"); + GV->setSection(getFrameSection()); + appendToCompilerUsed(M, GV); + // Put GV into the F's Comadat so that if F is deleted GV can be deleted too. + if (auto Comdat = + GetOrCreateFunctionComdat(F, TargetTriple, CurModuleUniqueId)) + GV->setComdat(Comdat); +} + +Value *HWAddressSanitizer::emitPrologue(IRBuilder<> &IRB, + bool WithFrameRecord) { + if (!Mapping.InTls) + return getDynamicShadowNonTls(IRB); + + Value *SlotPtr = getHwasanThreadSlotPtr(IRB, IntptrTy); + assert(SlotPtr); + + Instruction *ThreadLong = IRB.CreateLoad(SlotPtr); + + Function *F = IRB.GetInsertBlock()->getParent(); + if (F->getFnAttribute("hwasan-abi").getValueAsString() == "interceptor") { + Value *ThreadLongEqZero = + IRB.CreateICmpEQ(ThreadLong, ConstantInt::get(IntptrTy, 0)); + auto *Br = cast<BranchInst>(SplitBlockAndInsertIfThen( + ThreadLongEqZero, cast<Instruction>(ThreadLongEqZero)->getNextNode(), + false, MDBuilder(*C).createBranchWeights(1, 100000))); + + IRB.SetInsertPoint(Br); + // FIXME: This should call a new runtime function with a custom calling + // convention to avoid needing to spill all arguments here. + IRB.CreateCall(HwasanThreadEnterFunc); + LoadInst *ReloadThreadLong = IRB.CreateLoad(SlotPtr); + + IRB.SetInsertPoint(&*Br->getSuccessor(0)->begin()); + PHINode *ThreadLongPhi = IRB.CreatePHI(IntptrTy, 2); + ThreadLongPhi->addIncoming(ThreadLong, ThreadLong->getParent()); + ThreadLongPhi->addIncoming(ReloadThreadLong, ReloadThreadLong->getParent()); + ThreadLong = ThreadLongPhi; + } + + // Extract the address field from ThreadLong. Unnecessary on AArch64 with TBI. + Value *ThreadLongMaybeUntagged = + TargetTriple.isAArch64() ? ThreadLong : untagPointer(IRB, ThreadLong); + + if (WithFrameRecord) { + // Prepare ring buffer data. + auto PC = IRB.CreatePtrToInt(F, IntptrTy); + auto GetStackPointerFn = + Intrinsic::getDeclaration(F->getParent(), Intrinsic::frameaddress); + Value *SP = IRB.CreatePtrToInt( + IRB.CreateCall(GetStackPointerFn, + {Constant::getNullValue(IRB.getInt32Ty())}), + IntptrTy); + // Mix SP and PC. TODO: also add the tag to the mix. + // Assumptions: + // PC is 0x0000PPPPPPPPPPPP (48 bits are meaningful, others are zero) + // SP is 0xsssssssssssSSSS0 (4 lower bits are zero) + // We only really need ~20 lower non-zero bits (SSSS), so we mix like this: + // 0xSSSSPPPPPPPPPPPP + SP = IRB.CreateShl(SP, 44); + + // Store data to ring buffer. + Value *RecordPtr = + IRB.CreateIntToPtr(ThreadLongMaybeUntagged, IntptrTy->getPointerTo(0)); + IRB.CreateStore(IRB.CreateOr(PC, SP), RecordPtr); + + // Update the ring buffer. Top byte of ThreadLong defines the size of the + // buffer in pages, it must be a power of two, and the start of the buffer + // must be aligned by twice that much. Therefore wrap around of the ring + // buffer is simply Addr &= ~((ThreadLong >> 56) << 12). + // The use of AShr instead of LShr is due to + // https://bugs.llvm.org/show_bug.cgi?id=39030 + // Runtime library makes sure not to use the highest bit. + Value *WrapMask = IRB.CreateXor( + IRB.CreateShl(IRB.CreateAShr(ThreadLong, 56), 12, "", true, true), + ConstantInt::get(IntptrTy, (uint64_t)-1)); + Value *ThreadLongNew = IRB.CreateAnd( + IRB.CreateAdd(ThreadLong, ConstantInt::get(IntptrTy, 8)), WrapMask); + IRB.CreateStore(ThreadLongNew, SlotPtr); + } + // Get shadow base address by aligning RecordPtr up. + // Note: this is not correct if the pointer is already aligned. + // Runtime library will make sure this never happens. + Value *ShadowBase = IRB.CreateAdd( + IRB.CreateOr( + ThreadLongMaybeUntagged, + ConstantInt::get(IntptrTy, (1ULL << kShadowBaseAlignment) - 1)), + ConstantInt::get(IntptrTy, 1), "hwasan.shadow"); + return ShadowBase; +} + +bool HWAddressSanitizer::instrumentStack( + SmallVectorImpl<AllocaInst *> &Allocas, + SmallVectorImpl<Instruction *> &RetVec, Value *StackTag) { // Ideally, we want to calculate tagged stack base pointer, and rewrite all // alloca addresses using that. Unfortunately, offsets are not known yet // (unless we use ASan-style mega-alloca). Instead we keep the base tag in a @@ -641,7 +897,7 @@ bool HWAddressSanitizer::instrumentStack( // This generates one extra instruction per alloca use. for (unsigned N = 0; N < Allocas.size(); ++N) { auto *AI = Allocas[N]; - IRB.SetInsertPoint(AI->getNextNode()); + IRBuilder<> IRB(AI->getNextNode()); // Replace uses of the alloca with tagged address. Value *Tag = getAllocaTag(IRB, StackTag, AI, N); @@ -696,12 +952,6 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { LLVM_DEBUG(dbgs() << "Function: " << F.getName() << "\n"); - initializeCallbacks(*F.getParent()); - - assert(!LocalDynamicShadow); - maybeInsertDynamicShadowAtFunctionEntry(F); - - bool Changed = false; SmallVector<Instruction*, 16> ToInstrument; SmallVector<AllocaInst*, 8> AllocasToInstrument; SmallVector<Instruction*, 8> RetVec; @@ -734,8 +984,28 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { } } - if (!AllocasToInstrument.empty()) - Changed |= instrumentStack(AllocasToInstrument, RetVec); + if (AllocasToInstrument.empty() && ToInstrument.empty()) + return false; + + if (ClCreateFrameDescriptions && !AllocasToInstrument.empty()) + createFrameGlobal(F, createFrameString(AllocasToInstrument)); + + initializeCallbacks(*F.getParent()); + + assert(!LocalDynamicShadow); + + Instruction *InsertPt = &*F.getEntryBlock().begin(); + IRBuilder<> EntryIRB(InsertPt); + LocalDynamicShadow = emitPrologue(EntryIRB, + /*WithFrameRecord*/ ClRecordStackHistory && + !AllocasToInstrument.empty()); + + bool Changed = false; + if (!AllocasToInstrument.empty()) { + Value *StackTag = + ClGenerateTagsWithCalls ? nullptr : getStackBaseTag(EntryIRB); + Changed |= instrumentStack(AllocasToInstrument, RetVec, StackTag); + } for (auto Inst : ToInstrument) Changed |= instrumentMemAccess(Inst); @@ -746,18 +1016,26 @@ bool HWAddressSanitizer::runOnFunction(Function &F) { } void HWAddressSanitizer::ShadowMapping::init(Triple &TargetTriple) { - const bool IsAndroid = TargetTriple.isAndroid(); - const bool IsAndroidWithIfuncSupport = - IsAndroid && !TargetTriple.isAndroidVersionLT(21); - Scale = kDefaultShadowScale; - - if (ClEnableKhwasan || ClInstrumentWithCalls || !IsAndroidWithIfuncSupport) + if (ClMappingOffset.getNumOccurrences() > 0) { + InGlobal = false; + InTls = false; + Offset = ClMappingOffset; + } else if (ClEnableKhwasan || ClInstrumentWithCalls) { + InGlobal = false; + InTls = false; Offset = 0; - else + } else if (ClWithIfunc) { + InGlobal = true; + InTls = false; Offset = kDynamicShadowSentinel; - if (ClMappingOffset.getNumOccurrences() > 0) - Offset = ClMappingOffset; - - InGlobal = IsAndroidWithIfuncSupport; + } else if (ClWithTls) { + InGlobal = false; + InTls = true; + Offset = kDynamicShadowSentinel; + } else { + InGlobal = false; + InTls = false; + Offset = kDynamicShadowSentinel; + } } diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 27fb0e4393af..58436c8560ad 100644 --- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -19,7 +19,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" -#include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/IR/Attributes.h" @@ -41,8 +41,8 @@ #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Error.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" @@ -269,7 +269,8 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( LLVM_DEBUG(dbgs() << " Not promote: Cannot find the target\n"); ORE.emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "UnableToFindTarget", Inst) - << "Cannot promote indirect call: target not found"; + << "Cannot promote indirect call: target with md5sum " + << ore::NV("target md5sum", Target) << " not found"; }); break; } @@ -351,7 +352,7 @@ uint32_t ICallPromotionFunc::tryToPromote( bool ICallPromotionFunc::processFunction(ProfileSummaryInfo *PSI) { bool Changed = false; ICallPromotionAnalysis ICallAnalysis; - for (auto &I : findIndirectCallSites(F)) { + for (auto &I : findIndirectCalls(F)) { uint32_t NumVals, NumCandidates; uint64_t TotalCount; auto ICallProfDataRef = ICallAnalysis.getPromotionCandidatesForInstruction( @@ -426,7 +427,7 @@ static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) { ProfileSummaryInfo *PSI = - getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); // Command-line option has the priority for InLTO. return promoteIndirectCalls(M, PSI, InLTO | ICPLTOMode, diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp index 4d5dfb0aa66b..15b94388cbe5 100644 --- a/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -96,6 +96,11 @@ cl::opt<double> NumCountersPerValueSite( // is usually smaller than 2. cl::init(1.0)); +cl::opt<bool> AtomicCounterUpdateAll( + "instrprof-atomic-counter-update-all", cl::ZeroOrMore, + cl::desc("Make all profile counter updates atomic (for testing only)"), + cl::init(false)); + cl::opt<bool> AtomicCounterUpdatePromoted( "atomic-counter-update-promoted", cl::ZeroOrMore, cl::desc("Do counter update using atomic fetch add " @@ -597,12 +602,17 @@ void InstrProfiling::lowerIncrement(InstrProfIncrementInst *Inc) { IRBuilder<> Builder(Inc); uint64_t Index = Inc->getIndex()->getZExtValue(); Value *Addr = Builder.CreateConstInBoundsGEP2_64(Counters, 0, Index); - Value *Load = Builder.CreateLoad(Addr, "pgocount"); - auto *Count = Builder.CreateAdd(Load, Inc->getStep()); - auto *Store = Builder.CreateStore(Count, Addr); - Inc->replaceAllUsesWith(Store); - if (isCounterPromotionEnabled()) - PromotionCandidates.emplace_back(cast<Instruction>(Load), Store); + + if (Options.Atomic || AtomicCounterUpdateAll) { + Builder.CreateAtomicRMW(AtomicRMWInst::Add, Addr, Inc->getStep(), + AtomicOrdering::Monotonic); + } else { + Value *Load = Builder.CreateLoad(Addr, "pgocount"); + auto *Count = Builder.CreateAdd(Load, Inc->getStep()); + auto *Store = Builder.CreateStore(Count, Addr); + if (isCounterPromotionEnabled()) + PromotionCandidates.emplace_back(cast<Instruction>(Load), Store); + } Inc->eraseFromParent(); } @@ -691,6 +701,7 @@ static bool needsRuntimeRegistrationOfSectionRange(const Module &M) { // Use linker script magic to get data/cnts/name start/end. if (Triple(M.getTargetTriple()).isOSLinux() || Triple(M.getTargetTriple()).isOSFreeBSD() || + Triple(M.getTargetTriple()).isOSNetBSD() || Triple(M.getTargetTriple()).isOSFuchsia() || Triple(M.getTargetTriple()).isPS4CPU()) return false; diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp index 8e9eea96ced7..c3e323613c70 100644 --- a/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -14,7 +14,9 @@ #include "llvm/Transforms/Instrumentation.h" #include "llvm-c/Initialization.h" +#include "llvm/ADT/Triple.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" #include "llvm/InitializePasses.h" #include "llvm/PassRegistry.h" @@ -53,21 +55,65 @@ BasicBlock::iterator llvm::PrepareToSplitEntryBlock(BasicBlock &BB, return IP; } +// Create a constant for Str so that we can pass it to the run-time lib. +GlobalVariable *llvm::createPrivateGlobalForString(Module &M, StringRef Str, + bool AllowMerging, + const char *NamePrefix) { + Constant *StrConst = ConstantDataArray::getString(M.getContext(), Str); + // We use private linkage for module-local strings. If they can be merged + // with another one, we set the unnamed_addr attribute. + GlobalVariable *GV = + new GlobalVariable(M, StrConst->getType(), true, + GlobalValue::PrivateLinkage, StrConst, NamePrefix); + if (AllowMerging) + GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + GV->setAlignment(1); // Strings may not be merged w/o setting align 1. + return GV; +} + +Comdat *llvm::GetOrCreateFunctionComdat(Function &F, Triple &T, + const std::string &ModuleId) { + if (auto Comdat = F.getComdat()) return Comdat; + assert(F.hasName()); + Module *M = F.getParent(); + std::string Name = F.getName(); + + // Make a unique comdat name for internal linkage things on ELF. On COFF, the + // name of the comdat group identifies the leader symbol of the comdat group. + // The linkage of the leader symbol is considered during comdat resolution, + // and internal symbols with the same name from different objects will not be + // merged. + if (T.isOSBinFormatELF() && F.hasLocalLinkage()) { + if (ModuleId.empty()) + return nullptr; + Name += ModuleId; + } + + // Make a new comdat for the function. Use the "no duplicates" selection kind + // for non-weak symbols if the object file format supports it. + Comdat *C = M->getOrInsertComdat(Name); + if (T.isOSBinFormatCOFF() && !F.isWeakForLinker()) + C->setSelectionKind(Comdat::NoDuplicates); + F.setComdat(C); + return C; +} + /// initializeInstrumentation - Initialize all passes in the TransformUtils /// library. void llvm::initializeInstrumentation(PassRegistry &Registry) { initializeAddressSanitizerPass(Registry); initializeAddressSanitizerModulePass(Registry); initializeBoundsCheckingLegacyPassPass(Registry); + initializeControlHeightReductionLegacyPassPass(Registry); initializeGCOVProfilerLegacyPassPass(Registry); initializePGOInstrumentationGenLegacyPassPass(Registry); initializePGOInstrumentationUseLegacyPassPass(Registry); initializePGOIndirectCallPromotionLegacyPassPass(Registry); initializePGOMemOPSizeOptLegacyPassPass(Registry); initializeInstrProfilingLegacyPassPass(Registry); - initializeMemorySanitizerPass(Registry); + initializeMemorySanitizerLegacyPassPass(Registry); initializeHWAddressSanitizerPass(Registry); - initializeThreadSanitizerPass(Registry); + initializeThreadSanitizerLegacyPassPass(Registry); initializeSanitizerCoverageModulePass(Registry); initializeDataFlowSanitizerPass(Registry); initializeEfficiencySanitizerPass(Registry); diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 4bcef6972786..e6573af2077d 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -89,9 +89,58 @@ /// implementation ignores the load aspect of CAS/RMW, always returning a clean /// value. It implements the store part as a simple atomic store by storing a /// clean shadow. -// +/// +/// Instrumenting inline assembly. +/// +/// For inline assembly code LLVM has little idea about which memory locations +/// become initialized depending on the arguments. It can be possible to figure +/// out which arguments are meant to point to inputs and outputs, but the +/// actual semantics can be only visible at runtime. In the Linux kernel it's +/// also possible that the arguments only indicate the offset for a base taken +/// from a segment register, so it's dangerous to treat any asm() arguments as +/// pointers. We take a conservative approach generating calls to +/// __msan_instrument_asm_store(ptr, size) +/// , which defer the memory unpoisoning to the runtime library. +/// The latter can perform more complex address checks to figure out whether +/// it's safe to touch the shadow memory. +/// Like with atomic operations, we call __msan_instrument_asm_store() before +/// the assembly call, so that changes to the shadow memory will be seen by +/// other threads together with main memory initialization. +/// +/// KernelMemorySanitizer (KMSAN) implementation. +/// +/// The major differences between KMSAN and MSan instrumentation are: +/// - KMSAN always tracks the origins and implies msan-keep-going=true; +/// - KMSAN allocates shadow and origin memory for each page separately, so +/// there are no explicit accesses to shadow and origin in the +/// instrumentation. +/// Shadow and origin values for a particular X-byte memory location +/// (X=1,2,4,8) are accessed through pointers obtained via the +/// __msan_metadata_ptr_for_load_X(ptr) +/// __msan_metadata_ptr_for_store_X(ptr) +/// functions. The corresponding functions check that the X-byte accesses +/// are possible and returns the pointers to shadow and origin memory. +/// Arbitrary sized accesses are handled with: +/// __msan_metadata_ptr_for_load_n(ptr, size) +/// __msan_metadata_ptr_for_store_n(ptr, size); +/// - TLS variables are stored in a single per-task struct. A call to a +/// function __msan_get_context_state() returning a pointer to that struct +/// is inserted into every instrumented function before the entry block; +/// - __msan_warning() takes a 32-bit origin parameter; +/// - local variables are poisoned with __msan_poison_alloca() upon function +/// entry and unpoisoned with __msan_unpoison_alloca() before leaving the +/// function; +/// - the pass doesn't declare any global variables or add global constructors +/// to the translation unit. +/// +/// Also, KMSAN currently ignores uninitialized memory passed into inline asm +/// calls, making sure we're on the safe side wrt. possible false positives. +/// +/// KernelMemorySanitizer only supports X86_64 at the moment. +/// //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Instrumentation/MemorySanitizer.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DepthFirstIterator.h" @@ -101,7 +150,6 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -139,6 +187,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include <algorithm> #include <cassert> @@ -206,10 +255,13 @@ static cl::opt<bool> ClHandleICmpExact("msan-handle-icmp-exact", // passed into an assembly call. Note that this may cause false positives. // Because it's impossible to figure out the array sizes, we can only unpoison // the first sizeof(type) bytes for each type* pointer. +// The instrumentation is only enabled in KMSAN builds, and only if +// -msan-handle-asm-conservative is on. This is done because we may want to +// quickly disable assembly instrumentation when it breaks. static cl::opt<bool> ClHandleAsmConservative( "msan-handle-asm-conservative", cl::desc("conservative handling of inline assembly"), cl::Hidden, - cl::init(false)); + cl::init(true)); // This flag controls whether we check the shadow of the address // operand of load or store. Such bugs are very rare, since load from @@ -233,6 +285,11 @@ static cl::opt<int> ClInstrumentationWithCallThreshold( "inline checks (-1 means never use callbacks)."), cl::Hidden, cl::init(3500)); +static cl::opt<bool> + ClEnableKmsan("msan-kernel", + cl::desc("Enable KernelMemorySanitizer instrumentation"), + cl::Hidden, cl::init(false)); + // This is an experiment to enable handling of cases where shadow is a non-zero // compile-time constant. For some unexplainable reason they were silently // ignored in the instrumentation. @@ -264,7 +321,6 @@ static cl::opt<unsigned long long> ClOriginBase("msan-origin-base", cl::desc("Define custom MSan OriginBase"), cl::Hidden, cl::init(0)); -static const char *const kMsanModuleCtorName = "msan.module_ctor"; static const char *const kMsanInitName = "__msan_init"; namespace { @@ -390,29 +446,35 @@ static const PlatformMemoryMapParams NetBSD_X86_MemoryMapParams = { namespace { -/// An instrumentation pass implementing detection of uninitialized -/// reads. +/// Instrument functions of a module to detect uninitialized reads. /// -/// MemorySanitizer: instrument the code in module to find -/// uninitialized reads. -class MemorySanitizer : public FunctionPass { +/// Instantiating MemorySanitizer inserts the msan runtime library API function +/// declarations into the module if they don't exist already. Instantiating +/// ensures the __msan_init function is in the list of global constructors for +/// the module. +class MemorySanitizer { public: - // Pass identification, replacement for typeid. - static char ID; - - MemorySanitizer(int TrackOrigins = 0, bool Recover = false) - : FunctionPass(ID), - TrackOrigins(std::max(TrackOrigins, (int)ClTrackOrigins)), - Recover(Recover || ClKeepGoing) {} - - StringRef getPassName() const override { return "MemorySanitizer"; } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); + MemorySanitizer(Module &M, int TrackOrigins = 0, bool Recover = false, + bool EnableKmsan = false) { + this->CompileKernel = + ClEnableKmsan.getNumOccurrences() > 0 ? ClEnableKmsan : EnableKmsan; + if (ClTrackOrigins.getNumOccurrences() > 0) + this->TrackOrigins = ClTrackOrigins; + else + this->TrackOrigins = this->CompileKernel ? 2 : TrackOrigins; + this->Recover = ClKeepGoing.getNumOccurrences() > 0 + ? ClKeepGoing + : (this->CompileKernel | Recover); + initializeModule(M); } - bool runOnFunction(Function &F) override; - bool doInitialization(Module &M) override; + // MSan cannot be moved or copied because of MapParams. + MemorySanitizer(MemorySanitizer &&) = delete; + MemorySanitizer &operator=(MemorySanitizer &&) = delete; + MemorySanitizer(const MemorySanitizer &) = delete; + MemorySanitizer &operator=(const MemorySanitizer &) = delete; + + bool sanitizeFunction(Function &F, TargetLibraryInfo &TLI); private: friend struct MemorySanitizerVisitor; @@ -421,9 +483,13 @@ private: friend struct VarArgAArch64Helper; friend struct VarArgPowerPC64Helper; + void initializeModule(Module &M); void initializeCallbacks(Module &M); + void createKernelApi(Module &M); void createUserspaceApi(Module &M); + /// True if we're compiling the Linux kernel. + bool CompileKernel; /// Track origins (allocation points) of uninitialized values. int TrackOrigins; bool Recover; @@ -432,29 +498,39 @@ private: Type *IntptrTy; Type *OriginTy; + // XxxTLS variables represent the per-thread state in MSan and per-task state + // in KMSAN. + // For the userspace these point to thread-local globals. In the kernel land + // they point to the members of a per-task struct obtained via a call to + // __msan_get_context_state(). + /// Thread-local shadow storage for function parameters. - GlobalVariable *ParamTLS; + Value *ParamTLS; /// Thread-local origin storage for function parameters. - GlobalVariable *ParamOriginTLS; + Value *ParamOriginTLS; /// Thread-local shadow storage for function return value. - GlobalVariable *RetvalTLS; + Value *RetvalTLS; /// Thread-local origin storage for function return value. - GlobalVariable *RetvalOriginTLS; + Value *RetvalOriginTLS; /// Thread-local shadow storage for in-register va_arg function /// parameters (x86_64-specific). - GlobalVariable *VAArgTLS; + Value *VAArgTLS; + + /// Thread-local shadow storage for in-register va_arg function + /// parameters (x86_64-specific). + Value *VAArgOriginTLS; /// Thread-local shadow storage for va_arg overflow area /// (x86_64-specific). - GlobalVariable *VAArgOverflowSizeTLS; + Value *VAArgOverflowSizeTLS; /// Thread-local space used to pass origin value to the UMR reporting /// function. - GlobalVariable *OriginTLS; + Value *OriginTLS; /// Are the instrumentation callbacks set up? bool CallbacksInitialized = false; @@ -480,6 +556,22 @@ private: /// MSan runtime replacements for memmove, memcpy and memset. Value *MemmoveFn, *MemcpyFn, *MemsetFn; + /// KMSAN callback for task-local function argument shadow. + Value *MsanGetContextStateFn; + + /// Functions for poisoning/unpoisoning local variables + Value *MsanPoisonAllocaFn, *MsanUnpoisonAllocaFn; + + /// Each of the MsanMetadataPtrXxx functions returns a pair of shadow/origin + /// pointers. + Value *MsanMetadataPtrForLoadN, *MsanMetadataPtrForStoreN; + Value *MsanMetadataPtrForLoad_1_8[4]; + Value *MsanMetadataPtrForStore_1_8[4]; + Value *MsanInstrumentAsmStoreFn; + + /// Helper to choose between different MsanMetadataPtrXxx(). + Value *getKmsanShadowOriginAccessFn(bool isStore, int size); + /// Memory map parameters used in application-to-shadow calculation. const MemoryMapParams *MapParams; @@ -494,24 +586,61 @@ private: /// An empty volatile inline asm that prevents callback merge. InlineAsm *EmptyAsm; +}; + +/// A legacy function pass for msan instrumentation. +/// +/// Instruments functions to detect unitialized reads. +struct MemorySanitizerLegacyPass : public FunctionPass { + // Pass identification, replacement for typeid. + static char ID; - Function *MsanCtorFunction; + MemorySanitizerLegacyPass(int TrackOrigins = 0, bool Recover = false, + bool EnableKmsan = false) + : FunctionPass(ID), TrackOrigins(TrackOrigins), Recover(Recover), + EnableKmsan(EnableKmsan) {} + StringRef getPassName() const override { return "MemorySanitizerLegacyPass"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + return MSan->sanitizeFunction( + F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI()); + } + bool doInitialization(Module &M) override; + + Optional<MemorySanitizer> MSan; + int TrackOrigins; + bool Recover; + bool EnableKmsan; }; } // end anonymous namespace -char MemorySanitizer::ID = 0; +PreservedAnalyses MemorySanitizerPass::run(Function &F, + FunctionAnalysisManager &FAM) { + MemorySanitizer Msan(*F.getParent(), TrackOrigins, Recover, EnableKmsan); + if (Msan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F))) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} -INITIALIZE_PASS_BEGIN( - MemorySanitizer, "msan", - "MemorySanitizer: detects uninitialized reads.", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END( - MemorySanitizer, "msan", - "MemorySanitizer: detects uninitialized reads.", false, false) +char MemorySanitizerLegacyPass::ID = 0; -FunctionPass *llvm::createMemorySanitizerPass(int TrackOrigins, bool Recover) { - return new MemorySanitizer(TrackOrigins, Recover); +INITIALIZE_PASS_BEGIN(MemorySanitizerLegacyPass, "msan", + "MemorySanitizer: detects uninitialized reads.", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(MemorySanitizerLegacyPass, "msan", + "MemorySanitizer: detects uninitialized reads.", false, + false) + +FunctionPass *llvm::createMemorySanitizerLegacyPassPass(int TrackOrigins, + bool Recover, + bool CompileKernel) { + return new MemorySanitizerLegacyPass(TrackOrigins, Recover, CompileKernel); } /// Create a non-const global initialized with the given string. @@ -526,6 +655,76 @@ static GlobalVariable *createPrivateNonConstGlobalForString(Module &M, GlobalValue::PrivateLinkage, StrConst, ""); } +/// Create KMSAN API callbacks. +void MemorySanitizer::createKernelApi(Module &M) { + IRBuilder<> IRB(*C); + + // These will be initialized in insertKmsanPrologue(). + RetvalTLS = nullptr; + RetvalOriginTLS = nullptr; + ParamTLS = nullptr; + ParamOriginTLS = nullptr; + VAArgTLS = nullptr; + VAArgOriginTLS = nullptr; + VAArgOverflowSizeTLS = nullptr; + // OriginTLS is unused in the kernel. + OriginTLS = nullptr; + + // __msan_warning() in the kernel takes an origin. + WarningFn = M.getOrInsertFunction("__msan_warning", IRB.getVoidTy(), + IRB.getInt32Ty()); + // Requests the per-task context state (kmsan_context_state*) from the + // runtime library. + MsanGetContextStateFn = M.getOrInsertFunction( + "__msan_get_context_state", + PointerType::get( + StructType::get(ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), + ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), + ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), + ArrayType::get(IRB.getInt64Ty(), + kParamTLSSize / 8), /* va_arg_origin */ + IRB.getInt64Ty(), + ArrayType::get(OriginTy, kParamTLSSize / 4), OriginTy, + OriginTy), + 0)); + + Type *RetTy = StructType::get(PointerType::get(IRB.getInt8Ty(), 0), + PointerType::get(IRB.getInt32Ty(), 0)); + + for (int ind = 0, size = 1; ind < 4; ind++, size <<= 1) { + std::string name_load = + "__msan_metadata_ptr_for_load_" + std::to_string(size); + std::string name_store = + "__msan_metadata_ptr_for_store_" + std::to_string(size); + MsanMetadataPtrForLoad_1_8[ind] = M.getOrInsertFunction( + name_load, RetTy, PointerType::get(IRB.getInt8Ty(), 0)); + MsanMetadataPtrForStore_1_8[ind] = M.getOrInsertFunction( + name_store, RetTy, PointerType::get(IRB.getInt8Ty(), 0)); + } + + MsanMetadataPtrForLoadN = M.getOrInsertFunction( + "__msan_metadata_ptr_for_load_n", RetTy, + PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty()); + MsanMetadataPtrForStoreN = M.getOrInsertFunction( + "__msan_metadata_ptr_for_store_n", RetTy, + PointerType::get(IRB.getInt8Ty(), 0), IRB.getInt64Ty()); + + // Functions for poisoning and unpoisoning memory. + MsanPoisonAllocaFn = + M.getOrInsertFunction("__msan_poison_alloca", IRB.getVoidTy(), + IRB.getInt8PtrTy(), IntptrTy, IRB.getInt8PtrTy()); + MsanUnpoisonAllocaFn = M.getOrInsertFunction( + "__msan_unpoison_alloca", IRB.getVoidTy(), IRB.getInt8PtrTy(), IntptrTy); +} + +static Constant *getOrInsertGlobal(Module &M, StringRef Name, Type *Ty) { + return M.getOrInsertGlobal(Name, Ty, [&] { + return new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, + nullptr, Name, nullptr, + GlobalVariable::InitialExecTLSModel); + }); +} + /// Insert declarations for userspace-specific functions and globals. void MemorySanitizer::createUserspaceApi(Module &M) { IRBuilder<> IRB(*C); @@ -537,36 +736,31 @@ void MemorySanitizer::createUserspaceApi(Module &M) { WarningFn = M.getOrInsertFunction(WarningFnName, IRB.getVoidTy()); // Create the global TLS variables. - RetvalTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_retval_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - - RetvalOriginTLS = new GlobalVariable( - M, OriginTy, false, GlobalVariable::ExternalLinkage, nullptr, - "__msan_retval_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel); - - ParamTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_param_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - - ParamOriginTLS = new GlobalVariable( - M, ArrayType::get(OriginTy, kParamTLSSize / 4), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_param_origin_tls", - nullptr, GlobalVariable::InitialExecTLSModel); - - VAArgTLS = new GlobalVariable( - M, ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8), false, - GlobalVariable::ExternalLinkage, nullptr, "__msan_va_arg_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - VAArgOverflowSizeTLS = new GlobalVariable( - M, IRB.getInt64Ty(), false, GlobalVariable::ExternalLinkage, nullptr, - "__msan_va_arg_overflow_size_tls", nullptr, - GlobalVariable::InitialExecTLSModel); - OriginTLS = new GlobalVariable( - M, IRB.getInt32Ty(), false, GlobalVariable::ExternalLinkage, nullptr, - "__msan_origin_tls", nullptr, GlobalVariable::InitialExecTLSModel); + RetvalTLS = + getOrInsertGlobal(M, "__msan_retval_tls", + ArrayType::get(IRB.getInt64Ty(), kRetvalTLSSize / 8)); + + RetvalOriginTLS = getOrInsertGlobal(M, "__msan_retval_origin_tls", OriginTy); + + ParamTLS = + getOrInsertGlobal(M, "__msan_param_tls", + ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8)); + + ParamOriginTLS = + getOrInsertGlobal(M, "__msan_param_origin_tls", + ArrayType::get(OriginTy, kParamTLSSize / 4)); + + VAArgTLS = + getOrInsertGlobal(M, "__msan_va_arg_tls", + ArrayType::get(IRB.getInt64Ty(), kParamTLSSize / 8)); + + VAArgOriginTLS = + getOrInsertGlobal(M, "__msan_va_arg_origin_tls", + ArrayType::get(OriginTy, kParamTLSSize / 4)); + + VAArgOverflowSizeTLS = + getOrInsertGlobal(M, "__msan_va_arg_overflow_size_tls", IRB.getInt64Ty()); + OriginTLS = getOrInsertGlobal(M, "__msan_origin_tls", IRB.getInt32Ty()); for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; AccessSizeIndex++) { @@ -615,14 +809,37 @@ void MemorySanitizer::initializeCallbacks(Module &M) { StringRef(""), StringRef(""), /*hasSideEffects=*/true); - createUserspaceApi(M); + MsanInstrumentAsmStoreFn = + M.getOrInsertFunction("__msan_instrument_asm_store", IRB.getVoidTy(), + PointerType::get(IRB.getInt8Ty(), 0), IntptrTy); + + if (CompileKernel) { + createKernelApi(M); + } else { + createUserspaceApi(M); + } CallbacksInitialized = true; } +Value *MemorySanitizer::getKmsanShadowOriginAccessFn(bool isStore, int size) { + Value **Fns = + isStore ? MsanMetadataPtrForStore_1_8 : MsanMetadataPtrForLoad_1_8; + switch (size) { + case 1: + return Fns[0]; + case 2: + return Fns[1]; + case 4: + return Fns[2]; + case 8: + return Fns[3]; + default: + return nullptr; + } +} + /// Module-level initialization. -/// -/// inserts a call to __msan_init to the module's constructor list. -bool MemorySanitizer::doInitialization(Module &M) { +void MemorySanitizer::initializeModule(Module &M) { auto &DL = M.getDataLayout(); bool ShadowPassed = ClShadowBase.getNumOccurrences() > 0; @@ -695,27 +912,27 @@ bool MemorySanitizer::doInitialization(Module &M) { ColdCallWeights = MDBuilder(*C).createBranchWeights(1, 1000); OriginStoreWeights = MDBuilder(*C).createBranchWeights(1, 1000); - std::tie(MsanCtorFunction, std::ignore) = - createSanitizerCtorAndInitFunctions(M, kMsanModuleCtorName, kMsanInitName, - /*InitArgTypes=*/{}, - /*InitArgs=*/{}); - if (ClWithComdat) { - Comdat *MsanCtorComdat = M.getOrInsertComdat(kMsanModuleCtorName); - MsanCtorFunction->setComdat(MsanCtorComdat); - appendToGlobalCtors(M, MsanCtorFunction, 0, MsanCtorFunction); - } else { - appendToGlobalCtors(M, MsanCtorFunction, 0); - } - - - if (TrackOrigins) - new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage, - IRB.getInt32(TrackOrigins), "__msan_track_origins"); - - if (Recover) - new GlobalVariable(M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage, - IRB.getInt32(Recover), "__msan_keep_going"); + if (!CompileKernel) { + getOrCreateInitFunction(M, kMsanInitName); + + if (TrackOrigins) + M.getOrInsertGlobal("__msan_track_origins", IRB.getInt32Ty(), [&] { + return new GlobalVariable( + M, IRB.getInt32Ty(), true, GlobalValue::WeakODRLinkage, + IRB.getInt32(TrackOrigins), "__msan_track_origins"); + }); + + if (Recover) + M.getOrInsertGlobal("__msan_keep_going", IRB.getInt32Ty(), [&] { + return new GlobalVariable(M, IRB.getInt32Ty(), true, + GlobalValue::WeakODRLinkage, + IRB.getInt32(Recover), "__msan_keep_going"); + }); +} +} +bool MemorySanitizerLegacyPass::doInitialization(Module &M) { + MSan.emplace(M, TrackOrigins, Recover, EnableKmsan); return true; } @@ -796,8 +1013,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { SmallVector<ShadowOriginAndInsertPoint, 16> InstrumentationList; SmallVector<StoreInst *, 16> StoreList; - MemorySanitizerVisitor(Function &F, MemorySanitizer &MS) - : F(F), MS(MS), VAHelper(CreateVarArgHelper(F, MS, *this)) { + MemorySanitizerVisitor(Function &F, MemorySanitizer &MS, + const TargetLibraryInfo &TLI) + : F(F), MS(MS), VAHelper(CreateVarArgHelper(F, MS, *this)), TLI(&TLI) { bool SanitizeFunction = F.hasFnAttribute(Attribute::SanitizeMemory); InsertChecks = SanitizeFunction; PropagateShadow = SanitizeFunction; @@ -806,10 +1024,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // FIXME: Consider using SpecialCaseList to specify a list of functions that // must always return fully initialized values. For now, we hardcode "main". CheckReturnValue = SanitizeFunction && (F.getName() == "main"); - TLI = &MS.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); MS.initializeCallbacks(*F.getParent()); - ActualFnStart = &F.getEntryBlock(); + if (MS.CompileKernel) + ActualFnStart = insertKmsanPrologue(F); + else + ActualFnStart = &F.getEntryBlock(); LLVM_DEBUG(if (!InsertChecks) dbgs() << "MemorySanitizer is not inserting checks into '" @@ -883,7 +1103,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); - if (AsCall && SizeIndex < kNumberOfAccessSizes) { + if (AsCall && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { Value *Fn = MS.MaybeStoreOriginFn[SizeIndex]; Value *ConvertedShadow2 = IRB.CreateZExt( ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); @@ -932,10 +1152,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void insertWarningFn(IRBuilder<> &IRB, Value *Origin) { if (!Origin) Origin = (Value *)IRB.getInt32(0); - if (MS.TrackOrigins) { - IRB.CreateStore(Origin, MS.OriginTLS); + if (MS.CompileKernel) { + IRB.CreateCall(MS.WarningFn, Origin); + } else { + if (MS.TrackOrigins) { + IRB.CreateStore(Origin, MS.OriginTLS); + } + IRB.CreateCall(MS.WarningFn, {}); } - IRB.CreateCall(MS.WarningFn, {}); IRB.CreateCall(MS.EmptyAsm, {}); // FIXME: Insert UnreachableInst if !MS.Recover? // This may invalidate some of the following checks and needs to be done @@ -961,7 +1185,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { unsigned TypeSizeInBits = DL.getTypeSizeInBits(ConvertedShadow->getType()); unsigned SizeIndex = TypeSizeToSizeIndex(TypeSizeInBits); - if (AsCall && SizeIndex < kNumberOfAccessSizes) { + if (AsCall && SizeIndex < kNumberOfAccessSizes && !MS.CompileKernel) { Value *Fn = MS.MaybeWarningFn[SizeIndex]; Value *ConvertedShadow2 = IRB.CreateZExt(ConvertedShadow, IRB.getIntNTy(8 * (1 << SizeIndex))); @@ -991,6 +1215,29 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { LLVM_DEBUG(dbgs() << "DONE:\n" << F); } + BasicBlock *insertKmsanPrologue(Function &F) { + BasicBlock *ret = + SplitBlock(&F.getEntryBlock(), F.getEntryBlock().getFirstNonPHI()); + IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI()); + Value *ContextState = IRB.CreateCall(MS.MsanGetContextStateFn, {}); + Constant *Zero = IRB.getInt32(0); + MS.ParamTLS = + IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(0)}, "param_shadow"); + MS.RetvalTLS = + IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(1)}, "retval_shadow"); + MS.VAArgTLS = + IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(2)}, "va_arg_shadow"); + MS.VAArgOriginTLS = + IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(3)}, "va_arg_origin"); + MS.VAArgOverflowSizeTLS = IRB.CreateGEP( + ContextState, {Zero, IRB.getInt32(4)}, "va_arg_overflow_size"); + MS.ParamOriginTLS = + IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(5)}, "param_origin"); + MS.RetvalOriginTLS = + IRB.CreateGEP(ContextState, {Zero, IRB.getInt32(6)}, "retval_origin"); + return ret; + } + /// Add MemorySanitizer instrumentation to a function. bool runOnFunction() { // In the presence of unreachable blocks, we may see Phi nodes with @@ -1139,12 +1386,40 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return std::make_pair(ShadowPtr, OriginPtr); } + std::pair<Value *, Value *> + getShadowOriginPtrKernel(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, + unsigned Alignment, bool isStore) { + Value *ShadowOriginPtrs; + const DataLayout &DL = F.getParent()->getDataLayout(); + int Size = DL.getTypeStoreSize(ShadowTy); + + Value *Getter = MS.getKmsanShadowOriginAccessFn(isStore, Size); + Value *AddrCast = + IRB.CreatePointerCast(Addr, PointerType::get(IRB.getInt8Ty(), 0)); + if (Getter) { + ShadowOriginPtrs = IRB.CreateCall(Getter, AddrCast); + } else { + Value *SizeVal = ConstantInt::get(MS.IntptrTy, Size); + ShadowOriginPtrs = IRB.CreateCall(isStore ? MS.MsanMetadataPtrForStoreN + : MS.MsanMetadataPtrForLoadN, + {AddrCast, SizeVal}); + } + Value *ShadowPtr = IRB.CreateExtractValue(ShadowOriginPtrs, 0); + ShadowPtr = IRB.CreatePointerCast(ShadowPtr, PointerType::get(ShadowTy, 0)); + Value *OriginPtr = IRB.CreateExtractValue(ShadowOriginPtrs, 1); + + return std::make_pair(ShadowPtr, OriginPtr); + } + std::pair<Value *, Value *> getShadowOriginPtr(Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, unsigned Alignment, bool isStore) { - std::pair<Value *, Value *> ret = - getShadowOriginPtrUserspace(Addr, IRB, ShadowTy, Alignment); + std::pair<Value *, Value *> ret; + if (MS.CompileKernel) + ret = getShadowOriginPtrKernel(Addr, IRB, ShadowTy, Alignment, isStore); + else + ret = getShadowOriginPtrUserspace(Addr, IRB, ShadowTy, Alignment); return ret; } @@ -1163,7 +1438,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Compute the origin address for a given function argument. Value *getOriginPtrForArgument(Value *A, IRBuilder<> &IRB, int ArgOffset) { - if (!MS.TrackOrigins) return nullptr; + if (!MS.TrackOrigins) + return nullptr; Value *Base = IRB.CreatePointerCast(MS.ParamOriginTLS, MS.IntptrTy); if (ArgOffset) Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); @@ -1303,6 +1579,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { getShadowOriginPtr(V, EntryIRB, EntryIRB.getInt8Ty(), ArgAlign, /*isStore*/ true) .first; + // TODO(glider): need to copy origins. if (Overflow) { // ParamTLS overflow. EntryIRB.CreateMemSet( @@ -2850,6 +3127,12 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { handleVectorComparePackedIntrinsic(I); break; + case Intrinsic::is_constant: + // The result of llvm.is.constant() is always defined. + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + break; + default: if (!handleUnknownIntrinsic(I)) visitInstruction(I); @@ -2868,7 +3151,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // outputs as clean. Note that any side effects of the inline asm that are // not immediately visible in its constraints are not handled. if (Call->isInlineAsm()) { - if (ClHandleAsmConservative) + if (ClHandleAsmConservative && MS.CompileKernel) visitAsmInstruction(I); else visitInstruction(I); @@ -2921,12 +3204,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ArgOffset + Size > kParamTLSSize) break; unsigned ParamAlignment = CS.getParamAlignment(i); unsigned Alignment = std::min(ParamAlignment, kShadowTLSAlignment); - Value *AShadowPtr = getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), - Alignment, /*isStore*/ false) - .first; + Value *AShadowPtr = + getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), Alignment, + /*isStore*/ false) + .first; Store = IRB.CreateMemCpy(ArgShadowBase, Alignment, AShadowPtr, Alignment, Size); + // TODO(glider): need to copy origins. } else { Size = DL.getTypeAllocSize(A->getType()); if (ArgOffset + Size > kParamTLSSize) break; @@ -2945,8 +3230,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } LLVM_DEBUG(dbgs() << " done with call args\n"); - FunctionType *FT = - cast<FunctionType>(CS.getCalledValue()->getType()->getContainedType(0)); + FunctionType *FT = CS.getFunctionType(); if (FT->isVarArg()) { VAHelper->visitCallSite(CS, IRB); } @@ -3033,40 +3317,34 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { "_msphi_o")); } - void visitAllocaInst(AllocaInst &I) { - setShadow(&I, getCleanShadow(&I)); - setOrigin(&I, getCleanOrigin()); - IRBuilder<> IRB(I.getNextNode()); - const DataLayout &DL = F.getParent()->getDataLayout(); - uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType()); - Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize); - if (I.isArrayAllocation()) - Len = IRB.CreateMul(Len, I.getArraySize()); + Value *getLocalVarDescription(AllocaInst &I) { + SmallString<2048> StackDescriptionStorage; + raw_svector_ostream StackDescription(StackDescriptionStorage); + // We create a string with a description of the stack allocation and + // pass it into __msan_set_alloca_origin. + // It will be printed by the run-time if stack-originated UMR is found. + // The first 4 bytes of the string are set to '----' and will be replaced + // by __msan_va_arg_overflow_size_tls at the first call. + StackDescription << "----" << I.getName() << "@" << F.getName(); + return createPrivateNonConstGlobalForString(*F.getParent(), + StackDescription.str()); + } + + void instrumentAllocaUserspace(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { if (PoisonStack && ClPoisonStackWithCall) { IRB.CreateCall(MS.MsanPoisonStackFn, {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); } else { - Value *ShadowBase = getShadowOriginPtr(&I, IRB, IRB.getInt8Ty(), - I.getAlignment(), /*isStore*/ true) - .first; + Value *ShadowBase, *OriginBase; + std::tie(ShadowBase, OriginBase) = + getShadowOriginPtr(&I, IRB, IRB.getInt8Ty(), 1, /*isStore*/ true); Value *PoisonValue = IRB.getInt8(PoisonStack ? ClPoisonStackPattern : 0); IRB.CreateMemSet(ShadowBase, PoisonValue, Len, I.getAlignment()); } if (PoisonStack && MS.TrackOrigins) { - SmallString<2048> StackDescriptionStorage; - raw_svector_ostream StackDescription(StackDescriptionStorage); - // We create a string with a description of the stack allocation and - // pass it into __msan_set_alloca_origin. - // It will be printed by the run-time if stack-originated UMR is found. - // The first 4 bytes of the string are set to '----' and will be replaced - // by __msan_va_arg_overflow_size_tls at the first call. - StackDescription << "----" << I.getName() << "@" << F.getName(); - Value *Descr = - createPrivateNonConstGlobalForString(*F.getParent(), - StackDescription.str()); - + Value *Descr = getLocalVarDescription(I); IRB.CreateCall(MS.MsanSetAllocaOrigin4Fn, {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy()), @@ -3074,6 +3352,34 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } } + void instrumentAllocaKmsan(AllocaInst &I, IRBuilder<> &IRB, Value *Len) { + Value *Descr = getLocalVarDescription(I); + if (PoisonStack) { + IRB.CreateCall(MS.MsanPoisonAllocaFn, + {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len, + IRB.CreatePointerCast(Descr, IRB.getInt8PtrTy())}); + } else { + IRB.CreateCall(MS.MsanUnpoisonAllocaFn, + {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); + } + } + + void visitAllocaInst(AllocaInst &I) { + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); + IRBuilder<> IRB(I.getNextNode()); + const DataLayout &DL = F.getParent()->getDataLayout(); + uint64_t TypeSize = DL.getTypeAllocSize(I.getAllocatedType()); + Value *Len = ConstantInt::get(MS.IntptrTy, TypeSize); + if (I.isArrayAllocation()) + Len = IRB.CreateMul(Len, I.getArraySize()); + + if (MS.CompileKernel) + instrumentAllocaKmsan(I, IRB, Len); + else + instrumentAllocaUserspace(I, IRB, Len); + } + void visitSelectInst(SelectInst& I) { IRBuilder<> IRB(&I); // a = select b, c, d @@ -3196,37 +3502,95 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // Nothing to do here. } + void instrumentAsmArgument(Value *Operand, Instruction &I, IRBuilder<> &IRB, + const DataLayout &DL, bool isOutput) { + // For each assembly argument, we check its value for being initialized. + // If the argument is a pointer, we assume it points to a single element + // of the corresponding type (or to a 8-byte word, if the type is unsized). + // Each such pointer is instrumented with a call to the runtime library. + Type *OpType = Operand->getType(); + // Check the operand value itself. + insertShadowCheck(Operand, &I); + if (!OpType->isPointerTy() || !isOutput) { + assert(!isOutput); + return; + } + Type *ElType = OpType->getPointerElementType(); + if (!ElType->isSized()) + return; + int Size = DL.getTypeStoreSize(ElType); + Value *Ptr = IRB.CreatePointerCast(Operand, IRB.getInt8PtrTy()); + Value *SizeVal = ConstantInt::get(MS.IntptrTy, Size); + IRB.CreateCall(MS.MsanInstrumentAsmStoreFn, {Ptr, SizeVal}); + } + + /// Get the number of output arguments returned by pointers. + int getNumOutputArgs(InlineAsm *IA, CallInst *CI) { + int NumRetOutputs = 0; + int NumOutputs = 0; + Type *RetTy = dyn_cast<Value>(CI)->getType(); + if (!RetTy->isVoidTy()) { + // Register outputs are returned via the CallInst return value. + StructType *ST = dyn_cast_or_null<StructType>(RetTy); + if (ST) + NumRetOutputs = ST->getNumElements(); + else + NumRetOutputs = 1; + } + InlineAsm::ConstraintInfoVector Constraints = IA->ParseConstraints(); + for (size_t i = 0, n = Constraints.size(); i < n; i++) { + InlineAsm::ConstraintInfo Info = Constraints[i]; + switch (Info.Type) { + case InlineAsm::isOutput: + NumOutputs++; + break; + default: + break; + } + } + return NumOutputs - NumRetOutputs; + } + void visitAsmInstruction(Instruction &I) { // Conservative inline assembly handling: check for poisoned shadow of // asm() arguments, then unpoison the result and all the memory locations // pointed to by those arguments. + // An inline asm() statement in C++ contains lists of input and output + // arguments used by the assembly code. These are mapped to operands of the + // CallInst as follows: + // - nR register outputs ("=r) are returned by value in a single structure + // (SSA value of the CallInst); + // - nO other outputs ("=m" and others) are returned by pointer as first + // nO operands of the CallInst; + // - nI inputs ("r", "m" and others) are passed to CallInst as the + // remaining nI operands. + // The total number of asm() arguments in the source is nR+nO+nI, and the + // corresponding CallInst has nO+nI+1 operands (the last operand is the + // function to be called). + const DataLayout &DL = F.getParent()->getDataLayout(); CallInst *CI = dyn_cast<CallInst>(&I); - - for (size_t i = 0, n = CI->getNumOperands(); i < n; i++) { + IRBuilder<> IRB(&I); + InlineAsm *IA = cast<InlineAsm>(CI->getCalledValue()); + int OutputArgs = getNumOutputArgs(IA, CI); + // The last operand of a CallInst is the function itself. + int NumOperands = CI->getNumOperands() - 1; + + // Check input arguments. Doing so before unpoisoning output arguments, so + // that we won't overwrite uninit values before checking them. + for (int i = OutputArgs; i < NumOperands; i++) { Value *Operand = CI->getOperand(i); - if (Operand->getType()->isSized()) - insertShadowCheck(Operand, &I); + instrumentAsmArgument(Operand, I, IRB, DL, /*isOutput*/ false); } - setShadow(&I, getCleanShadow(&I)); - setOrigin(&I, getCleanOrigin()); - IRBuilder<> IRB(&I); - IRB.SetInsertPoint(I.getNextNode()); - for (size_t i = 0, n = CI->getNumOperands(); i < n; i++) { + // Unpoison output arguments. This must happen before the actual InlineAsm + // call, so that the shadow for memory published in the asm() statement + // remains valid. + for (int i = 0; i < OutputArgs; i++) { Value *Operand = CI->getOperand(i); - Type *OpType = Operand->getType(); - if (!OpType->isPointerTy()) - continue; - Type *ElType = OpType->getPointerElementType(); - if (!ElType->isSized()) - continue; - Value *ShadowPtr, *OriginPtr; - std::tie(ShadowPtr, OriginPtr) = getShadowOriginPtr( - Operand, IRB, ElType, /*Alignment*/ 1, /*isStore*/ true); - Value *CShadow = getCleanShadow(ElType); - IRB.CreateStore( - CShadow, - IRB.CreatePointerCast(ShadowPtr, CShadow->getType()->getPointerTo())); + instrumentAsmArgument(Operand, I, IRB, DL, /*isOutput*/ true); } + + setShadow(&I, getCleanShadow(&I)); + setOrigin(&I, getCleanOrigin()); } void visitInstruction(Instruction &I) { @@ -3249,12 +3613,16 @@ struct VarArgAMD64Helper : public VarArgHelper { // An unfortunate workaround for asymmetric lowering of va_arg stuff. // See a comment in visitCallSite for more details. static const unsigned AMD64GpEndOffset = 48; // AMD64 ABI Draft 0.99.6 p3.5.7 - static const unsigned AMD64FpEndOffset = 176; + static const unsigned AMD64FpEndOffsetSSE = 176; + // If SSE is disabled, fp_offset in va_list is zero. + static const unsigned AMD64FpEndOffsetNoSSE = AMD64GpEndOffset; + unsigned AMD64FpEndOffset; Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; Value *VAArgTLSCopy = nullptr; + Value *VAArgTLSOriginCopy = nullptr; Value *VAArgOverflowSize = nullptr; SmallVector<CallInst*, 16> VAStartInstrumentationList; @@ -3262,7 +3630,18 @@ struct VarArgAMD64Helper : public VarArgHelper { enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory }; VarArgAMD64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} + MemorySanitizerVisitor &MSV) + : F(F), MS(MS), MSV(MSV) { + AMD64FpEndOffset = AMD64FpEndOffsetSSE; + for (const auto &Attr : F.getAttributes().getFnAttributes()) { + if (Attr.isStringAttribute() && + (Attr.getKindAsString() == "target-features")) { + if (Attr.getValueAsString().contains("-sse")) + AMD64FpEndOffset = AMD64FpEndOffsetNoSSE; + break; + } + } + } ArgKind classifyArgument(Value* arg) { // A very rough approximation of X86_64 argument classification rules. @@ -3304,9 +3683,14 @@ struct VarArgAMD64Helper : public VarArgHelper { assert(A->getType()->isPointerTy()); Type *RealTy = A->getType()->getPointerElementType(); uint64_t ArgSize = DL.getTypeAllocSize(RealTy); - Value *ShadowBase = - getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); + Value *ShadowBase = getShadowPtrForVAArgument( + RealTy, IRB, OverflowOffset, alignTo(ArgSize, 8)); + Value *OriginBase = nullptr; + if (MS.TrackOrigins) + OriginBase = getOriginPtrForVAArgument(RealTy, IRB, OverflowOffset); OverflowOffset += alignTo(ArgSize, 8); + if (!ShadowBase) + continue; Value *ShadowPtr, *OriginPtr; std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment, @@ -3314,20 +3698,31 @@ struct VarArgAMD64Helper : public VarArgHelper { IRB.CreateMemCpy(ShadowBase, kShadowTLSAlignment, ShadowPtr, kShadowTLSAlignment, ArgSize); + if (MS.TrackOrigins) + IRB.CreateMemCpy(OriginBase, kShadowTLSAlignment, OriginPtr, + kShadowTLSAlignment, ArgSize); } else { ArgKind AK = classifyArgument(A); if (AK == AK_GeneralPurpose && GpOffset >= AMD64GpEndOffset) AK = AK_Memory; if (AK == AK_FloatingPoint && FpOffset >= AMD64FpEndOffset) AK = AK_Memory; - Value *ShadowBase; + Value *ShadowBase, *OriginBase = nullptr; switch (AK) { case AK_GeneralPurpose: - ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, GpOffset); + ShadowBase = + getShadowPtrForVAArgument(A->getType(), IRB, GpOffset, 8); + if (MS.TrackOrigins) + OriginBase = + getOriginPtrForVAArgument(A->getType(), IRB, GpOffset); GpOffset += 8; break; case AK_FloatingPoint: - ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, FpOffset); + ShadowBase = + getShadowPtrForVAArgument(A->getType(), IRB, FpOffset, 16); + if (MS.TrackOrigins) + OriginBase = + getOriginPtrForVAArgument(A->getType(), IRB, FpOffset); FpOffset += 16; break; case AK_Memory: @@ -3335,15 +3730,27 @@ struct VarArgAMD64Helper : public VarArgHelper { continue; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); ShadowBase = - getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); + getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, 8); + if (MS.TrackOrigins) + OriginBase = + getOriginPtrForVAArgument(A->getType(), IRB, OverflowOffset); OverflowOffset += alignTo(ArgSize, 8); } // Take fixed arguments into account for GpOffset and FpOffset, // but don't actually store shadows for them. + // TODO(glider): don't call get*PtrForVAArgument() for them. if (IsFixed) continue; - IRB.CreateAlignedStore(MSV.getShadow(A), ShadowBase, - kShadowTLSAlignment); + if (!ShadowBase) + continue; + Value *Shadow = MSV.getShadow(A); + IRB.CreateAlignedStore(Shadow, ShadowBase, kShadowTLSAlignment); + if (MS.TrackOrigins) { + Value *Origin = MSV.getOrigin(A); + unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType()); + MSV.paintOrigin(IRB, Origin, OriginBase, StoreSize, + std::max(kShadowTLSAlignment, kMinOriginAlignment)); + } } } Constant *OverflowSize = @@ -3353,11 +3760,25 @@ struct VarArgAMD64Helper : public VarArgHelper { /// Compute the shadow address for a given va_arg. Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - int ArgOffset) { + unsigned ArgOffset, unsigned ArgSize) { + // Make sure we don't overflow __msan_va_arg_tls. + if (ArgOffset + ArgSize > kParamTLSSize) + return nullptr; Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), - "_msarg"); + "_msarg_va_s"); + } + + /// Compute the origin address for a given va_arg. + Value *getOriginPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, int ArgOffset) { + Value *Base = IRB.CreatePointerCast(MS.VAArgOriginTLS, MS.IntptrTy); + // getOriginPtrForVAArgument() is always called after + // getShadowPtrForVAArgument(), so __msan_va_arg_origin_tls can never + // overflow. + Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); + return IRB.CreateIntToPtr(Base, PointerType::get(MS.OriginTy, 0), + "_msarg_va_o"); } void unpoisonVAListTagForInst(IntrinsicInst &I) { @@ -3402,6 +3823,10 @@ struct VarArgAMD64Helper : public VarArgHelper { VAArgOverflowSize); VAArgTLSCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); IRB.CreateMemCpy(VAArgTLSCopy, 8, MS.VAArgTLS, 8, CopySize); + if (MS.TrackOrigins) { + VAArgTLSOriginCopy = IRB.CreateAlloca(Type::getInt8Ty(*MS.C), CopySize); + IRB.CreateMemCpy(VAArgTLSOriginCopy, 8, MS.VAArgOriginTLS, 8, CopySize); + } } // Instrument va_start. @@ -3423,6 +3848,9 @@ struct VarArgAMD64Helper : public VarArgHelper { Alignment, /*isStore*/ true); IRB.CreateMemCpy(RegSaveAreaShadowPtr, Alignment, VAArgTLSCopy, Alignment, AMD64FpEndOffset); + if (MS.TrackOrigins) + IRB.CreateMemCpy(RegSaveAreaOriginPtr, Alignment, VAArgTLSOriginCopy, + Alignment, AMD64FpEndOffset); Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 8)), @@ -3436,6 +3864,12 @@ struct VarArgAMD64Helper : public VarArgHelper { AMD64FpEndOffset); IRB.CreateMemCpy(OverflowArgAreaShadowPtr, Alignment, SrcPtr, Alignment, VAArgOverflowSize); + if (MS.TrackOrigins) { + SrcPtr = IRB.CreateConstGEP1_32(IRB.getInt8Ty(), VAArgTLSOriginCopy, + AMD64FpEndOffset); + IRB.CreateMemCpy(OverflowArgAreaOriginPtr, Alignment, SrcPtr, Alignment, + VAArgOverflowSize); + } } } }; @@ -3469,9 +3903,11 @@ struct VarArgMIPS64Helper : public VarArgHelper { if (ArgSize < 8) VAArgOffset += (8 - ArgSize); } - Base = getShadowPtrForVAArgument(A->getType(), IRB, VAArgOffset); + Base = getShadowPtrForVAArgument(A->getType(), IRB, VAArgOffset, ArgSize); VAArgOffset += ArgSize; VAArgOffset = alignTo(VAArgOffset, 8); + if (!Base) + continue; IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } @@ -3483,7 +3919,10 @@ struct VarArgMIPS64Helper : public VarArgHelper { /// Compute the shadow address for a given va_arg. Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - int ArgOffset) { + unsigned ArgOffset, unsigned ArgSize) { + // Make sure we don't overflow __msan_va_arg_tls. + if (ArgOffset + ArgSize > kParamTLSSize) + return nullptr; Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), @@ -3614,11 +4053,11 @@ struct VarArgAArch64Helper : public VarArgHelper { Value *Base; switch (AK) { case AK_GeneralPurpose: - Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset); + Base = getShadowPtrForVAArgument(A->getType(), IRB, GrOffset, 8); GrOffset += 8; break; case AK_FloatingPoint: - Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset); + Base = getShadowPtrForVAArgument(A->getType(), IRB, VrOffset, 8); VrOffset += 16; break; case AK_Memory: @@ -3627,7 +4066,8 @@ struct VarArgAArch64Helper : public VarArgHelper { if (IsFixed) continue; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); - Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); + Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset, + alignTo(ArgSize, 8)); OverflowOffset += alignTo(ArgSize, 8); break; } @@ -3635,6 +4075,8 @@ struct VarArgAArch64Helper : public VarArgHelper { // bother to actually store a shadow. if (IsFixed) continue; + if (!Base) + continue; IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } Constant *OverflowSize = @@ -3644,7 +4086,10 @@ struct VarArgAArch64Helper : public VarArgHelper { /// Compute the shadow address for a given va_arg. Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - int ArgOffset) { + unsigned ArgOffset, unsigned ArgSize) { + // Make sure we don't overflow __msan_va_arg_tls. + if (ArgOffset + ArgSize > kParamTLSSize) + return nullptr; Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), @@ -3849,14 +4294,17 @@ struct VarArgPowerPC64Helper : public VarArgHelper { ArgAlign = 8; VAArgOffset = alignTo(VAArgOffset, ArgAlign); if (!IsFixed) { - Value *Base = getShadowPtrForVAArgument(RealTy, IRB, - VAArgOffset - VAArgBase); - Value *AShadowPtr, *AOriginPtr; - std::tie(AShadowPtr, AOriginPtr) = MSV.getShadowOriginPtr( - A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment, /*isStore*/ false); - - IRB.CreateMemCpy(Base, kShadowTLSAlignment, AShadowPtr, - kShadowTLSAlignment, ArgSize); + Value *Base = getShadowPtrForVAArgument( + RealTy, IRB, VAArgOffset - VAArgBase, ArgSize); + if (Base) { + Value *AShadowPtr, *AOriginPtr; + std::tie(AShadowPtr, AOriginPtr) = + MSV.getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), + kShadowTLSAlignment, /*isStore*/ false); + + IRB.CreateMemCpy(Base, kShadowTLSAlignment, AShadowPtr, + kShadowTLSAlignment, ArgSize); + } } VAArgOffset += alignTo(ArgSize, 8); } else { @@ -3884,8 +4332,9 @@ struct VarArgPowerPC64Helper : public VarArgHelper { } if (!IsFixed) { Base = getShadowPtrForVAArgument(A->getType(), IRB, - VAArgOffset - VAArgBase); - IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); + VAArgOffset - VAArgBase, ArgSize); + if (Base) + IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); } VAArgOffset += ArgSize; VAArgOffset = alignTo(VAArgOffset, 8); @@ -3903,7 +4352,10 @@ struct VarArgPowerPC64Helper : public VarArgHelper { /// Compute the shadow address for a given va_arg. Value *getShadowPtrForVAArgument(Type *Ty, IRBuilder<> &IRB, - int ArgOffset) { + unsigned ArgOffset, unsigned ArgSize) { + // Make sure we don't overflow __msan_va_arg_tls. + if (ArgOffset + ArgSize > kParamTLSSize) + return nullptr; Value *Base = IRB.CreatePointerCast(MS.VAArgTLS, MS.IntptrTy); Base = IRB.CreateAdd(Base, ConstantInt::get(MS.IntptrTy, ArgOffset)); return IRB.CreateIntToPtr(Base, PointerType::get(MSV.getShadowTy(Ty), 0), @@ -4005,10 +4457,8 @@ static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, return new VarArgNoOpHelper(Func, Msan, Visitor); } -bool MemorySanitizer::runOnFunction(Function &F) { - if (&F == MsanCtorFunction) - return false; - MemorySanitizerVisitor Visitor(F, *this); +bool MemorySanitizer::sanitizeFunction(Function &F, TargetLibraryInfo &TLI) { + MemorySanitizerVisitor Visitor(F, *this, TLI); // Clear out readonly/readnone attributes. AttrBuilder B; diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 307b7eaa2196..f043325f5bba 100644 --- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -63,7 +63,7 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/IR/Attributes.h" @@ -141,6 +141,11 @@ static cl::opt<std::string> cl::value_desc("filename"), cl::desc("Specify the path of profile data file. This is" "mainly for test purpose.")); +static cl::opt<std::string> PGOTestProfileRemappingFile( + "pgo-test-profile-remapping-file", cl::init(""), cl::Hidden, + cl::value_desc("filename"), + cl::desc("Specify the path of profile remapping file. This is mainly for " + "test purpose.")); // Command line option to disable value profiling. The default is false: // i.e. value profiling is enabled by default. This is for debug purpose. @@ -539,7 +544,7 @@ public: MIVisitor.countMemIntrinsics(Func); NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts(); NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics(); - ValueSites[IPVK_IndirectCallTarget] = findIndirectCallSites(Func); + ValueSites[IPVK_IndirectCallTarget] = findIndirectCalls(Func); ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func); FuncName = getPGOFuncName(F); @@ -581,7 +586,7 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { std::vector<char> Indexes; JamCRC JC; for (auto &BB : F) { - const TerminatorInst *TI = BB.getTerminator(); + const Instruction *TI = BB.getTerminator(); for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) { BasicBlock *Succ = TI->getSuccessor(I); auto BI = findBBInfo(Succ); @@ -693,7 +698,7 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { // Instrument the SrcBB if it has a single successor, // otherwise, the DestBB if this is not a critical edge. - TerminatorInst *TI = SrcBB->getTerminator(); + Instruction *TI = SrcBB->getTerminator(); if (TI->getNumSuccessors() <= 1) return SrcBB; if (!E->IsCritical) @@ -749,12 +754,12 @@ static void instrumentOneFunc( if (DisableValueProfiling) return; - unsigned NumIndirectCallSites = 0; + unsigned NumIndirectCalls = 0; for (auto &I : FuncInfo.ValueSites[IPVK_IndirectCallTarget]) { CallSite CS(I); Value *Callee = CS.getCalledValue(); LLVM_DEBUG(dbgs() << "Instrument one indirect call: CallSite Index = " - << NumIndirectCallSites << "\n"); + << NumIndirectCalls << "\n"); IRBuilder<> Builder(I); assert(Builder.GetInsertPoint() != I->getParent()->end() && "Cannot get the Instrumentation point"); @@ -764,9 +769,9 @@ static void instrumentOneFunc( Builder.getInt64(FuncInfo.FunctionHash), Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()), Builder.getInt32(IPVK_IndirectCallTarget), - Builder.getInt32(NumIndirectCallSites++)}); + Builder.getInt32(NumIndirectCalls++)}); } - NumOfPGOICall += NumIndirectCallSites; + NumOfPGOICall += NumIndirectCalls; // Now instrument memop intrinsic calls. FuncInfo.MIVisitor.instrumentMemIntrinsics( @@ -854,7 +859,7 @@ public: FreqAttr(FFA_Normal) {} // Read counts for the instrumented BB from profile. - bool readCounters(IndexedInstrProfReader *PGOReader); + bool readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros); // Populate the counts for all BBs. void populateCounters(); @@ -899,6 +904,7 @@ public: FuncInfo.dumpInfo(Str); } + uint64_t getProgramMaxCount() const { return ProgramMaxCount; } private: Function &F; Module *M; @@ -1008,7 +1014,7 @@ void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) { // Read the profile from ProfileFileName and assign the value to the // instrumented BB and the edges. This function also updates ProgramMaxCount. // Return true if the profile are successfully read, and false on errors. -bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) { +bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros) { auto &Ctx = M->getContext(); Expected<InstrProfRecord> Result = PGOReader->getInstrProfRecord(FuncInfo.FuncName, FuncInfo.FunctionHash); @@ -1048,6 +1054,7 @@ bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) { LLVM_DEBUG(dbgs() << " " << I << ": " << CountFromProfile[I] << "\n"); ValueSum += CountFromProfile[I]; } + AllZeros = (ValueSum == 0); LLVM_DEBUG(dbgs() << "SUM = " << ValueSum << "\n"); @@ -1162,7 +1169,7 @@ void PGOUseFunc::setBranchWeights() { // Generate MD_prof metadata for every branch instruction. LLVM_DEBUG(dbgs() << "\nSetting branch weights.\n"); for (auto &BB : F) { - TerminatorInst *TI = BB.getTerminator(); + Instruction *TI = BB.getTerminator(); if (TI->getNumSuccessors() < 2) continue; if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || @@ -1208,7 +1215,7 @@ void PGOUseFunc::annotateIrrLoopHeaderWeights() { // to become an irreducible loop header after the indirectbr tail // duplication. if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) { - TerminatorInst *TI = BB.getTerminator(); + Instruction *TI = BB.getTerminator(); const UseBBInfo &BBCountInfo = getBBInfo(&BB); setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue); } @@ -1429,13 +1436,14 @@ PreservedAnalyses PGOInstrumentationGen::run(Module &M, } static bool annotateAllFunctions( - Module &M, StringRef ProfileFileName, + Module &M, StringRef ProfileFileName, StringRef ProfileRemappingFileName, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI, function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) { LLVM_DEBUG(dbgs() << "Read in profile counters: "); auto &Ctx = M.getContext(); // Read the counter array from file. - auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName); + auto ReaderOrErr = + IndexedInstrProfReader::create(ProfileFileName, ProfileRemappingFileName); if (Error E = ReaderOrErr.takeError()) { handleAllErrors(std::move(E), [&](const ErrorInfoBase &EI) { Ctx.diagnose( @@ -1471,8 +1479,15 @@ static bool annotateAllFunctions( // later in getInstrBB() to avoid invalidating it. SplitIndirectBrCriticalEdges(F, BPI, BFI); PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI); - if (!Func.readCounters(PGOReader.get())) + bool AllZeros = false; + if (!Func.readCounters(PGOReader.get(), AllZeros)) continue; + if (AllZeros) { + F.setEntryCount(ProfileCount(0, Function::PCT_Real)); + if (Func.getProgramMaxCount() != 0) + ColdFunctions.push_back(&F); + continue; + } Func.populateCounters(); Func.setBranchWeights(); Func.annotateValueSites(); @@ -1529,10 +1544,14 @@ static bool annotateAllFunctions( return true; } -PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename) - : ProfileFileName(std::move(Filename)) { +PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename, + std::string RemappingFilename) + : ProfileFileName(std::move(Filename)), + ProfileRemappingFileName(std::move(RemappingFilename)) { if (!PGOTestProfileFile.empty()) ProfileFileName = PGOTestProfileFile; + if (!PGOTestProfileRemappingFile.empty()) + ProfileRemappingFileName = PGOTestProfileRemappingFile; } PreservedAnalyses PGOInstrumentationUse::run(Module &M, @@ -1547,7 +1566,8 @@ PreservedAnalyses PGOInstrumentationUse::run(Module &M, return &FAM.getResult<BlockFrequencyAnalysis>(F); }; - if (!annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI)) + if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName, + LookupBPI, LookupBFI)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); @@ -1564,7 +1584,7 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) { return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI(); }; - return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI); + return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI); } static std::string getSimpleNodeName(const BasicBlock *Node) { diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index a4dd48c8dd6a..0ba8d5765e8c 100644 --- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Mangler.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" @@ -211,8 +212,8 @@ private: bool IsLeafFunc = true); Function *CreateInitCallsForSections(Module &M, const char *InitFunctionName, Type *Ty, const char *Section); - std::pair<GlobalVariable *, GlobalVariable *> - CreateSecStartEnd(Module &M, const char *Section, Type *Ty); + std::pair<Value *, Value *> CreateSecStartEnd(Module &M, const char *Section, + Type *Ty); void SetNoSanitizeMetadata(Instruction *I) { I->setMetadata(I->getModule()->getMDKindID("nosanitize"), @@ -234,6 +235,7 @@ private: Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy, *Int16Ty, *Int8Ty, *Int8PtrTy; Module *CurModule; + std::string CurModuleUniqueId; Triple TargetTriple; LLVMContext *C; const DataLayout *DL; @@ -249,7 +251,7 @@ private: } // namespace -std::pair<GlobalVariable *, GlobalVariable *> +std::pair<Value *, Value *> SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section, Type *Ty) { GlobalVariable *SecStart = @@ -260,22 +262,28 @@ SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section, new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, nullptr, getSectionEnd(Section)); SecEnd->setVisibility(GlobalValue::HiddenVisibility); - - return std::make_pair(SecStart, SecEnd); + IRBuilder<> IRB(M.getContext()); + Value *SecEndPtr = IRB.CreatePointerCast(SecEnd, Ty); + if (!TargetTriple.isOSBinFormatCOFF()) + return std::make_pair(IRB.CreatePointerCast(SecStart, Ty), SecEndPtr); + + // Account for the fact that on windows-msvc __start_* symbols actually + // point to a uint64_t before the start of the array. + auto SecStartI8Ptr = IRB.CreatePointerCast(SecStart, Int8PtrTy); + auto GEP = IRB.CreateGEP(SecStartI8Ptr, + ConstantInt::get(IntptrTy, sizeof(uint64_t))); + return std::make_pair(IRB.CreatePointerCast(GEP, Ty), SecEndPtr); } - Function *SanitizerCoverageModule::CreateInitCallsForSections( Module &M, const char *InitFunctionName, Type *Ty, const char *Section) { - IRBuilder<> IRB(M.getContext()); auto SecStartEnd = CreateSecStartEnd(M, Section, Ty); auto SecStart = SecStartEnd.first; auto SecEnd = SecStartEnd.second; Function *CtorFunc; std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( - M, SanCovModuleCtorName, InitFunctionName, {Ty, Ty}, - {IRB.CreatePointerCast(SecStart, Ty), IRB.CreatePointerCast(SecEnd, Ty)}); + M, SanCovModuleCtorName, InitFunctionName, {Ty, Ty}, {SecStart, SecEnd}); if (TargetTriple.supportsCOMDAT()) { // Use comdat to dedup CtorFunc. @@ -284,6 +292,17 @@ Function *SanitizerCoverageModule::CreateInitCallsForSections( } else { appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); } + + if (TargetTriple.isOSBinFormatCOFF()) { + // In COFF files, if the contructors are set as COMDAT (they are because + // COFF supports COMDAT) and the linker flag /OPT:REF (strip unreferenced + // functions and data) is used, the constructors get stripped. To prevent + // this, give the constructors weak ODR linkage and ensure the linker knows + // to include the sancov constructor. This way the linker can deduplicate + // the constructors but always leave one copy. + CtorFunc->setLinkage(GlobalValue::WeakODRLinkage); + appendToUsed(M, CtorFunc); + } return CtorFunc; } @@ -293,6 +312,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { C = &(M.getContext()); DL = &M.getDataLayout(); CurModule = &M; + CurModuleUniqueId = getUniqueModuleId(CurModule); TargetTriple = Triple(M.getTargetTriple()); FunctionGuardArray = nullptr; Function8bitCounterArray = nullptr; @@ -397,9 +417,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { Function *InitFunction = declareSanitizerInitFunction( M, SanCovPCsInitName, {IntptrPtrTy, IntptrPtrTy}); IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator()); - IRBCtor.CreateCall(InitFunction, - {IRB.CreatePointerCast(SecStartEnd.first, IntptrPtrTy), - IRB.CreatePointerCast(SecStartEnd.second, IntptrPtrTy)}); + IRBCtor.CreateCall(InitFunction, {SecStartEnd.first, SecStartEnd.second}); } // We don't reference these arrays directly in any of our runtime functions, // so we need to prevent them from being dead stripped. @@ -549,11 +567,19 @@ GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection( auto Array = new GlobalVariable( *CurModule, ArrayTy, false, GlobalVariable::PrivateLinkage, Constant::getNullValue(ArrayTy), "__sancov_gen_"); - if (auto Comdat = F.getComdat()) - Array->setComdat(Comdat); + + if (TargetTriple.supportsCOMDAT() && !F.isInterposable()) + if (auto Comdat = + GetOrCreateFunctionComdat(F, TargetTriple, CurModuleUniqueId)) + Array->setComdat(Comdat); Array->setSection(getSectionName(Section)); Array->setAlignment(Ty->isPointerTy() ? DL->getPointerSize() : Ty->getPrimitiveSizeInBits() / 8); + GlobalsToAppendToUsed.push_back(Array); + GlobalsToAppendToCompilerUsed.push_back(Array); + MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F)); + Array->addMetadata(LLVMContext::MD_associated, *MD); + return Array; } @@ -587,24 +613,16 @@ SanitizerCoverageModule::CreatePCArray(Function &F, void SanitizerCoverageModule::CreateFunctionLocalArrays( Function &F, ArrayRef<BasicBlock *> AllBlocks) { - if (Options.TracePCGuard) { + if (Options.TracePCGuard) FunctionGuardArray = CreateFunctionLocalArrayInSection( AllBlocks.size(), F, Int32Ty, SanCovGuardsSectionName); - GlobalsToAppendToUsed.push_back(FunctionGuardArray); - } - if (Options.Inline8bitCounters) { + + if (Options.Inline8bitCounters) Function8bitCounterArray = CreateFunctionLocalArrayInSection( AllBlocks.size(), F, Int8Ty, SanCovCountersSectionName); - GlobalsToAppendToCompilerUsed.push_back(Function8bitCounterArray); - MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F)); - Function8bitCounterArray->addMetadata(LLVMContext::MD_associated, *MD); - } - if (Options.PCTable) { + + if (Options.PCTable) FunctionPCsArray = CreatePCArray(F, AllBlocks); - GlobalsToAppendToCompilerUsed.push_back(FunctionPCsArray); - MDNode *MD = MDNode::get(F.getContext(), ValueAsMetadata::get(&F)); - FunctionPCsArray->addMetadata(LLVMContext::MD_associated, *MD); - } } bool SanitizerCoverageModule::InjectCoverage(Function &F, @@ -806,8 +824,13 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, std::string SanitizerCoverageModule::getSectionName(const std::string &Section) const { - if (TargetTriple.getObjectFormat() == Triple::COFF) - return ".SCOV$M"; + if (TargetTriple.isOSBinFormatCOFF()) { + if (Section == SanCovCountersSectionName) + return ".SCOV$CM"; + if (Section == SanCovPCsSectionName) + return ".SCOVP$M"; + return ".SCOV$GM"; // For SanCovGuardsSectionName. + } if (TargetTriple.isOSBinFormatMachO()) return "__DATA,__" + Section; return "__" + Section; diff --git a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index fa1e5a157a0f..077364e15c4f 100644 --- a/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -19,6 +19,7 @@ // The rest is handled by the run-time library. //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Instrumentation/ThreadSanitizer.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" @@ -86,15 +87,16 @@ static const char *const kTsanInitName = "__tsan_init"; namespace { /// ThreadSanitizer: instrument the code in module to find races. -struct ThreadSanitizer : public FunctionPass { - ThreadSanitizer() : FunctionPass(ID) {} - StringRef getPassName() const override; - void getAnalysisUsage(AnalysisUsage &AU) const override; - bool runOnFunction(Function &F) override; - bool doInitialization(Module &M) override; - static char ID; // Pass identification, replacement for typeid. - - private: +/// +/// Instantiating ThreadSanitizer inserts the tsan runtime library API function +/// declarations into the module if they don't exist already. Instantiating +/// ensures the __tsan_init function is in the list of global constructors for +/// the module. +struct ThreadSanitizer { + ThreadSanitizer(Module &M); + bool sanitizeFunction(Function &F, const TargetLibraryInfo &TLI); + +private: void initializeCallbacks(Module &M); bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL); bool instrumentAtomic(Instruction *I, const DataLayout &DL); @@ -130,27 +132,55 @@ struct ThreadSanitizer : public FunctionPass { Function *MemmoveFn, *MemcpyFn, *MemsetFn; Function *TsanCtorFunction; }; + +struct ThreadSanitizerLegacyPass : FunctionPass { + ThreadSanitizerLegacyPass() : FunctionPass(ID) {} + StringRef getPassName() const override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; + bool doInitialization(Module &M) override; + static char ID; // Pass identification, replacement for typeid. +private: + Optional<ThreadSanitizer> TSan; +}; } // namespace -char ThreadSanitizer::ID = 0; -INITIALIZE_PASS_BEGIN( - ThreadSanitizer, "tsan", - "ThreadSanitizer: detects data races.", - false, false) +PreservedAnalyses ThreadSanitizerPass::run(Function &F, + FunctionAnalysisManager &FAM) { + ThreadSanitizer TSan(*F.getParent()); + if (TSan.sanitizeFunction(F, FAM.getResult<TargetLibraryAnalysis>(F))) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +char ThreadSanitizerLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(ThreadSanitizerLegacyPass, "tsan", + "ThreadSanitizer: detects data races.", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END( - ThreadSanitizer, "tsan", - "ThreadSanitizer: detects data races.", - false, false) +INITIALIZE_PASS_END(ThreadSanitizerLegacyPass, "tsan", + "ThreadSanitizer: detects data races.", false, false) -StringRef ThreadSanitizer::getPassName() const { return "ThreadSanitizer"; } +StringRef ThreadSanitizerLegacyPass::getPassName() const { + return "ThreadSanitizerLegacyPass"; +} -void ThreadSanitizer::getAnalysisUsage(AnalysisUsage &AU) const { +void ThreadSanitizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<TargetLibraryInfoWrapperPass>(); } -FunctionPass *llvm::createThreadSanitizerPass() { - return new ThreadSanitizer(); +bool ThreadSanitizerLegacyPass::doInitialization(Module &M) { + TSan.emplace(M); + return true; +} + +bool ThreadSanitizerLegacyPass::runOnFunction(Function &F) { + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + TSan->sanitizeFunction(F, TLI); + return true; +} + +FunctionPass *llvm::createThreadSanitizerLegacyPassPass() { + return new ThreadSanitizerLegacyPass(); } void ThreadSanitizer::initializeCallbacks(Module &M) { @@ -252,16 +282,16 @@ void ThreadSanitizer::initializeCallbacks(Module &M) { IRB.getInt32Ty(), IntptrTy)); } -bool ThreadSanitizer::doInitialization(Module &M) { +ThreadSanitizer::ThreadSanitizer(Module &M) { const DataLayout &DL = M.getDataLayout(); IntptrTy = DL.getIntPtrType(M.getContext()); - std::tie(TsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions( - M, kTsanModuleCtorName, kTsanInitName, /*InitArgTypes=*/{}, - /*InitArgs=*/{}); - - appendToGlobalCtors(M, TsanCtorFunction, 0); - - return true; + std::tie(TsanCtorFunction, std::ignore) = + getOrCreateSanitizerCtorAndInitFunctions( + M, kTsanModuleCtorName, kTsanInitName, /*InitArgTypes=*/{}, + /*InitArgs=*/{}, + // This callback is invoked when the functions are created the first + // time. Hook them into the global ctors list in that case: + [&](Function *Ctor, Function *) { appendToGlobalCtors(M, Ctor, 0); }); } static bool isVtableAccess(Instruction *I) { @@ -402,7 +432,8 @@ void ThreadSanitizer::InsertRuntimeIgnores(Function &F) { } } -bool ThreadSanitizer::runOnFunction(Function &F) { +bool ThreadSanitizer::sanitizeFunction(Function &F, + const TargetLibraryInfo &TLI) { // This is required to prevent instrumenting call to __tsan_init from within // the module constructor. if (&F == TsanCtorFunction) @@ -416,8 +447,6 @@ bool ThreadSanitizer::runOnFunction(Function &F) { bool HasCalls = false; bool SanitizeFunction = F.hasFnAttribute(Attribute::SanitizeThread); const DataLayout &DL = F.getParent()->getDataLayout(); - const TargetLibraryInfo *TLI = - &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); // Traverse all instructions, collect loads/stores/returns, check for calls. for (auto &BB : F) { @@ -428,7 +457,7 @@ bool ThreadSanitizer::runOnFunction(Function &F) { LocalLoadsAndStores.push_back(&Inst); else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) { if (CallInst *CI = dyn_cast<CallInst>(&Inst)) - maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI); + maybeMarkSanitizerLibraryCallNoBuiltin(CI, &TLI); if (isa<MemIntrinsic>(Inst)) MemIntrinCalls.push_back(&Inst); HasCalls = true; diff --git a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h index ba4924c9cb2d..7f6b157304a3 100644 --- a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h +++ b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h @@ -26,6 +26,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/ErrorHandling.h" @@ -74,27 +75,27 @@ public: switch (kind) { case ARCRuntimeEntryPointKind::AutoreleaseRV: - return getI8XRetI8XEntryPoint(AutoreleaseRV, - "objc_autoreleaseReturnValue", true); + return getIntrinsicEntryPoint(AutoreleaseRV, + Intrinsic::objc_autoreleaseReturnValue); case ARCRuntimeEntryPointKind::Release: - return getVoidRetI8XEntryPoint(Release, "objc_release"); + return getIntrinsicEntryPoint(Release, Intrinsic::objc_release); case ARCRuntimeEntryPointKind::Retain: - return getI8XRetI8XEntryPoint(Retain, "objc_retain", true); + return getIntrinsicEntryPoint(Retain, Intrinsic::objc_retain); case ARCRuntimeEntryPointKind::RetainBlock: - return getI8XRetI8XEntryPoint(RetainBlock, "objc_retainBlock", false); + return getIntrinsicEntryPoint(RetainBlock, Intrinsic::objc_retainBlock); case ARCRuntimeEntryPointKind::Autorelease: - return getI8XRetI8XEntryPoint(Autorelease, "objc_autorelease", true); + return getIntrinsicEntryPoint(Autorelease, Intrinsic::objc_autorelease); case ARCRuntimeEntryPointKind::StoreStrong: - return getI8XRetI8XXI8XEntryPoint(StoreStrong, "objc_storeStrong"); + return getIntrinsicEntryPoint(StoreStrong, Intrinsic::objc_storeStrong); case ARCRuntimeEntryPointKind::RetainRV: - return getI8XRetI8XEntryPoint(RetainRV, - "objc_retainAutoreleasedReturnValue", true); + return getIntrinsicEntryPoint(RetainRV, + Intrinsic::objc_retainAutoreleasedReturnValue); case ARCRuntimeEntryPointKind::RetainAutorelease: - return getI8XRetI8XEntryPoint(RetainAutorelease, "objc_retainAutorelease", - true); + return getIntrinsicEntryPoint(RetainAutorelease, + Intrinsic::objc_retainAutorelease); case ARCRuntimeEntryPointKind::RetainAutoreleaseRV: - return getI8XRetI8XEntryPoint(RetainAutoreleaseRV, - "objc_retainAutoreleaseReturnValue", true); + return getIntrinsicEntryPoint(RetainAutoreleaseRV, + Intrinsic::objc_retainAutoreleaseReturnValue); } llvm_unreachable("Switch should be a covered switch."); @@ -131,54 +132,11 @@ private: /// Declaration for objc_retainAutoreleaseReturnValue(). Constant *RetainAutoreleaseRV = nullptr; - Constant *getVoidRetI8XEntryPoint(Constant *&Decl, StringRef Name) { + Constant *getIntrinsicEntryPoint(Constant *&Decl, Intrinsic::ID IntID) { if (Decl) return Decl; - LLVMContext &C = TheModule->getContext(); - Type *Params[] = { PointerType::getUnqual(Type::getInt8Ty(C)) }; - AttributeList Attr = AttributeList().addAttribute( - C, AttributeList::FunctionIndex, Attribute::NoUnwind); - FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params, - /*isVarArg=*/false); - return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); - } - - Constant *getI8XRetI8XEntryPoint(Constant *&Decl, StringRef Name, - bool NoUnwind = false) { - if (Decl) - return Decl; - - LLVMContext &C = TheModule->getContext(); - Type *I8X = PointerType::getUnqual(Type::getInt8Ty(C)); - Type *Params[] = { I8X }; - FunctionType *Fty = FunctionType::get(I8X, Params, /*isVarArg=*/false); - AttributeList Attr = AttributeList(); - - if (NoUnwind) - Attr = Attr.addAttribute(C, AttributeList::FunctionIndex, - Attribute::NoUnwind); - - return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); - } - - Constant *getI8XRetI8XXI8XEntryPoint(Constant *&Decl, StringRef Name) { - if (Decl) - return Decl; - - LLVMContext &C = TheModule->getContext(); - Type *I8X = PointerType::getUnqual(Type::getInt8Ty(C)); - Type *I8XX = PointerType::getUnqual(I8X); - Type *Params[] = { I8XX, I8X }; - - AttributeList Attr = AttributeList().addAttribute( - C, AttributeList::FunctionIndex, Attribute::NoUnwind); - Attr = Attr.addParamAttribute(C, 0, Attribute::NoCapture); - - FunctionType *Fty = FunctionType::get(Type::getVoidTy(C), Params, - /*isVarArg=*/false); - - return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); + return Decl = Intrinsic::getDeclaration(TheModule, IntID); } }; diff --git a/lib/Transforms/ObjCARC/DependencyAnalysis.cpp b/lib/Transforms/ObjCARC/DependencyAnalysis.cpp index 464805051c65..4bd5fd1acd4c 100644 --- a/lib/Transforms/ObjCARC/DependencyAnalysis.cpp +++ b/lib/Transforms/ObjCARC/DependencyAnalysis.cpp @@ -45,18 +45,15 @@ bool llvm::objcarc::CanAlterRefCount(const Instruction *Inst, const Value *Ptr, default: break; } - ImmutableCallSite CS(Inst); - assert(CS && "Only calls can alter reference counts!"); + const auto *Call = cast<CallBase>(Inst); // See if AliasAnalysis can help us with the call. - FunctionModRefBehavior MRB = PA.getAA()->getModRefBehavior(CS); + FunctionModRefBehavior MRB = PA.getAA()->getModRefBehavior(Call); if (AliasAnalysis::onlyReadsMemory(MRB)) return false; if (AliasAnalysis::onlyAccessesArgPointees(MRB)) { const DataLayout &DL = Inst->getModule()->getDataLayout(); - for (ImmutableCallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); - I != E; ++I) { - const Value *Op = *I; + for (const Value *Op : Call->args()) { if (IsPotentialRetainableObjPtr(Op, *PA.getAA()) && PA.related(Ptr, Op, DL)) return true; @@ -266,13 +263,10 @@ llvm::objcarc::FindDependencies(DependenceKind Flavor, for (const BasicBlock *BB : Visited) { if (BB == StartBB) continue; - const TerminatorInst *TI = cast<TerminatorInst>(&BB->back()); - for (succ_const_iterator SI(TI), SE(TI, false); SI != SE; ++SI) { - const BasicBlock *Succ = *SI; + for (const BasicBlock *Succ : successors(BB)) if (Succ != StartBB && !Visited.count(Succ)) { DependingInsts.insert(reinterpret_cast<Instruction *>(-1)); return; } - } } } diff --git a/lib/Transforms/ObjCARC/ObjCARC.h b/lib/Transforms/ObjCARC/ObjCARC.h index 1dbe72c7569f..751c8f30e814 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.h +++ b/lib/Transforms/ObjCARC/ObjCARC.h @@ -58,7 +58,7 @@ static inline void EraseInstruction(Instruction *CI) { // Replace the return value with the argument. assert((IsForwarding(GetBasicARCInstKind(CI)) || (IsNoopOnNull(GetBasicARCInstKind(CI)) && - isa<ConstantPointerNull>(OldArg))) && + IsNullOrUndef(OldArg->stripPointerCasts()))) && "Can't delete non-forwarding instruction with users!"); CI->replaceAllUsesWith(OldArg); } diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp index 1f1ea9f58739..abe2871c0b8f 100644 --- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -522,7 +522,7 @@ bool ObjCARCContract::tryToPeepholeInstruction( TailOkForStoreStrongs = false; return true; case ARCInstKind::IntrinsicUser: - // Remove calls to @clang.arc.use(...). + // Remove calls to @llvm.objc.clang.arc.use(...). Inst->eraseFromParent(); return true; default: diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index 21e2848030fc..9a02174556fc 100644 --- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -600,6 +600,17 @@ ObjCARCOpt::OptimizeRetainRVCall(Function &F, Instruction *RetainRV) { } } + // Track PHIs which are equivalent to our Arg. + SmallDenseSet<const Value*, 2> EquivalentArgs; + EquivalentArgs.insert(Arg); + + // Add PHIs that are equivalent to Arg to ArgUsers. + if (const PHINode *PN = dyn_cast<PHINode>(Arg)) { + SmallVector<const Value *, 2> ArgUsers; + getEquivalentPHIs(*PN, ArgUsers); + EquivalentArgs.insert(ArgUsers.begin(), ArgUsers.end()); + } + // Check for being preceded by an objc_autoreleaseReturnValue on the same // pointer. In this case, we can delete the pair. BasicBlock::iterator I = RetainRV->getIterator(), @@ -609,7 +620,7 @@ ObjCARCOpt::OptimizeRetainRVCall(Function &F, Instruction *RetainRV) { --I; while (I != Begin && IsNoopInstruction(&*I)); if (GetBasicARCInstKind(&*I) == ARCInstKind::AutoreleaseRV && - GetArgRCIdentityRoot(&*I) == Arg) { + EquivalentArgs.count(GetArgRCIdentityRoot(&*I))) { Changed = true; ++NumPeeps; @@ -914,8 +925,8 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { GetRCIdentityRoot(PN->getIncomingValue(i)); if (IsNullOrUndef(Incoming)) HasNull = true; - else if (cast<TerminatorInst>(PN->getIncomingBlock(i)->back()) - .getNumSuccessors() != 1) { + else if (PN->getIncomingBlock(i)->getTerminator()->getNumSuccessors() != + 1) { HasCriticalEdges = true; break; } @@ -1084,18 +1095,15 @@ ObjCARCOpt::CheckForCFGHazards(const BasicBlock *BB, "Unknown top down sequence state."); const Value *Arg = I->first; - const TerminatorInst *TI = cast<TerminatorInst>(&BB->back()); bool SomeSuccHasSame = false; bool AllSuccsHaveSame = true; bool NotAllSeqEqualButKnownSafe = false; - succ_const_iterator SI(TI), SE(TI, false); - - for (; SI != SE; ++SI) { + for (const BasicBlock *Succ : successors(BB)) { // If VisitBottomUp has pointer information for this successor, take // what we know about it. const DenseMap<const BasicBlock *, BBState>::iterator BBI = - BBStates.find(*SI); + BBStates.find(Succ); assert(BBI != BBStates.end()); const BottomUpPtrState &SuccS = BBI->second.getPtrBottomUpState(Arg); const Sequence SuccSSeq = SuccS.GetSeq(); @@ -1414,21 +1422,20 @@ ComputePostOrders(Function &F, BasicBlock *EntryBB = &F.getEntryBlock(); BBState &MyStates = BBStates[EntryBB]; MyStates.SetAsEntry(); - TerminatorInst *EntryTI = cast<TerminatorInst>(&EntryBB->back()); + Instruction *EntryTI = EntryBB->getTerminator(); SuccStack.push_back(std::make_pair(EntryBB, succ_iterator(EntryTI))); Visited.insert(EntryBB); OnStack.insert(EntryBB); do { dfs_next_succ: BasicBlock *CurrBB = SuccStack.back().first; - TerminatorInst *TI = cast<TerminatorInst>(&CurrBB->back()); - succ_iterator SE(TI, false); + succ_iterator SE(CurrBB->getTerminator(), false); while (SuccStack.back().second != SE) { BasicBlock *SuccBB = *SuccStack.back().second++; if (Visited.insert(SuccBB).second) { - TerminatorInst *TI = cast<TerminatorInst>(&SuccBB->back()); - SuccStack.push_back(std::make_pair(SuccBB, succ_iterator(TI))); + SuccStack.push_back( + std::make_pair(SuccBB, succ_iterator(SuccBB->getTerminator()))); BBStates[CurrBB].addSucc(SuccBB); BBState &SuccStates = BBStates[SuccBB]; SuccStates.addPred(CurrBB); diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp index ce09a477b5f5..b0602d96798c 100644 --- a/lib/Transforms/Scalar/ADCE.cpp +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -30,9 +30,10 @@ #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" @@ -102,7 +103,7 @@ struct BlockInfoType { BasicBlock *BB = nullptr; /// Cache of BB->getTerminator(). - TerminatorInst *Terminator = nullptr; + Instruction *Terminator = nullptr; /// Post-order numbering of reverse control flow graph. unsigned PostOrder; @@ -115,7 +116,7 @@ class AggressiveDeadCodeElimination { // ADCE does not use DominatorTree per se, but it updates it to preserve the // analysis. - DominatorTree &DT; + DominatorTree *DT; PostDominatorTree &PDT; /// Mapping of blocks to associated information, an element in BlockInfoVec. @@ -190,7 +191,7 @@ class AggressiveDeadCodeElimination { void makeUnconditional(BasicBlock *BB, BasicBlock *Target); public: - AggressiveDeadCodeElimination(Function &F, DominatorTree &DT, + AggressiveDeadCodeElimination(Function &F, DominatorTree *DT, PostDominatorTree &PDT) : F(F), DT(DT), PDT(PDT) {} @@ -205,7 +206,7 @@ bool AggressiveDeadCodeElimination::performDeadCodeElimination() { return removeDeadInstructions(); } -static bool isUnconditionalBranch(TerminatorInst *Term) { +static bool isUnconditionalBranch(Instruction *Term) { auto *BR = dyn_cast<BranchInst>(Term); return BR && BR->isUnconditional(); } @@ -276,7 +277,7 @@ void AggressiveDeadCodeElimination::initialize() { // treat all edges to a block already seen as loop back edges // and mark the branch live it if there is a back edge. for (auto *BB: depth_first_ext(&F.getEntryBlock(), State)) { - TerminatorInst *Term = BB->getTerminator(); + Instruction *Term = BB->getTerminator(); if (isLive(Term)) continue; @@ -330,7 +331,7 @@ bool AggressiveDeadCodeElimination::isAlwaysLive(Instruction &I) { return false; return true; } - if (!isa<TerminatorInst>(I)) + if (!I.isTerminator()) return false; if (RemoveControlFlowFlag && (isa<BranchInst>(I) || isa<SwitchInst>(I))) return false; @@ -507,7 +508,7 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { if (isLive(&I)) continue; - if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&I)) { + if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&I)) { // Check if the scope of this variable location is alive. if (AliveScopes.count(DII->getDebugLoc()->getScope())) continue; @@ -614,8 +615,8 @@ void AggressiveDeadCodeElimination::updateDeadRegions() { } } - DT.applyUpdates(DeletedEdges); - PDT.applyUpdates(DeletedEdges); + DomTreeUpdater(DT, &PDT, DomTreeUpdater::UpdateStrategy::Eager) + .applyUpdates(DeletedEdges); NumBranchesRemoved += 1; } @@ -642,7 +643,7 @@ void AggressiveDeadCodeElimination::computeReversePostOrder() { void AggressiveDeadCodeElimination::makeUnconditional(BasicBlock *BB, BasicBlock *Target) { - TerminatorInst *PredTerm = BB->getTerminator(); + Instruction *PredTerm = BB->getTerminator(); // Collect the live debug info scopes attached to this instruction. if (const DILocation *DL = PredTerm->getDebugLoc()) collectLiveScopes(*DL); @@ -671,7 +672,9 @@ void AggressiveDeadCodeElimination::makeUnconditional(BasicBlock *BB, // //===----------------------------------------------------------------------===// PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) { - auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); + // ADCE does not need DominatorTree, but require DominatorTree here + // to update analysis if it is already available. + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F); if (!AggressiveDeadCodeElimination(F, DT, PDT).performDeadCodeElimination()) return PreservedAnalyses::all(); @@ -697,15 +700,16 @@ struct ADCELegacyPass : public FunctionPass { if (skipFunction(F)) return false; - auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + // ADCE does not need DominatorTree, but require DominatorTree here + // to update analysis if it is already available. + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); return AggressiveDeadCodeElimination(F, DT, PDT) .performDeadCodeElimination(); } void getAnalysisUsage(AnalysisUsage &AU) const override { - // We require DominatorTree here only to update and thus preserve it. - AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<PostDominatorTreeWrapperPass>(); if (!RemoveControlFlowFlag) AU.setPreservesCFG(); @@ -723,7 +727,6 @@ char ADCELegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", false, false) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_END(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", false, false) diff --git a/lib/Transforms/Scalar/BDCE.cpp b/lib/Transforms/Scalar/BDCE.cpp index 3a8ef073cb48..d3c9b9a270aa 100644 --- a/lib/Transforms/Scalar/BDCE.cpp +++ b/lib/Transforms/Scalar/BDCE.cpp @@ -38,7 +38,8 @@ STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)"); /// instruction may need to be cleared of assumptions that can no longer be /// guaranteed correct. static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { - assert(I->getType()->isIntegerTy() && "Trivializing a non-integer value?"); + assert(I->getType()->isIntOrIntVectorTy() && + "Trivializing a non-integer value?"); // Initialize the worklist with eligible direct users. SmallVector<Instruction *, 16> WorkList; @@ -46,13 +47,13 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { // If all bits of a user are demanded, then we know that nothing below that // in the def-use chain needs to be changed. auto *J = dyn_cast<Instruction>(JU); - if (J && J->getType()->isSized() && + if (J && J->getType()->isIntOrIntVectorTy() && !DB.getDemandedBits(J).isAllOnesValue()) WorkList.push_back(J); - // Note that we need to check for unsized types above before asking for + // Note that we need to check for non-int types above before asking for // demanded bits. Normally, the only way to reach an instruction with an - // unsized type is via an instruction that has side effects (or otherwise + // non-int type is via an instruction that has side effects (or otherwise // will demand its input bits). However, if we have a readnone function // that returns an unsized type (e.g., void), we must avoid asking for the // demanded bits of the function call's return value. A void-returning @@ -78,7 +79,7 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { // If all bits of a user are demanded, then we know that nothing below // that in the def-use chain needs to be changed. auto *K = dyn_cast<Instruction>(KU); - if (K && !Visited.count(K) && K->getType()->isSized() && + if (K && !Visited.count(K) && K->getType()->isIntOrIntVectorTy() && !DB.getDemandedBits(K).isAllOnesValue()) WorkList.push_back(K); } @@ -95,30 +96,41 @@ static bool bitTrackingDCE(Function &F, DemandedBits &DB) { if (I.mayHaveSideEffects() && I.use_empty()) continue; - if (I.getType()->isIntegerTy() && - !DB.getDemandedBits(&I).getBoolValue()) { - // For live instructions that have all dead bits, first make them dead by - // replacing all uses with something else. Then, if they don't need to - // remain live (because they have side effects, etc.) we can remove them. - LLVM_DEBUG(dbgs() << "BDCE: Trivializing: " << I << " (all bits dead)\n"); + // Remove instructions that are dead, either because they were not reached + // during analysis or have no demanded bits. + if (DB.isInstructionDead(&I) || + (I.getType()->isIntOrIntVectorTy() && + DB.getDemandedBits(&I).isNullValue() && + wouldInstructionBeTriviallyDead(&I))) { + salvageDebugInfo(I); + Worklist.push_back(&I); + I.dropAllReferences(); + Changed = true; + continue; + } + + for (Use &U : I.operands()) { + // DemandedBits only detects dead integer uses. + if (!U->getType()->isIntOrIntVectorTy()) + continue; + + if (!isa<Instruction>(U) && !isa<Argument>(U)) + continue; + + if (!DB.isUseDead(&U)) + continue; + + LLVM_DEBUG(dbgs() << "BDCE: Trivializing: " << U << " (all bits dead)\n"); clearAssumptionsOfUsers(&I, DB); // FIXME: In theory we could substitute undef here instead of zero. // This should be reconsidered once we settle on the semantics of // undef, poison, etc. - Value *Zero = ConstantInt::get(I.getType(), 0); + U.set(ConstantInt::get(U->getType(), 0)); ++NumSimplified; - I.replaceNonMetadataUsesWith(Zero); Changed = true; } - if (!DB.isInstructionDead(&I)) - continue; - - salvageDebugInfo(I); - Worklist.push_back(&I); - I.dropAllReferences(); - Changed = true; } for (Instruction *&I : Worklist) { diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index fce37d4bffb8..e3548ce5cd0a 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -45,6 +45,7 @@ add_llvm_library(LLVMScalarOpts LowerAtomic.cpp LowerExpectIntrinsic.cpp LowerGuardIntrinsic.cpp + MakeGuardsExplicit.cpp MemCpyOptimizer.cpp MergeICmps.cpp MergedLoadStoreMotion.cpp @@ -68,6 +69,7 @@ add_llvm_library(LLVMScalarOpts StraightLineStrengthReduce.cpp StructurizeCFG.cpp TailRecursionElimination.cpp + WarnMissedTransforms.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/Scalar/CallSiteSplitting.cpp b/lib/Transforms/Scalar/CallSiteSplitting.cpp index 5ebfbf8a879b..a806d6faed60 100644 --- a/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -149,14 +149,14 @@ static void recordCondition(CallSite CS, BasicBlock *From, BasicBlock *To, /// Record ICmp conditions relevant to any argument in CS following Pred's /// single predecessors. If there are conflicting conditions along a path, like -/// x == 1 and x == 0, the first condition will be used. +/// x == 1 and x == 0, the first condition will be used. We stop once we reach +/// an edge to StopAt. static void recordConditions(CallSite CS, BasicBlock *Pred, - ConditionsTy &Conditions) { - recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); + ConditionsTy &Conditions, BasicBlock *StopAt) { BasicBlock *From = Pred; BasicBlock *To = Pred; SmallPtrSet<BasicBlock *, 4> Visited; - while (!Visited.count(From->getSinglePredecessor()) && + while (To != StopAt && !Visited.count(From->getSinglePredecessor()) && (From = From->getSinglePredecessor())) { recordCondition(CS, From, To, Conditions); Visited.insert(From); @@ -197,7 +197,7 @@ static bool canSplitCallSite(CallSite CS, TargetTransformInfo &TTI) { isa<IndirectBrInst>(Preds[1]->getTerminator())) return false; - // BasicBlock::canSplitPredecessors is more agressive, so checking for + // BasicBlock::canSplitPredecessors is more aggressive, so checking for // BasicBlock::isEHPad as well. if (!CallSiteBB->canSplitPredecessors() || CallSiteBB->isEHPad()) return false; @@ -248,7 +248,7 @@ static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI, ReturnInst* RI = dyn_cast<ReturnInst>(&*II); assert(RI && "`musttail` call must be followed by `ret` instruction"); - TerminatorInst *TI = SplitBB->getTerminator(); + Instruction *TI = SplitBB->getTerminator(); Value *V = NewCI; if (BCI) V = cloneInstForMustTail(BCI, TI, V); @@ -302,7 +302,7 @@ static void copyMustTailReturn(BasicBlock *SplitBB, Instruction *CI, static void splitCallSite( CallSite CS, const SmallVectorImpl<std::pair<BasicBlock *, ConditionsTy>> &Preds, - DominatorTree *DT) { + DomTreeUpdater &DTU) { Instruction *Instr = CS.getInstruction(); BasicBlock *TailBB = Instr->getParent(); bool IsMustTailCall = CS.isMustTailCall(); @@ -312,8 +312,10 @@ static void splitCallSite( // `musttail` calls must be followed by optional `bitcast`, and `ret`. The // split blocks will be terminated right after that so there're no users for // this phi in a `TailBB`. - if (!IsMustTailCall && !Instr->use_empty()) + if (!IsMustTailCall && !Instr->use_empty()) { CallPN = PHINode::Create(Instr->getType(), Preds.size(), "phi.call"); + CallPN->setDebugLoc(Instr->getDebugLoc()); + } LLVM_DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); @@ -325,7 +327,7 @@ static void splitCallSite( BasicBlock *PredBB = Preds[i].first; BasicBlock *SplitBlock = DuplicateInstructionsInSplitBetween( TailBB, PredBB, &*std::next(Instr->getIterator()), ValueToValueMaps[i], - DT); + DTU); assert(SplitBlock && "Unexpected new basic block split."); Instruction *NewCI = @@ -363,11 +365,13 @@ static void splitCallSite( // attempting removal. SmallVector<BasicBlock *, 2> Splits(predecessors((TailBB))); assert(Splits.size() == 2 && "Expected exactly 2 splits!"); - for (unsigned i = 0; i < Splits.size(); i++) + for (unsigned i = 0; i < Splits.size(); i++) { Splits[i]->getTerminator()->eraseFromParent(); + DTU.deleteEdge(Splits[i], TailBB); + } // Erase the tail block once done with musttail patching - TailBB->eraseFromParent(); + DTU.deleteBB(TailBB); return; } @@ -394,6 +398,7 @@ static void splitCallSite( if (isa<PHINode>(CurrentI)) continue; PHINode *NewPN = PHINode::Create(CurrentI->getType(), Preds.size()); + NewPN->setDebugLoc(CurrentI->getDebugLoc()); for (auto &Mapping : ValueToValueMaps) NewPN->addIncoming(Mapping[CurrentI], cast<Instruction>(Mapping[CurrentI])->getParent()); @@ -435,49 +440,73 @@ static bool isPredicatedOnPHI(CallSite CS) { return false; } -static bool tryToSplitOnPHIPredicatedArgument(CallSite CS, DominatorTree *DT) { +using PredsWithCondsTy = SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2>; + +// Check if any of the arguments in CS are predicated on a PHI node and return +// the set of predecessors we should use for splitting. +static PredsWithCondsTy shouldSplitOnPHIPredicatedArgument(CallSite CS) { if (!isPredicatedOnPHI(CS)) - return false; + return {}; auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); - SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS = { - {Preds[0], {}}, {Preds[1], {}}}; - splitCallSite(CS, PredsCS, DT); - return true; + return {{Preds[0], {}}, {Preds[1], {}}}; } -static bool tryToSplitOnPredicatedArgument(CallSite CS, DominatorTree *DT) { +// Checks if any of the arguments in CS are predicated in a predecessor and +// returns a list of predecessors with the conditions that hold on their edges +// to CS. +static PredsWithCondsTy shouldSplitOnPredicatedArgument(CallSite CS, + DomTreeUpdater &DTU) { auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); if (Preds[0] == Preds[1]) - return false; + return {}; + + // We can stop recording conditions once we reached the immediate dominator + // for the block containing the call site. Conditions in predecessors of the + // that node will be the same for all paths to the call site and splitting + // is not beneficial. + assert(DTU.hasDomTree() && "We need a DTU with a valid DT!"); + auto *CSDTNode = DTU.getDomTree().getNode(CS.getInstruction()->getParent()); + BasicBlock *StopAt = CSDTNode ? CSDTNode->getIDom()->getBlock() : nullptr; SmallVector<std::pair<BasicBlock *, ConditionsTy>, 2> PredsCS; for (auto *Pred : make_range(Preds.rbegin(), Preds.rend())) { ConditionsTy Conditions; - recordConditions(CS, Pred, Conditions); + // Record condition on edge BB(CS) <- Pred + recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); + // Record conditions following Pred's single predecessors. + recordConditions(CS, Pred, Conditions, StopAt); PredsCS.push_back({Pred, Conditions}); } - if (std::all_of(PredsCS.begin(), PredsCS.end(), - [](const std::pair<BasicBlock *, ConditionsTy> &P) { - return P.second.empty(); - })) - return false; + if (all_of(PredsCS, [](const std::pair<BasicBlock *, ConditionsTy> &P) { + return P.second.empty(); + })) + return {}; - splitCallSite(CS, PredsCS, DT); - return true; + return PredsCS; } static bool tryToSplitCallSite(CallSite CS, TargetTransformInfo &TTI, - DominatorTree *DT) { + DomTreeUpdater &DTU) { + // Check if we can split the call site. if (!CS.arg_size() || !canSplitCallSite(CS, TTI)) return false; - return tryToSplitOnPredicatedArgument(CS, DT) || - tryToSplitOnPHIPredicatedArgument(CS, DT); + + auto PredsWithConds = shouldSplitOnPredicatedArgument(CS, DTU); + if (PredsWithConds.empty()) + PredsWithConds = shouldSplitOnPHIPredicatedArgument(CS); + if (PredsWithConds.empty()) + return false; + + splitCallSite(CS, PredsWithConds, DTU); + return true; } static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI, - TargetTransformInfo &TTI, DominatorTree *DT) { + TargetTransformInfo &TTI, DominatorTree &DT) { + + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy); bool Changed = false; for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) { BasicBlock &BB = *BI++; @@ -501,7 +530,7 @@ static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI, // Check if such path is possible before attempting the splitting. bool IsMustTail = CS.isMustTailCall(); - Changed |= tryToSplitCallSite(CS, TTI, DT); + Changed |= tryToSplitCallSite(CS, TTI, DTU); // There're no interesting instructions after this. The call site // itself might have been erased on splitting. @@ -522,6 +551,7 @@ struct CallSiteSplittingLegacyPass : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); FunctionPass::getAnalysisUsage(AU); } @@ -532,9 +562,8 @@ struct CallSiteSplittingLegacyPass : public FunctionPass { auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - return doCallSiteSplitting(F, TLI, TTI, - DTWP ? &DTWP->getDomTree() : nullptr); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + return doCallSiteSplitting(F, TLI, TTI, DT); } }; } // namespace @@ -544,6 +573,7 @@ INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting", "Call-site splitting", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting", "Call-site splitting", false, false) FunctionPass *llvm::createCallSiteSplittingPass() { @@ -554,7 +584,7 @@ PreservedAnalyses CallSiteSplittingPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &TTI = AM.getResult<TargetIRAnalysis>(F); - auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); if (!doCallSiteSplitting(F, TLI, TTI, DT)) return PreservedAnalyses::all(); diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index 55759e8b1661..beac0d967a98 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -82,6 +82,16 @@ static cl::opt<bool> ConstHoistWithBlockFrequency( "chance to execute const materialization more frequently than " "without hoisting.")); +static cl::opt<bool> ConstHoistGEP( + "consthoist-gep", cl::init(false), cl::Hidden, + cl::desc("Try hoisting constant gep expressions")); + +static cl::opt<unsigned> +MinNumOfDependentToRebase("consthoist-min-num-to-rebase", + cl::desc("Do not rebase if number of dependent constants of a Base is less " + "than this number."), + cl::init(0), cl::Hidden); + namespace { /// The constant hoisting pass. @@ -340,7 +350,7 @@ SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( /// /// The operand at index Idx is not necessarily the constant integer itself. It /// could also be a cast instruction or a constant expression that uses the -// constant integer. +/// constant integer. void ConstantHoistingPass::collectConstantCandidates( ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx, ConstantInt *ConstInt) { @@ -358,12 +368,13 @@ void ConstantHoistingPass::collectConstantCandidates( if (Cost > TargetTransformInfo::TCC_Basic) { ConstCandMapType::iterator Itr; bool Inserted; - std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(ConstInt, 0)); + ConstPtrUnionType Cand = ConstInt; + std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(Cand, 0)); if (Inserted) { - ConstCandVec.push_back(ConstantCandidate(ConstInt)); - Itr->second = ConstCandVec.size() - 1; + ConstIntCandVec.push_back(ConstantCandidate(ConstInt)); + Itr->second = ConstIntCandVec.size() - 1; } - ConstCandVec[Itr->second].addUser(Inst, Idx, Cost); + ConstIntCandVec[Itr->second].addUser(Inst, Idx, Cost); LLVM_DEBUG(if (isa<ConstantInt>(Inst->getOperand(Idx))) dbgs() << "Collect constant " << *ConstInt << " from " << *Inst << " with cost " << Cost << '\n'; @@ -374,6 +385,48 @@ void ConstantHoistingPass::collectConstantCandidates( } } +/// Record constant GEP expression for instruction Inst at operand index Idx. +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx, + ConstantExpr *ConstExpr) { + // TODO: Handle vector GEPs + if (ConstExpr->getType()->isVectorTy()) + return; + + GlobalVariable *BaseGV = dyn_cast<GlobalVariable>(ConstExpr->getOperand(0)); + if (!BaseGV) + return; + + // Get offset from the base GV. + PointerType *GVPtrTy = dyn_cast<PointerType>(BaseGV->getType()); + IntegerType *PtrIntTy = DL->getIntPtrType(*Ctx, GVPtrTy->getAddressSpace()); + APInt Offset(DL->getTypeSizeInBits(PtrIntTy), /*val*/0, /*isSigned*/true); + auto *GEPO = cast<GEPOperator>(ConstExpr); + if (!GEPO->accumulateConstantOffset(*DL, Offset)) + return; + + if (!Offset.isIntN(32)) + return; + + // A constant GEP expression that has a GlobalVariable as base pointer is + // usually lowered to a load from constant pool. Such operation is unlikely + // to be cheaper than compute it by <Base + Offset>, which can be lowered to + // an ADD instruction or folded into Load/Store instruction. + int Cost = TTI->getIntImmCost(Instruction::Add, 1, Offset, PtrIntTy); + ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV]; + ConstCandMapType::iterator Itr; + bool Inserted; + ConstPtrUnionType Cand = ConstExpr; + std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(Cand, 0)); + if (Inserted) { + ExprCandVec.push_back(ConstantCandidate( + ConstantInt::get(Type::getInt32Ty(*Ctx), Offset.getLimitedValue()), + ConstExpr)); + Itr->second = ExprCandVec.size() - 1; + } + ExprCandVec[Itr->second].addUser(Inst, Idx, Cost); +} + /// Check the operand for instruction Inst at index Idx. void ConstantHoistingPass::collectConstantCandidates( ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx) { @@ -402,6 +455,10 @@ void ConstantHoistingPass::collectConstantCandidates( // Visit constant expressions that have constant integers. if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) { + // Handle constant gep expressions. + if (ConstHoistGEP && ConstExpr->isGEPWithNoNotionalOverIndexing()) + collectConstantCandidates(ConstCandMap, Inst, Idx, ConstExpr); + // Only visit constant cast expressions. if (!ConstExpr->isCast()) return; @@ -544,7 +601,8 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, /// Find the base constant within the given range and rebase all other /// constants with respect to the base constant. void ConstantHoistingPass::findAndMakeBaseConstant( - ConstCandVecType::iterator S, ConstCandVecType::iterator E) { + ConstCandVecType::iterator S, ConstCandVecType::iterator E, + SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec) { auto MaxCostItr = S; unsigned NumUses = maximizeConstantsInRange(S, E, MaxCostItr); @@ -552,26 +610,37 @@ void ConstantHoistingPass::findAndMakeBaseConstant( if (NumUses <= 1) return; + ConstantInt *ConstInt = MaxCostItr->ConstInt; + ConstantExpr *ConstExpr = MaxCostItr->ConstExpr; ConstantInfo ConstInfo; - ConstInfo.BaseConstant = MaxCostItr->ConstInt; - Type *Ty = ConstInfo.BaseConstant->getType(); + ConstInfo.BaseInt = ConstInt; + ConstInfo.BaseExpr = ConstExpr; + Type *Ty = ConstInt->getType(); // Rebase the constants with respect to the base constant. for (auto ConstCand = S; ConstCand != E; ++ConstCand) { - APInt Diff = ConstCand->ConstInt->getValue() - - ConstInfo.BaseConstant->getValue(); + APInt Diff = ConstCand->ConstInt->getValue() - ConstInt->getValue(); Constant *Offset = Diff == 0 ? nullptr : ConstantInt::get(Ty, Diff); + Type *ConstTy = + ConstCand->ConstExpr ? ConstCand->ConstExpr->getType() : nullptr; ConstInfo.RebasedConstants.push_back( - RebasedConstantInfo(std::move(ConstCand->Uses), Offset)); + RebasedConstantInfo(std::move(ConstCand->Uses), Offset, ConstTy)); } - ConstantVec.push_back(std::move(ConstInfo)); + ConstInfoVec.push_back(std::move(ConstInfo)); } /// Finds and combines constant candidates that can be easily /// rematerialized with an add from a common base constant. -void ConstantHoistingPass::findBaseConstants() { +void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) { + // If BaseGV is nullptr, find base among candidate constant integers; + // Otherwise find base among constant GEPs that share the same BaseGV. + ConstCandVecType &ConstCandVec = BaseGV ? + ConstGEPCandMap[BaseGV] : ConstIntCandVec; + ConstInfoVecType &ConstInfoVec = BaseGV ? + ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; + // Sort the constants by value and type. This invalidates the mapping! - llvm::sort(ConstCandVec.begin(), ConstCandVec.end(), + std::stable_sort(ConstCandVec.begin(), ConstCandVec.end(), [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) { if (LHS.ConstInt->getType() != RHS.ConstInt->getType()) return LHS.ConstInt->getType()->getBitWidth() < @@ -585,20 +654,40 @@ void ConstantHoistingPass::findBaseConstants() { for (auto CC = std::next(ConstCandVec.begin()), E = ConstCandVec.end(); CC != E; ++CC) { if (MinValItr->ConstInt->getType() == CC->ConstInt->getType()) { + Type *MemUseValTy = nullptr; + for (auto &U : CC->Uses) { + auto *UI = U.Inst; + if (LoadInst *LI = dyn_cast<LoadInst>(UI)) { + MemUseValTy = LI->getType(); + break; + } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) { + // Make sure the constant is used as pointer operand of the StoreInst. + if (SI->getPointerOperand() == SI->getOperand(U.OpndIdx)) { + MemUseValTy = SI->getValueOperand()->getType(); + break; + } + } + } + // Check if the constant is in range of an add with immediate. APInt Diff = CC->ConstInt->getValue() - MinValItr->ConstInt->getValue(); if ((Diff.getBitWidth() <= 64) && - TTI->isLegalAddImmediate(Diff.getSExtValue())) + TTI->isLegalAddImmediate(Diff.getSExtValue()) && + // Check if Diff can be used as offset in addressing mode of the user + // memory instruction. + (!MemUseValTy || TTI->isLegalAddressingMode(MemUseValTy, + /*BaseGV*/nullptr, /*BaseOffset*/Diff.getSExtValue(), + /*HasBaseReg*/true, /*Scale*/0))) continue; } // We either have now a different constant type or the constant is not in // range of an add with immediate anymore. - findAndMakeBaseConstant(MinValItr, CC); + findAndMakeBaseConstant(MinValItr, CC, ConstInfoVec); // Start a new base constant search. MinValItr = CC; } // Finalize the last base constant search. - findAndMakeBaseConstant(MinValItr, ConstCandVec.end()); + findAndMakeBaseConstant(MinValItr, ConstCandVec.end(), ConstInfoVec); } /// Updates the operand at Idx in instruction Inst with the result of @@ -633,12 +722,28 @@ static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) { /// users. void ConstantHoistingPass::emitBaseConstants(Instruction *Base, Constant *Offset, + Type *Ty, const ConstantUser &ConstUser) { Instruction *Mat = Base; + + // The same offset can be dereferenced to different types in nested struct. + if (!Offset && Ty && Ty != Base->getType()) + Offset = ConstantInt::get(Type::getInt32Ty(*Ctx), 0); + if (Offset) { Instruction *InsertionPt = findMatInsertPt(ConstUser.Inst, ConstUser.OpndIdx); - Mat = BinaryOperator::Create(Instruction::Add, Base, Offset, + if (Ty) { + // Constant being rebased is a ConstantExpr. + PointerType *Int8PtrTy = Type::getInt8PtrTy(*Ctx, + cast<PointerType>(Ty)->getAddressSpace()); + Base = new BitCastInst(Base, Int8PtrTy, "base_bitcast", InsertionPt); + Mat = GetElementPtrInst::Create(Int8PtrTy->getElementType(), Base, + Offset, "mat_gep", InsertionPt); + Mat = new BitCastInst(Mat, Ty, "mat_bitcast", InsertionPt); + } else + // Constant being rebased is a ConstantInt. + Mat = BinaryOperator::Create(Instruction::Add, Base, Offset, "const_mat", InsertionPt); LLVM_DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0) @@ -682,6 +787,14 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, // Visit constant expression. if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) { + if (ConstExpr->isGEPWithNoNotionalOverIndexing()) { + // Operand is a ConstantGEP, replace it. + updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat); + return; + } + + // Aside from constant GEPs, only constant cast expressions are collected. + assert(ConstExpr->isCast() && "ConstExpr should be a cast"); Instruction *ConstExprInst = ConstExpr->getAsInstruction(); ConstExprInst->setOperand(0, Mat); ConstExprInst->insertBefore(findMatInsertPt(ConstUser.Inst, @@ -705,28 +818,22 @@ void ConstantHoistingPass::emitBaseConstants(Instruction *Base, /// Hoist and hide the base constant behind a bitcast and emit /// materialization code for derived constants. -bool ConstantHoistingPass::emitBaseConstants() { +bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) { bool MadeChange = false; - for (auto const &ConstInfo : ConstantVec) { - // Hoist and hide the base constant behind a bitcast. + SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec = + BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec; + for (auto const &ConstInfo : ConstInfoVec) { SmallPtrSet<Instruction *, 8> IPSet = findConstantInsertionPoint(ConstInfo); assert(!IPSet.empty() && "IPSet is empty"); unsigned UsesNum = 0; unsigned ReBasesNum = 0; + unsigned NotRebasedNum = 0; for (Instruction *IP : IPSet) { - IntegerType *Ty = ConstInfo.BaseConstant->getType(); - Instruction *Base = - new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP); - - Base->setDebugLoc(IP->getDebugLoc()); - - LLVM_DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant - << ") to BB " << IP->getParent()->getName() << '\n' - << *Base << '\n'); - - // Emit materialization code for all rebased constants. + // First, collect constants depending on this IP of the base. unsigned Uses = 0; + using RebasedUse = std::tuple<Constant *, Type *, ConstantUser>; + SmallVector<RebasedUse, 4> ToBeRebased; for (auto const &RCI : ConstInfo.RebasedConstants) { for (auto const &U : RCI.Uses) { Uses++; @@ -735,31 +842,64 @@ bool ConstantHoistingPass::emitBaseConstants() { // If Base constant is to be inserted in multiple places, // generate rebase for U using the Base dominating U. if (IPSet.size() == 1 || - DT->dominates(Base->getParent(), OrigMatInsertBB)) { - emitBaseConstants(Base, RCI.Offset, U); - ReBasesNum++; - } - - Base->setDebugLoc(DILocation::getMergedLocation(Base->getDebugLoc(), U.Inst->getDebugLoc())); + DT->dominates(IP->getParent(), OrigMatInsertBB)) + ToBeRebased.push_back(RebasedUse(RCI.Offset, RCI.Ty, U)); } } UsesNum = Uses; - // Use the same debug location as the last user of the constant. + // If only few constants depend on this IP of base, skip rebasing, + // assuming the base and the rebased have the same materialization cost. + if (ToBeRebased.size() < MinNumOfDependentToRebase) { + NotRebasedNum += ToBeRebased.size(); + continue; + } + + // Emit an instance of the base at this IP. + Instruction *Base = nullptr; + // Hoist and hide the base constant behind a bitcast. + if (ConstInfo.BaseExpr) { + assert(BaseGV && "A base constant expression must have an base GV"); + Type *Ty = ConstInfo.BaseExpr->getType(); + Base = new BitCastInst(ConstInfo.BaseExpr, Ty, "const", IP); + } else { + IntegerType *Ty = ConstInfo.BaseInt->getType(); + Base = new BitCastInst(ConstInfo.BaseInt, Ty, "const", IP); + } + + Base->setDebugLoc(IP->getDebugLoc()); + + LLVM_DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseInt + << ") to BB " << IP->getParent()->getName() << '\n' + << *Base << '\n'); + + // Emit materialization code for rebased constants depending on this IP. + for (auto const &R : ToBeRebased) { + Constant *Off = std::get<0>(R); + Type *Ty = std::get<1>(R); + ConstantUser U = std::get<2>(R); + emitBaseConstants(Base, Off, Ty, U); + ReBasesNum++; + // Use the same debug location as the last user of the constant. + Base->setDebugLoc(DILocation::getMergedLocation( + Base->getDebugLoc(), U.Inst->getDebugLoc())); + } assert(!Base->use_empty() && "The use list is empty!?"); assert(isa<Instruction>(Base->user_back()) && "All uses should be instructions."); } (void)UsesNum; (void)ReBasesNum; + (void)NotRebasedNum; // Expect all uses are rebased after rebase is done. - assert(UsesNum == ReBasesNum && "Not all uses are rebased"); + assert(UsesNum == (ReBasesNum + NotRebasedNum) && + "Not all uses are rebased"); NumConstantsHoisted++; // Base constant is also included in ConstInfo.RebasedConstants, so // deduct 1 from ConstInfo.RebasedConstants.size(). - NumConstantsRebased = ConstInfo.RebasedConstants.size() - 1; + NumConstantsRebased += ConstInfo.RebasedConstants.size() - 1; MadeChange = true; } @@ -781,25 +921,29 @@ bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, this->TTI = &TTI; this->DT = &DT; this->BFI = BFI; + this->DL = &Fn.getParent()->getDataLayout(); + this->Ctx = &Fn.getContext(); this->Entry = &Entry; // Collect all constant candidates. collectConstantCandidates(Fn); - // There are no constant candidates to worry about. - if (ConstCandVec.empty()) - return false; - // Combine constants that can be easily materialized with an add from a common // base constant. - findBaseConstants(); - - // There are no constants to emit. - if (ConstantVec.empty()) - return false; + if (!ConstIntCandVec.empty()) + findBaseConstants(nullptr); + for (auto &MapEntry : ConstGEPCandMap) + if (!MapEntry.second.empty()) + findBaseConstants(MapEntry.first); // Finally hoist the base constant and emit materialization code for dependent // constants. - bool MadeChange = emitBaseConstants(); + bool MadeChange = false; + if (!ConstIntInfoVec.empty()) + MadeChange = emitBaseConstants(nullptr); + for (auto MapEntry : ConstGEPInfoMap) + if (!MapEntry.second.empty()) + MadeChange |= emitBaseConstants(MapEntry.first); + // Cleanup dead instructions. deleteDeadCastInst(); diff --git a/lib/Transforms/Scalar/ConstantProp.cpp b/lib/Transforms/Scalar/ConstantProp.cpp index 46915889ce7c..51032b0625f8 100644 --- a/lib/Transforms/Scalar/ConstantProp.cpp +++ b/lib/Transforms/Scalar/ConstantProp.cpp @@ -18,21 +18,25 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Constant.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/Pass.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" -#include <set> +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "constprop" STATISTIC(NumInstKilled, "Number of instructions killed"); +DEBUG_COUNTER(CPCounter, "constprop-transform", + "Controls which instructions are killed"); namespace { struct ConstantPropagation : public FunctionPass { @@ -66,9 +70,15 @@ bool ConstantPropagation::runOnFunction(Function &F) { return false; // Initialize the worklist to all of the instructions ready to process... - std::set<Instruction*> WorkList; - for (Instruction &I: instructions(&F)) + SmallPtrSet<Instruction *, 16> WorkList; + // The SmallVector of WorkList ensures that we do iteration at stable order. + // We use two containers rather than one SetVector, since remove is + // linear-time, and we don't care enough to remove from Vec. + SmallVector<Instruction *, 16> WorkListVec; + for (Instruction &I : instructions(&F)) { WorkList.insert(&I); + WorkListVec.push_back(&I); + } bool Changed = false; const DataLayout &DL = F.getParent()->getDataLayout(); @@ -76,29 +86,36 @@ bool ConstantPropagation::runOnFunction(Function &F) { &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); while (!WorkList.empty()) { - Instruction *I = *WorkList.begin(); - WorkList.erase(WorkList.begin()); // Get an element from the worklist... - - if (!I->use_empty()) // Don't muck with dead instructions... - if (Constant *C = ConstantFoldInstruction(I, DL, TLI)) { - // Add all of the users of this instruction to the worklist, they might - // be constant propagatable now... - for (User *U : I->users()) - WorkList.insert(cast<Instruction>(U)); - - // Replace all of the uses of a variable with uses of the constant. - I->replaceAllUsesWith(C); - - // Remove the dead instruction. - WorkList.erase(I); - if (isInstructionTriviallyDead(I, TLI)) { - I->eraseFromParent(); - ++NumInstKilled; + SmallVector<Instruction*, 16> NewWorkListVec; + for (auto *I : WorkListVec) { + WorkList.erase(I); // Remove element from the worklist... + + if (!I->use_empty()) // Don't muck with dead instructions... + if (Constant *C = ConstantFoldInstruction(I, DL, TLI)) { + if (!DebugCounter::shouldExecute(CPCounter)) + continue; + + // Add all of the users of this instruction to the worklist, they might + // be constant propagatable now... + for (User *U : I->users()) { + // If user not in the set, then add it to the vector. + if (WorkList.insert(cast<Instruction>(U)).second) + NewWorkListVec.push_back(cast<Instruction>(U)); + } + + // Replace all of the uses of a variable with uses of the constant. + I->replaceAllUsesWith(C); + + if (isInstructionTriviallyDead(I, TLI)) { + I->eraseFromParent(); + ++NumInstKilled; + } + + // We made a change to the function... + Changed = true; } - - // We made a change to the function... - Changed = true; - } + } + WorkListVec = std::move(NewWorkListVec); } return Changed; } diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 2f2d7f620a29..d0105701c73f 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -19,7 +19,6 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -28,6 +27,7 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" @@ -44,6 +44,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <utility> @@ -272,10 +273,11 @@ static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) { /// information is sufficient to prove this comparison. Even for local /// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. -static bool processCmp(CmpInst *C, LazyValueInfo *LVI) { - Value *Op0 = C->getOperand(0); - Constant *Op1 = dyn_cast<Constant>(C->getOperand(1)); - if (!Op1) return false; +static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) { + Value *Op0 = Cmp->getOperand(0); + auto *C = dyn_cast<Constant>(Cmp->getOperand(1)); + if (!C) + return false; // As a policy choice, we choose not to waste compile time on anything where // the comparison is testing local values. While LVI can sometimes reason @@ -283,20 +285,18 @@ static bool processCmp(CmpInst *C, LazyValueInfo *LVI) { // the block local query for uses from terminator instructions, but that's // handled in the code for each terminator. auto *I = dyn_cast<Instruction>(Op0); - if (I && I->getParent() == C->getParent()) + if (I && I->getParent() == Cmp->getParent()) return false; LazyValueInfo::Tristate Result = - LVI->getPredicateAt(C->getPredicate(), Op0, Op1, C); - if (Result == LazyValueInfo::Unknown) return false; + LVI->getPredicateAt(Cmp->getPredicate(), Op0, C, Cmp); + if (Result == LazyValueInfo::Unknown) + return false; ++NumCmps; - if (Result == LazyValueInfo::True) - C->replaceAllUsesWith(ConstantInt::getTrue(C->getContext())); - else - C->replaceAllUsesWith(ConstantInt::getFalse(C->getContext())); - C->eraseFromParent(); - + Constant *TorF = ConstantInt::get(Type::getInt1Ty(Cmp->getContext()), Result); + Cmp->replaceAllUsesWith(TorF); + Cmp->eraseFromParent(); return true; } @@ -307,7 +307,9 @@ static bool processCmp(CmpInst *C, LazyValueInfo *LVI) { /// that cannot fire no matter what the incoming edge can safely be removed. If /// a case fires on every incoming edge then the entire switch can be removed /// and replaced with a branch to the case destination. -static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, DominatorTree *DT) { +static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, + DominatorTree *DT) { + DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy); Value *Cond = SI->getCondition(); BasicBlock *BB = SI->getParent(); @@ -372,7 +374,7 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, DominatorTree *DT) ++NumDeadCases; Changed = true; if (--SuccessorsCount[Succ] == 0) - DT->deleteEdge(BB, Succ); + DTU.deleteEdge(BB, Succ); continue; } if (State == LazyValueInfo::True) { @@ -389,15 +391,11 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI, DominatorTree *DT) ++CI; } - if (Changed) { + if (Changed) // If the switch has been simplified to the point where it can be replaced // by a branch then do so now. - DeferredDominance DDT(*DT); ConstantFoldTerminator(BB, /*DeleteDeadConditions = */ false, - /*TLI = */ nullptr, &DDT); - DDT.flush(); - } - + /*TLI = */ nullptr, &DTU); return Changed; } @@ -432,23 +430,21 @@ static bool willNotOverflow(IntrinsicInst *II, LazyValueInfo *LVI) { } static void processOverflowIntrinsic(IntrinsicInst *II) { + IRBuilder<> B(II); Value *NewOp = nullptr; switch (II->getIntrinsicID()) { default: llvm_unreachable("Unexpected instruction."); case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: - NewOp = BinaryOperator::CreateAdd(II->getOperand(0), II->getOperand(1), - II->getName(), II); + NewOp = B.CreateAdd(II->getOperand(0), II->getOperand(1), II->getName()); break; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: - NewOp = BinaryOperator::CreateSub(II->getOperand(0), II->getOperand(1), - II->getName(), II); + NewOp = B.CreateSub(II->getOperand(0), II->getOperand(1), II->getName()); break; } ++NumOverflows; - IRBuilder<> B(II); Value *NewI = B.CreateInsertValue(UndefValue::get(II->getType()), NewOp, 0); NewI = B.CreateInsertValue(NewI, ConstantInt::getFalse(II->getContext()), 1); II->replaceAllUsesWith(NewI); @@ -530,17 +526,17 @@ static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { return false; ++NumUDivs; + IRBuilder<> B{Instr}; auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth); - auto *LHS = CastInst::Create(Instruction::Trunc, Instr->getOperand(0), TruncTy, - Instr->getName() + ".lhs.trunc", Instr); - auto *RHS = CastInst::Create(Instruction::Trunc, Instr->getOperand(1), TruncTy, - Instr->getName() + ".rhs.trunc", Instr); - auto *BO = - BinaryOperator::Create(Instr->getOpcode(), LHS, RHS, Instr->getName(), Instr); - auto *Zext = CastInst::Create(Instruction::ZExt, BO, Instr->getType(), - Instr->getName() + ".zext", Instr); - if (BO->getOpcode() == Instruction::UDiv) - BO->setIsExact(Instr->isExact()); + auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy, + Instr->getName() + ".lhs.trunc"); + auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy, + Instr->getName() + ".rhs.trunc"); + auto *BO = B.CreateBinOp(Instr->getOpcode(), LHS, RHS, Instr->getName()); + auto *Zext = B.CreateZExt(BO, Instr->getType(), Instr->getName() + ".zext"); + if (auto *BinOp = dyn_cast<BinaryOperator>(BO)) + if (BinOp->getOpcode() == Instruction::UDiv) + BinOp->setIsExact(Instr->isExact()); Instr->replaceAllUsesWith(Zext); Instr->eraseFromParent(); @@ -554,6 +550,7 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { ++NumSRems; auto *BO = BinaryOperator::CreateURem(SDI->getOperand(0), SDI->getOperand(1), SDI->getName(), SDI); + BO->setDebugLoc(SDI->getDebugLoc()); SDI->replaceAllUsesWith(BO); SDI->eraseFromParent(); @@ -575,6 +572,7 @@ static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { ++NumSDivs; auto *BO = BinaryOperator::CreateUDiv(SDI->getOperand(0), SDI->getOperand(1), SDI->getName(), SDI); + BO->setDebugLoc(SDI->getDebugLoc()); BO->setIsExact(SDI->isExact()); SDI->replaceAllUsesWith(BO); SDI->eraseFromParent(); @@ -597,6 +595,7 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { ++NumAShrs; auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1), SDI->getName(), SDI); + BO->setDebugLoc(SDI->getDebugLoc()); BO->setIsExact(SDI->isExact()); SDI->replaceAllUsesWith(BO); SDI->eraseFromParent(); diff --git a/lib/Transforms/Scalar/DCE.cpp b/lib/Transforms/Scalar/DCE.cpp index 6078967a0f94..4c964e6e888c 100644 --- a/lib/Transforms/Scalar/DCE.cpp +++ b/lib/Transforms/Scalar/DCE.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/Pass.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; @@ -31,6 +32,8 @@ using namespace llvm; STATISTIC(DIEEliminated, "Number of insts removed by DIE pass"); STATISTIC(DCEEliminated, "Number of insts removed"); +DEBUG_COUNTER(DCECounter, "dce-transform", + "Controls which instructions are eliminated"); namespace { //===--------------------------------------------------------------------===// @@ -50,6 +53,8 @@ namespace { for (BasicBlock::iterator DI = BB.begin(); DI != BB.end(); ) { Instruction *Inst = &*DI++; if (isInstructionTriviallyDead(Inst, TLI)) { + if (!DebugCounter::shouldExecute(DCECounter)) + continue; salvageDebugInfo(*Inst); Inst->eraseFromParent(); Changed = true; @@ -77,6 +82,9 @@ static bool DCEInstruction(Instruction *I, SmallSetVector<Instruction *, 16> &WorkList, const TargetLibraryInfo *TLI) { if (isInstructionTriviallyDead(I, TLI)) { + if (!DebugCounter::shouldExecute(DCECounter)) + return false; + salvageDebugInfo(*I); // Null out all of the instruction's operands to see if any operand becomes diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index 9a7405e98e7d..469930ca6a19 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -71,7 +71,7 @@ using namespace llvm; STATISTIC(NumRedundantStores, "Number of redundant stores deleted"); STATISTIC(NumFastStores, "Number of stores deleted"); -STATISTIC(NumFastOther , "Number of other instrs removed"); +STATISTIC(NumFastOther, "Number of other instrs removed"); STATISTIC(NumCompletePartials, "Number of stores dead by later partials"); STATISTIC(NumModifiedStores, "Number of stores modified"); @@ -349,11 +349,14 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, InstOverlapIntervalsTy &IOL, AliasAnalysis &AA, const Function *F) { - // If we don't know the sizes of either access, then we can't do a comparison. - if (Later.Size == MemoryLocation::UnknownSize || - Earlier.Size == MemoryLocation::UnknownSize) + // FIXME: Vet that this works for size upper-bounds. Seems unlikely that we'll + // get imprecise values here, though (except for unknown sizes). + if (!Later.Size.isPrecise() || !Earlier.Size.isPrecise()) return OW_Unknown; + const uint64_t LaterSize = Later.Size.getValue(); + const uint64_t EarlierSize = Earlier.Size.getValue(); + const Value *P1 = Earlier.Ptr->stripPointerCasts(); const Value *P2 = Later.Ptr->stripPointerCasts(); @@ -361,7 +364,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // the later store was larger than the earlier store. if (P1 == P2 || AA.isMustAlias(P1, P2)) { // Make sure that the Later size is >= the Earlier size. - if (Later.Size >= Earlier.Size) + if (LaterSize >= EarlierSize) return OW_Complete; } @@ -379,7 +382,7 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // If the "Later" store is to a recognizable object, get its size. uint64_t ObjectSize = getPointerSize(UO2, DL, TLI, F); if (ObjectSize != MemoryLocation::UnknownSize) - if (ObjectSize == Later.Size && ObjectSize >= Earlier.Size) + if (ObjectSize == LaterSize && ObjectSize >= EarlierSize) return OW_Complete; // Okay, we have stores to two completely different pointers. Try to @@ -410,8 +413,8 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // // We have to be careful here as *Off is signed while *.Size is unsigned. if (EarlierOff >= LaterOff && - Later.Size >= Earlier.Size && - uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size) + LaterSize >= EarlierSize && + uint64_t(EarlierOff - LaterOff) + EarlierSize <= LaterSize) return OW_Complete; // We may now overlap, although the overlap is not complete. There might also @@ -420,21 +423,21 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // Note: The correctness of this logic depends on the fact that this function // is not even called providing DepWrite when there are any intervening reads. if (EnablePartialOverwriteTracking && - LaterOff < int64_t(EarlierOff + Earlier.Size) && - int64_t(LaterOff + Later.Size) >= EarlierOff) { + LaterOff < int64_t(EarlierOff + EarlierSize) && + int64_t(LaterOff + LaterSize) >= EarlierOff) { // Insert our part of the overlap into the map. auto &IM = IOL[DepWrite]; LLVM_DEBUG(dbgs() << "DSE: Partial overwrite: Earlier [" << EarlierOff - << ", " << int64_t(EarlierOff + Earlier.Size) + << ", " << int64_t(EarlierOff + EarlierSize) << ") Later [" << LaterOff << ", " - << int64_t(LaterOff + Later.Size) << ")\n"); + << int64_t(LaterOff + LaterSize) << ")\n"); // Make sure that we only insert non-overlapping intervals and combine // adjacent intervals. The intervals are stored in the map with the ending // offset as the key (in the half-open sense) and the starting offset as // the value. - int64_t LaterIntStart = LaterOff, LaterIntEnd = LaterOff + Later.Size; + int64_t LaterIntStart = LaterOff, LaterIntEnd = LaterOff + LaterSize; // Find any intervals ending at, or after, LaterIntStart which start // before LaterIntEnd. @@ -464,10 +467,10 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, ILI = IM.begin(); if (ILI->second <= EarlierOff && - ILI->first >= int64_t(EarlierOff + Earlier.Size)) { + ILI->first >= int64_t(EarlierOff + EarlierSize)) { LLVM_DEBUG(dbgs() << "DSE: Full overwrite from partials: Earlier [" << EarlierOff << ", " - << int64_t(EarlierOff + Earlier.Size) + << int64_t(EarlierOff + EarlierSize) << ") Composite Later [" << ILI->second << ", " << ILI->first << ")\n"); ++NumCompletePartials; @@ -478,13 +481,13 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // Check for an earlier store which writes to all the memory locations that // the later store writes to. if (EnablePartialStoreMerging && LaterOff >= EarlierOff && - int64_t(EarlierOff + Earlier.Size) > LaterOff && - uint64_t(LaterOff - EarlierOff) + Later.Size <= Earlier.Size) { + int64_t(EarlierOff + EarlierSize) > LaterOff && + uint64_t(LaterOff - EarlierOff) + LaterSize <= EarlierSize) { LLVM_DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" << EarlierOff << ", " - << int64_t(EarlierOff + Earlier.Size) + << int64_t(EarlierOff + EarlierSize) << ") by a later store [" << LaterOff << ", " - << int64_t(LaterOff + Later.Size) << ")\n"); + << int64_t(LaterOff + LaterSize) << ")\n"); // TODO: Maybe come up with a better name? return OW_PartialEarlierWithFullLater; } @@ -498,8 +501,8 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // In this case we may want to trim the size of earlier to avoid generating // writes to addresses which will definitely be overwritten later if (!EnablePartialOverwriteTracking && - (LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + Earlier.Size) && - int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size))) + (LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + EarlierSize) && + int64_t(LaterOff + LaterSize) >= int64_t(EarlierOff + EarlierSize))) return OW_End; // Finally, we also need to check if the later store overwrites the beginning @@ -512,9 +515,8 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, // of earlier to avoid generating writes to addresses which will definitely // be overwritten later. if (!EnablePartialOverwriteTracking && - (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff)) { - assert(int64_t(LaterOff + Later.Size) < - int64_t(EarlierOff + Earlier.Size) && + (LaterOff <= EarlierOff && int64_t(LaterOff + LaterSize) > EarlierOff)) { + assert(int64_t(LaterOff + LaterSize) < int64_t(EarlierOff + EarlierSize) && "Expect to be handled as OW_Complete"); return OW_Begin; } @@ -641,7 +643,7 @@ static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, for (pred_iterator I = pred_begin(BB), E = pred_end(BB); I != E; ++I) { BasicBlock *Pred = *I; if (Pred == BB) continue; - TerminatorInst *PredTI = Pred->getTerminator(); + Instruction *PredTI = Pred->getTerminator(); if (PredTI->getNumSuccessors() != 1) continue; @@ -832,7 +834,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, continue; } - if (auto CS = CallSite(&*BBI)) { + if (auto *Call = dyn_cast<CallBase>(&*BBI)) { // Remove allocation function calls from the list of dead stack objects; // there can't be any references before the definition. if (isAllocLikeFn(&*BBI, TLI)) @@ -840,15 +842,15 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // If this call does not access memory, it can't be loading any of our // pointers. - if (AA->doesNotAccessMemory(CS)) + if (AA->doesNotAccessMemory(Call)) continue; // If the call might load from any of our allocas, then any store above // the call is live. DeadStackObjects.remove_if([&](Value *I) { // See if the call site touches the value. - return isRefSet(AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI, - BB.getParent()))); + return isRefSet(AA->getModRefInfo( + Call, I, getPointerSize(I, DL, *TLI, BB.getParent()))); }); // If all of the allocas were clobbered by the call then we're not going @@ -1002,11 +1004,10 @@ static bool removePartiallyOverlappedStores(AliasAnalysis *AA, Instruction *EarlierWrite = OI.first; MemoryLocation Loc = getLocForWrite(EarlierWrite); assert(isRemovable(EarlierWrite) && "Expect only removable instruction"); - assert(Loc.Size != MemoryLocation::UnknownSize && "Unexpected mem loc"); const Value *Ptr = Loc.Ptr->stripPointerCasts(); int64_t EarlierStart = 0; - int64_t EarlierSize = int64_t(Loc.Size); + int64_t EarlierSize = int64_t(Loc.Size.getValue()); GetPointerBaseWithConstantOffset(Ptr, EarlierStart, DL); OverlapIntervalsTy &IntervalMap = OI.second; Changed |= @@ -1203,8 +1204,9 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, assert(!EnablePartialOverwriteTracking && "Do not expect to perform " "when partial-overwrite " "tracking is enabled"); - int64_t EarlierSize = DepLoc.Size; - int64_t LaterSize = Loc.Size; + // The overwrite result is known, so these must be known, too. + int64_t EarlierSize = DepLoc.Size.getValue(); + int64_t LaterSize = Loc.Size.getValue(); bool IsOverwriteEnd = (OR == OW_End); MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize, InstWriteOffset, LaterSize, IsOverwriteEnd); diff --git a/lib/Transforms/Scalar/DivRemPairs.cpp b/lib/Transforms/Scalar/DivRemPairs.cpp index e1bc590c5c9a..ffcf34f1cf7a 100644 --- a/lib/Transforms/Scalar/DivRemPairs.cpp +++ b/lib/Transforms/Scalar/DivRemPairs.cpp @@ -21,6 +21,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/Pass.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BypassSlowDivision.h" using namespace llvm; @@ -29,6 +30,8 @@ using namespace llvm; STATISTIC(NumPairs, "Number of div/rem pairs"); STATISTIC(NumHoisted, "Number of instructions hoisted"); STATISTIC(NumDecomposed, "Number of instructions decomposed"); +DEBUG_COUNTER(DRPCounter, "div-rem-pairs-transform", + "Controls transformations in div-rem-pairs pass"); /// Find matching pairs of integer div/rem ops (they have the same numerator, /// denominator, and signedness). If they exist in different basic blocks, bring @@ -93,6 +96,9 @@ static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, if (!DivDominates && !DT.dominates(RemInst, DivInst)) continue; + if (!DebugCounter::shouldExecute(DRPCounter)) + continue; + if (HasDivRemOp) { // The target has a single div/rem operation. Hoist the lower instruction // to make the matched pair visible to the backend. diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index 533d16e088c8..1f09979b3382 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" @@ -54,6 +55,7 @@ #include "llvm/Support/RecyclingAllocator.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/GuardUtils.h" #include <cassert> #include <deque> #include <memory> @@ -602,6 +604,8 @@ private: void removeMSSA(Instruction *Inst) { if (!MSSA) return; + if (VerifyMemorySSA) + MSSA->verifyMemorySSA(); // Removing a store here can leave MemorySSA in an unoptimized state by // creating MemoryPhis that have identical arguments and by creating // MemoryUses whose defining access is not an actual clobber. We handle the @@ -808,7 +812,8 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { LLVM_DEBUG(dbgs() << "Skipping due to debug counter\n"); continue; } - salvageDebugInfo(*Inst); + if (!salvageDebugInfo(*Inst)) + replaceDbgUsesWithUndef(Inst); removeMSSA(Inst); Inst->eraseFromParent(); Changed = true; @@ -863,7 +868,7 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } - if (match(Inst, m_Intrinsic<Intrinsic::experimental_guard>())) { + if (isGuard(Inst)) { if (auto *CondI = dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) { if (SimpleValue::canHandle(CondI)) { diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index 1e0a22cb14b3..9861948c8297 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -38,7 +38,6 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/PHITransAddr.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/Attributes.h" @@ -48,6 +47,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -71,6 +71,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/VNCoercion.h" #include <algorithm> @@ -97,11 +98,16 @@ STATISTIC(NumPRELoad, "Number of loads PRE'd"); static cl::opt<bool> EnablePRE("enable-pre", cl::init(true), cl::Hidden); static cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true)); +static cl::opt<bool> EnableMemDep("enable-gvn-memdep", cl::init(true)); // Maximum allowed recursion depth. static cl::opt<uint32_t> -MaxRecurseDepth("max-recurse-depth", cl::Hidden, cl::init(1000), cl::ZeroOrMore, - cl::desc("Max recurse depth (default = 1000)")); +MaxRecurseDepth("gvn-max-recurse-depth", cl::Hidden, cl::init(1000), cl::ZeroOrMore, + cl::desc("Max recurse depth in GVN (default = 1000)")); + +static cl::opt<uint32_t> MaxNumDeps( + "gvn-max-num-deps", cl::Hidden, cl::init(100), cl::ZeroOrMore, + cl::desc("Max number of dependences to attempt Load PRE (default = 100)")); struct llvm::GVN::Expression { uint32_t opcode; @@ -392,18 +398,13 @@ uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { uint32_t e = assignExpNewValueNum(exp).first; valueNumbering[C] = e; return e; - } else if (AA->onlyReadsMemory(C)) { + } else if (MD && AA->onlyReadsMemory(C)) { Expression exp = createExpr(C); auto ValNum = assignExpNewValueNum(exp); if (ValNum.second) { valueNumbering[C] = ValNum.first; return ValNum.first; } - if (!MD) { - uint32_t e = assignExpNewValueNum(exp).first; - valueNumbering[C] = e; - return e; - } MemDepResult local_dep = MD->getDependency(C); @@ -436,7 +437,7 @@ uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { // Non-local case. const MemoryDependenceResults::NonLocalDepInfo &deps = - MD->getNonLocalCallDependency(CallSite(C)); + MD->getNonLocalCallDependency(C); // FIXME: Move the checking logic to MemDep! CallInst* cdep = nullptr; @@ -677,7 +678,7 @@ static bool IsValueFullyAvailableInBlock(BasicBlock *BB, // Optimistically assume that the block is fully available and check to see // if we already know about this block in one lookup. - std::pair<DenseMap<BasicBlock*, char>::iterator, char> IV = + std::pair<DenseMap<BasicBlock*, char>::iterator, bool> IV = FullyAvailableBlocks.insert(std::make_pair(BB, 2)); // If the entry already existed for this block, return the precomputed value. @@ -1074,15 +1075,8 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // because if the index is out of bounds we should deoptimize rather than // access the array. // Check that there is no guard in this block above our instruction. - if (!IsSafeToSpeculativelyExecute) { - auto It = FirstImplicitControlFlowInsts.find(TmpBB); - if (It != FirstImplicitControlFlowInsts.end()) { - assert(It->second->getParent() == TmpBB && - "Implicit control flow map broken?"); - if (OI->dominates(It->second, LI)) - return false; - } - } + if (!IsSafeToSpeculativelyExecute && ICF->isDominatedByICFIFromSameBlock(LI)) + return false; while (TmpBB->getSinglePredecessor()) { TmpBB = TmpBB->getSinglePredecessor(); if (TmpBB == LoadBB) // Infinite (unreachable) loop. @@ -1099,8 +1093,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, return false; // Check that there is no implicit control flow in a block above. - if (!IsSafeToSpeculativelyExecute && - FirstImplicitControlFlowInsts.count(TmpBB)) + if (!IsSafeToSpeculativelyExecute && ICF->hasICF(TmpBB)) return false; } @@ -1322,7 +1315,7 @@ bool GVN::processNonLocalLoad(LoadInst *LI) { // dependencies, this load isn't worth worrying about. Optimizing // it will be too expensive. unsigned NumDeps = Deps.size(); - if (NumDeps > 100) + if (NumDeps > MaxNumDeps) return false; // If we had a phi translation failure, we'll have a single entry which is a @@ -1451,37 +1444,6 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { return Changed; } -static void patchReplacementInstruction(Instruction *I, Value *Repl) { - auto *ReplInst = dyn_cast<Instruction>(Repl); - if (!ReplInst) - return; - - // Patch the replacement so that it is not more restrictive than the value - // being replaced. - // Note that if 'I' is a load being replaced by some operation, - // for example, by an arithmetic operation, then andIRFlags() - // would just erase all math flags from the original arithmetic - // operation, which is clearly not wanted and not needed. - if (!isa<LoadInst>(I)) - ReplInst->andIRFlags(I); - - // FIXME: If both the original and replacement value are part of the - // same control-flow region (meaning that the execution of one - // guarantees the execution of the other), then we can combine the - // noalias scopes here and do better than the general conservative - // answer used in combineMetadata(). - - // In general, GVN unifies expressions over different control-flow - // regions, and so we need a conservative combination of the noalias - // scopes. - static const unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_range, - LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, - LLVMContext::MD_invariant_group}; - combineMetadata(ReplInst, I, KnownIDs); -} - static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { patchReplacementInstruction(I, Repl); I->replaceAllUsesWith(Repl); @@ -1683,10 +1645,12 @@ static bool isOnlyReachableViaThisEdge(const BasicBlockEdge &E, } void GVN::assignBlockRPONumber(Function &F) { + BlockRPONumber.clear(); uint32_t NextBlockNumber = 1; ReversePostOrderTraversal<Function *> RPOT(&F); for (BasicBlock *BB : RPOT) BlockRPONumber[BB] = NextBlockNumber++; + InvalidBlockRPONumbers = false; } // Tries to replace instruction with const, using information from @@ -1778,6 +1742,9 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, Changed |= NumReplacements > 0; NumGVNEqProp += NumReplacements; + // Cached information for anything that uses LHS will be invalid. + if (MD) + MD->invalidateCachedPointerInfo(LHS); } // Now try to deduce additional equalities from this one. For example, if @@ -1853,6 +1820,9 @@ bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, Root.getStart()); Changed |= NumReplacements > 0; NumGVNEqProp += NumReplacements; + // Cached information for anything that uses NotCmp will be invalid. + if (MD) + MD->invalidateCachedPointerInfo(NotCmp); } } // Ensure that any instruction in scope that gets the "A < B" value number @@ -1975,7 +1945,7 @@ bool GVN::processInstruction(Instruction *I) { // Allocations are always uniquely numbered, so we can save time and memory // by fast failing them. - if (isa<AllocaInst>(I) || isa<TerminatorInst>(I) || isa<PHINode>(I)) { + if (isa<AllocaInst>(I) || I->isTerminator() || isa<PHINode>(I)) { addToLeaderTable(Num, I, I->getParent()); return false; } @@ -2020,20 +1990,22 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, TLI = &RunTLI; VN.setAliasAnalysis(&RunAA); MD = RunMD; - OrderedInstructions OrderedInstrs(DT); - OI = &OrderedInstrs; + ImplicitControlFlowTracking ImplicitCFT(DT); + ICF = &ImplicitCFT; VN.setMemDep(MD); ORE = RunORE; + InvalidBlockRPONumbers = true; bool Changed = false; bool ShouldContinue = true; + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); // Merge unconditional branches, allowing PRE to catch more // optimization opportunities. for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ) { BasicBlock *BB = &*FI++; - bool removedBlock = MergeBlockIntoPredecessor(BB, DT, LI, MD); + bool removedBlock = MergeBlockIntoPredecessor(BB, &DTU, LI, nullptr, MD); if (removedBlock) ++NumGVNBlocks; @@ -2052,7 +2024,6 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, // Fabricate val-num for dead-code in order to suppress assertion in // performPRE(). assignValNumForDeadCode(); - assignBlockRPONumber(F); bool PREChanged = true; while (PREChanged) { PREChanged = performPRE(F); @@ -2104,27 +2075,16 @@ bool GVN::processBlock(BasicBlock *BB) { if (!AtStart) --BI; - bool InvalidateImplicitCF = false; - const Instruction *MaybeFirstICF = FirstImplicitControlFlowInsts.lookup(BB); for (auto *I : InstrsToErase) { assert(I->getParent() == BB && "Removing instruction from wrong block?"); LLVM_DEBUG(dbgs() << "GVN removed: " << *I << '\n'); salvageDebugInfo(*I); if (MD) MD->removeInstruction(I); LLVM_DEBUG(verifyRemoved(I)); - if (MaybeFirstICF == I) { - // We have erased the first ICF in block. The map needs to be updated. - InvalidateImplicitCF = true; - // Do not keep dangling pointer on the erased instruction. - MaybeFirstICF = nullptr; - } + ICF->removeInstruction(I); I->eraseFromParent(); } - - OI->invalidateBlock(BB); InstrsToErase.clear(); - if (InvalidateImplicitCF) - fillImplicitControlFlowInfo(BB); if (AtStart) BI = BB->begin(); @@ -2184,7 +2144,7 @@ bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, } bool GVN::performScalarPRE(Instruction *CurInst) { - if (isa<AllocaInst>(CurInst) || isa<TerminatorInst>(CurInst) || + if (isa<AllocaInst>(CurInst) || CurInst->isTerminator() || isa<PHINode>(CurInst) || CurInst->getType()->isVoidTy() || CurInst->mayReadFromMemory() || CurInst->mayHaveSideEffects() || isa<DbgInfoIntrinsic>(CurInst)) @@ -2197,6 +2157,16 @@ bool GVN::performScalarPRE(Instruction *CurInst) { if (isa<CmpInst>(CurInst)) return false; + // Don't do PRE on GEPs. The inserted PHI would prevent CodeGenPrepare from + // sinking the addressing mode computation back to its uses. Extending the + // GEP's live range increases the register pressure, and therefore it can + // introduce unnecessary spills. + // + // This doesn't prevent Load PRE. PHI translation will make the GEP available + // to the load by moving it to the predecessor block if necessary. + if (isa<GetElementPtrInst>(CurInst)) + return false; + // We don't currently value number ANY inline asm calls. if (CallInst *CallI = dyn_cast<CallInst>(CurInst)) if (CallI->isInlineAsm()) @@ -2215,6 +2185,10 @@ bool GVN::performScalarPRE(Instruction *CurInst) { BasicBlock *PREPred = nullptr; BasicBlock *CurrentBlock = CurInst->getParent(); + // Update the RPO numbers for this function. + if (InvalidBlockRPONumbers) + assignBlockRPONumber(*CurrentBlock->getParent()); + SmallVector<std::pair<Value *, BasicBlock *>, 8> predMap; for (BasicBlock *P : predecessors(CurrentBlock)) { // We're not interested in PRE where blocks with predecessors that are @@ -2226,6 +2200,8 @@ bool GVN::performScalarPRE(Instruction *CurInst) { // It is not safe to do PRE when P->CurrentBlock is a loop backedge, and // when CurInst has operand defined in CurrentBlock (so it may be defined // by phi in the loop header). + assert(BlockRPONumber.count(P) && BlockRPONumber.count(CurrentBlock) && + "Invalid BlockRPONumber map."); if (BlockRPONumber[P] >= BlockRPONumber[CurrentBlock] && llvm::any_of(CurInst->operands(), [&](const Use &U) { if (auto *Inst = dyn_cast<Instruction>(U.get())) @@ -2268,13 +2244,8 @@ bool GVN::performScalarPRE(Instruction *CurInst) { // is always executed. An instruction with implicit control flow could // prevent us from doing it. If we cannot speculate the execution, then // PRE should be prohibited. - auto It = FirstImplicitControlFlowInsts.find(CurrentBlock); - if (It != FirstImplicitControlFlowInsts.end()) { - assert(It->second->getParent() == CurrentBlock && - "Implicit control flow map broken?"); - if (OI->dominates(It->second, CurInst)) - return false; - } + if (ICF->isDominatedByICFIFromSameBlock(CurInst)) + return false; } // Don't do PRE across indirect branch. @@ -2335,14 +2306,10 @@ bool GVN::performScalarPRE(Instruction *CurInst) { if (MD) MD->removeInstruction(CurInst); LLVM_DEBUG(verifyRemoved(CurInst)); - bool InvalidateImplicitCF = - FirstImplicitControlFlowInsts.lookup(CurInst->getParent()) == CurInst; // FIXME: Intended to be markInstructionForDeletion(CurInst), but it causes // some assertion failures. - OI->invalidateBlock(CurrentBlock); + ICF->removeInstruction(CurInst); CurInst->eraseFromParent(); - if (InvalidateImplicitCF) - fillImplicitControlFlowInfo(CurrentBlock); ++NumGVNInstr; return true; @@ -2382,6 +2349,7 @@ BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT)); if (MD) MD->invalidateCachedPredecessors(); + InvalidBlockRPONumbers = true; return BB; } @@ -2391,11 +2359,12 @@ bool GVN::splitCriticalEdges() { if (toSplit.empty()) return false; do { - std::pair<TerminatorInst*, unsigned> Edge = toSplit.pop_back_val(); + std::pair<Instruction *, unsigned> Edge = toSplit.pop_back_val(); SplitCriticalEdge(Edge.first, Edge.second, CriticalEdgeSplittingOptions(DT)); } while (!toSplit.empty()); if (MD) MD->invalidateCachedPredecessors(); + InvalidBlockRPONumbers = true; return true; } @@ -2411,8 +2380,6 @@ bool GVN::iterateOnFunction(Function &F) { ReversePostOrderTraversal<Function *> RPOT(&F); for (BasicBlock *BB : RPOT) - fillImplicitControlFlowInfo(BB); - for (BasicBlock *BB : RPOT) Changed |= processBlock(BB); return Changed; @@ -2423,48 +2390,8 @@ void GVN::cleanupGlobalSets() { LeaderTable.clear(); BlockRPONumber.clear(); TableAllocator.Reset(); - FirstImplicitControlFlowInsts.clear(); -} - -void -GVN::fillImplicitControlFlowInfo(BasicBlock *BB) { - // Make sure that all marked instructions are actually deleted by this point, - // so that we don't need to care about omitting them. - assert(InstrsToErase.empty() && "Filling before removed all marked insns?"); - auto MayNotTransferExecutionToSuccessor = [&](const Instruction *I) { - // If a block's instruction doesn't always pass the control to its successor - // instruction, mark the block as having implicit control flow. We use them - // to avoid wrong assumptions of sort "if A is executed and B post-dominates - // A, then B is also executed". This is not true is there is an implicit - // control flow instruction (e.g. a guard) between them. - // - // TODO: Currently, isGuaranteedToTransferExecutionToSuccessor returns false - // for volatile stores and loads because they can trap. The discussion on - // whether or not it is correct is still ongoing. We might want to get rid - // of this logic in the future. Anyways, trapping instructions shouldn't - // introduce implicit control flow, so we explicitly allow them here. This - // must be removed once isGuaranteedToTransferExecutionToSuccessor is fixed. - if (isGuaranteedToTransferExecutionToSuccessor(I)) - return false; - if (isa<LoadInst>(I)) { - assert(cast<LoadInst>(I)->isVolatile() && - "Non-volatile load should transfer execution to successor!"); - return false; - } - if (isa<StoreInst>(I)) { - assert(cast<StoreInst>(I)->isVolatile() && - "Non-volatile store should transfer execution to successor!"); - return false; - } - return true; - }; - FirstImplicitControlFlowInsts.erase(BB); - - for (auto &I : *BB) - if (MayNotTransferExecutionToSuccessor(&I)) { - FirstImplicitControlFlowInsts[BB] = &I; - break; - } + ICF->clear(); + InvalidBlockRPONumbers = true; } /// Verify that the specified instruction does not occur in our @@ -2554,6 +2481,8 @@ void GVN::addDeadBlock(BasicBlock *BB) { PHINode &Phi = cast<PHINode>(*II); Phi.setIncomingValue(Phi.getBasicBlockIndex(P), UndefValue::get(Phi.getType())); + if (MD) + MD->invalidateCachedPointerInfo(&Phi); } } } @@ -2613,8 +2542,8 @@ class llvm::gvn::GVNLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid - explicit GVNLegacyPass(bool NoLoads = false) - : FunctionPass(ID), NoLoads(NoLoads) { + explicit GVNLegacyPass(bool NoMemDepAnalysis = !EnableMemDep) + : FunctionPass(ID), NoMemDepAnalysis(NoMemDepAnalysis) { initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -2629,7 +2558,7 @@ public: getAnalysis<DominatorTreeWrapperPass>().getDomTree(), getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), getAnalysis<AAResultsWrapperPass>().getAAResults(), - NoLoads ? nullptr + NoMemDepAnalysis ? nullptr : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(), LIWP ? &LIWP->getLoopInfo() : nullptr, &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); @@ -2639,7 +2568,7 @@ public: AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); - if (!NoLoads) + if (!NoMemDepAnalysis) AU.addRequired<MemoryDependenceWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); @@ -2650,7 +2579,7 @@ public: } private: - bool NoLoads; + bool NoMemDepAnalysis; GVN Impl; }; @@ -2667,6 +2596,6 @@ INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) // The public interface to this file... -FunctionPass *llvm::createGVNPass(bool NoLoads) { - return new GVNLegacyPass(NoLoads); +FunctionPass *llvm::createGVNPass(bool NoMemDepAnalysis) { + return new GVNLegacyPass(NoMemDepAnalysis); } diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp index 6d2b25cf6013..76a42d7fe750 100644 --- a/lib/Transforms/Scalar/GVNHoist.cpp +++ b/lib/Transforms/Scalar/GVNHoist.cpp @@ -246,8 +246,8 @@ static void combineKnownMetadata(Instruction *ReplInst, Instruction *I) { LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, LLVMContext::MD_noalias, LLVMContext::MD_range, LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, - LLVMContext::MD_invariant_group}; - combineMetadata(ReplInst, I, KnownIDs); + LLVMContext::MD_invariant_group, LLVMContext::MD_access_group}; + combineMetadata(ReplInst, I, KnownIDs, true); } // This pass hoists common computations across branches sharing common @@ -365,7 +365,7 @@ private: // Return true when a successor of BB dominates A. bool successorDominate(const BasicBlock *BB, const BasicBlock *A) { - for (const BasicBlock *Succ : BB->getTerminator()->successors()) + for (const BasicBlock *Succ : successors(BB)) if (DT->dominates(Succ, A)) return true; @@ -577,15 +577,15 @@ private: // Returns the edge via which an instruction in BB will get the values from. // Returns true when the values are flowing out to each edge. - bool valueAnticipable(CHIArgs C, TerminatorInst *TI) const { + bool valueAnticipable(CHIArgs C, Instruction *TI) const { if (TI->getNumSuccessors() > (unsigned)size(C)) return false; // Not enough args in this CHI. for (auto CHI : C) { BasicBlock *Dest = CHI.Dest; // Find if all the edges have values flowing out of BB. - bool Found = llvm::any_of(TI->successors(), [Dest](const BasicBlock *BB) { - return BB == Dest; }); + bool Found = llvm::any_of( + successors(TI), [Dest](const BasicBlock *BB) { return BB == Dest; }); if (!Found) return false; } @@ -748,11 +748,9 @@ private: // TODO: Remove fully-redundant expressions. // Get instruction from the Map, assume that all the Instructions // with same VNs have same rank (this is an approximation). - llvm::sort(Ranks.begin(), Ranks.end(), - [this, &Map](const VNType &r1, const VNType &r2) { - return (rank(*Map.lookup(r1).begin()) < - rank(*Map.lookup(r2).begin())); - }); + llvm::sort(Ranks, [this, &Map](const VNType &r1, const VNType &r2) { + return (rank(*Map.lookup(r1).begin()) < rank(*Map.lookup(r2).begin())); + }); // - Sort VNs according to their rank, and start with lowest ranked VN // - Take a VN and for each instruction with same VN @@ -784,6 +782,7 @@ private: // which currently have dead terminators that are control // dependence sources of a block which is in NewLiveBlocks. IDFs.setDefiningBlocks(VNBlocks); + IDFBlocks.clear(); IDFs.calculate(IDFBlocks); // Make a map of BB vs instructions to be hoisted. @@ -792,7 +791,7 @@ private: } // Insert empty CHI node for this VN. This is used to factor out // basic blocks where the ANTIC can potentially change. - for (auto IDFB : IDFBlocks) { // TODO: Prune out useless CHI insertions. + for (auto IDFB : IDFBlocks) { for (unsigned i = 0; i < V.size(); ++i) { CHIArg C = {VN, nullptr, nullptr}; // Ignore spurious PDFs. @@ -1100,7 +1099,7 @@ private: break; // Do not value number terminator instructions. - if (isa<TerminatorInst>(&I1)) + if (I1.isTerminator()) break; if (auto *Load = dyn_cast<LoadInst>(&I1)) diff --git a/lib/Transforms/Scalar/GVNSink.cpp b/lib/Transforms/Scalar/GVNSink.cpp index 8959038de596..1df5f5400c14 100644 --- a/lib/Transforms/Scalar/GVNSink.cpp +++ b/lib/Transforms/Scalar/GVNSink.cpp @@ -239,7 +239,7 @@ public: SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops; for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)}); - llvm::sort(Ops.begin(), Ops.end()); + llvm::sort(Ops); for (auto &P : Ops) { Blocks.push_back(P.first); Values.push_back(P.second); @@ -258,14 +258,14 @@ public: /// Create a PHI from an array of incoming values and incoming blocks. template <typename VArray, typename BArray> ModelledPHI(const VArray &V, const BArray &B) { - std::copy(V.begin(), V.end(), std::back_inserter(Values)); - std::copy(B.begin(), B.end(), std::back_inserter(Blocks)); + llvm::copy(V, std::back_inserter(Values)); + llvm::copy(B, std::back_inserter(Blocks)); } /// Create a PHI from [I[OpNum] for I in Insts]. template <typename BArray> ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) { - std::copy(B.begin(), B.end(), std::back_inserter(Blocks)); + llvm::copy(B, std::back_inserter(Blocks)); for (auto *I : Insts) Values.push_back(I->getOperand(OpNum)); } @@ -762,7 +762,7 @@ unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { } if (Preds.size() < 2) return 0; - llvm::sort(Preds.begin(), Preds.end()); + llvm::sort(Preds); unsigned NumOrigPreds = Preds.size(); // We can only sink instructions through unconditional branches. @@ -859,7 +859,7 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, // Update metadata and IR flags. for (auto *I : Insts) if (I != I0) { - combineMetadataForCSE(I0, I); + combineMetadataForCSE(I0, I, true); I0->andIRFlags(I); } diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp index 055fcbc8436f..efc204d4f74b 100644 --- a/lib/Transforms/Scalar/GuardWidening.cpp +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -44,6 +44,8 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/PostDominators.h" @@ -63,22 +65,69 @@ using namespace llvm; #define DEBUG_TYPE "guard-widening" STATISTIC(GuardsEliminated, "Number of eliminated guards"); +STATISTIC(CondBranchEliminated, "Number of eliminated conditional branches"); + +static cl::opt<bool> WidenFrequentBranches( + "guard-widening-widen-frequent-branches", cl::Hidden, + cl::desc("Widen conditions of explicit branches into dominating guards in " + "case if their taken frequency exceeds threshold set by " + "guard-widening-frequent-branch-threshold option"), + cl::init(false)); + +static cl::opt<unsigned> FrequentBranchThreshold( + "guard-widening-frequent-branch-threshold", cl::Hidden, + cl::desc("When WidenFrequentBranches is set to true, this option is used " + "to determine which branches are frequently taken. The criteria " + "that a branch is taken more often than " + "((FrequentBranchThreshold - 1) / FrequentBranchThreshold), then " + "it is considered frequently taken"), + cl::init(1000)); + namespace { +// Get the condition of \p I. It can either be a guard or a conditional branch. +static Value *getCondition(Instruction *I) { + if (IntrinsicInst *GI = dyn_cast<IntrinsicInst>(I)) { + assert(GI->getIntrinsicID() == Intrinsic::experimental_guard && + "Bad guard intrinsic?"); + return GI->getArgOperand(0); + } + return cast<BranchInst>(I)->getCondition(); +} + +// Set the condition for \p I to \p NewCond. \p I can either be a guard or a +// conditional branch. +static void setCondition(Instruction *I, Value *NewCond) { + if (IntrinsicInst *GI = dyn_cast<IntrinsicInst>(I)) { + assert(GI->getIntrinsicID() == Intrinsic::experimental_guard && + "Bad guard intrinsic?"); + GI->setArgOperand(0, NewCond); + return; + } + cast<BranchInst>(I)->setCondition(NewCond); +} + +// Eliminates the guard instruction properly. +static void eliminateGuard(Instruction *GuardInst) { + GuardInst->eraseFromParent(); + ++GuardsEliminated; +} + class GuardWideningImpl { DominatorTree &DT; PostDominatorTree *PDT; LoopInfo &LI; + BranchProbabilityInfo *BPI; /// Together, these describe the region of interest. This might be all of /// the blocks within a function, or only a given loop's blocks and preheader. DomTreeNode *Root; std::function<bool(BasicBlock*)> BlockFilter; - /// The set of guards whose conditions have been widened into dominating - /// guards. - SmallVector<Instruction *, 16> EliminatedGuards; + /// The set of guards and conditional branches whose conditions have been + /// widened into dominating guards. + SmallVector<Instruction *, 16> EliminatedGuardsAndBranches; /// The set of guards which have been widened to include conditions to other /// guards. @@ -91,19 +140,7 @@ class GuardWideningImpl { bool eliminateGuardViaWidening( Instruction *Guard, const df_iterator<DomTreeNode *> &DFSI, const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> & - GuardsPerBlock); - - // Get the condition from \p GuardInst. - Value *getGuardCondition(Instruction *GuardInst); - - // Set the condition for \p GuardInst. - void setGuardCondition(Instruction *GuardInst, Value *NewCond); - - // Whether or not the particular instruction is a guard. - bool isGuard(const Instruction *I); - - // Eliminates the guard instruction properly. - void eliminateGuard(Instruction *GuardInst); + GuardsPerBlock, bool InvertCondition = false); /// Used to keep track of which widening potential is more effective. enum WideningScore { @@ -127,11 +164,13 @@ class GuardWideningImpl { /// Compute the score for widening the condition in \p DominatedGuard /// (contained in \p DominatedGuardLoop) into \p DominatingGuard (contained in - /// \p DominatingGuardLoop). + /// \p DominatingGuardLoop). If \p InvertCond is set, then we widen the + /// inverted condition of the dominating guard. WideningScore computeWideningScore(Instruction *DominatedGuard, Loop *DominatedGuardLoop, Instruction *DominatingGuard, - Loop *DominatingGuardLoop); + Loop *DominatingGuardLoop, + bool InvertCond); /// Helper to check if \p V can be hoisted to \p InsertPos. bool isAvailableAt(Value *V, Instruction *InsertPos) { @@ -147,13 +186,14 @@ class GuardWideningImpl { void makeAvailableAt(Value *V, Instruction *InsertPos); /// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try - /// to generate an expression computing the logical AND of \p Cond0 and \p - /// Cond1. Return true if the expression computing the AND is only as + /// to generate an expression computing the logical AND of \p Cond0 and (\p + /// Cond1 XOR \p InvertCondition). + /// Return true if the expression computing the AND is only as /// expensive as computing one of the two. If \p InsertPt is true then /// actually generate the resulting expression, make it available at \p /// InsertPt and return it in \p Result (else no change to the IR is made). bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt, - Value *&Result); + Value *&Result, bool InvertCondition); /// Represents a range check of the form \c Base + \c Offset u< \c Length, /// with the constraint that \c Length is not negative. \c CheckInst is the @@ -214,25 +254,31 @@ class GuardWideningImpl { /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of /// computing only one of the two expressions? - bool isWideningCondProfitable(Value *Cond0, Value *Cond1) { + bool isWideningCondProfitable(Value *Cond0, Value *Cond1, bool InvertCond) { Value *ResultUnused; - return widenCondCommon(Cond0, Cond1, /*InsertPt=*/nullptr, ResultUnused); + return widenCondCommon(Cond0, Cond1, /*InsertPt=*/nullptr, ResultUnused, + InvertCond); } - /// Widen \p ToWiden to fail if \p NewCondition is false (in addition to - /// whatever it is already checking). - void widenGuard(Instruction *ToWiden, Value *NewCondition) { + /// If \p InvertCondition is false, Widen \p ToWiden to fail if + /// \p NewCondition is false, otherwise make it fail if \p NewCondition is + /// true (in addition to whatever it is already checking). + void widenGuard(Instruction *ToWiden, Value *NewCondition, + bool InvertCondition) { Value *Result; - widenCondCommon(ToWiden->getOperand(0), NewCondition, ToWiden, Result); - setGuardCondition(ToWiden, Result); + widenCondCommon(ToWiden->getOperand(0), NewCondition, ToWiden, Result, + InvertCondition); + setCondition(ToWiden, Result); } public: explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree *PDT, - LoopInfo &LI, DomTreeNode *Root, + LoopInfo &LI, BranchProbabilityInfo *BPI, + DomTreeNode *Root, std::function<bool(BasicBlock*)> BlockFilter) - : DT(DT), PDT(PDT), LI(LI), Root(Root), BlockFilter(BlockFilter) {} + : DT(DT), PDT(PDT), LI(LI), BPI(BPI), Root(Root), BlockFilter(BlockFilter) + {} /// The entry point for this pass. bool run(); @@ -242,6 +288,12 @@ public: bool GuardWideningImpl::run() { DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> GuardsInBlock; bool Changed = false; + Optional<BranchProbability> LikelyTaken = None; + if (WidenFrequentBranches && BPI) { + unsigned Threshold = FrequentBranchThreshold; + assert(Threshold > 0 && "Zero threshold makes no sense!"); + LikelyTaken = BranchProbability(Threshold - 1, Threshold); + } for (auto DFI = df_begin(Root), DFE = df_end(Root); DFI != DFE; ++DFI) { @@ -257,12 +309,31 @@ bool GuardWideningImpl::run() { for (auto *II : CurrentList) Changed |= eliminateGuardViaWidening(II, DFI, GuardsInBlock); + if (WidenFrequentBranches && BPI) + if (auto *BI = dyn_cast<BranchInst>(BB->getTerminator())) + if (BI->isConditional()) { + // If one of branches of a conditional is likely taken, try to + // eliminate it. + if (BPI->getEdgeProbability(BB, 0U) >= *LikelyTaken) + Changed |= eliminateGuardViaWidening(BI, DFI, GuardsInBlock); + else if (BPI->getEdgeProbability(BB, 1U) >= *LikelyTaken) + Changed |= eliminateGuardViaWidening(BI, DFI, GuardsInBlock, + /*InvertCondition*/true); + } } - assert(EliminatedGuards.empty() || Changed); - for (auto *II : EliminatedGuards) - if (!WidenedGuards.count(II)) - eliminateGuard(II); + assert(EliminatedGuardsAndBranches.empty() || Changed); + for (auto *I : EliminatedGuardsAndBranches) + if (!WidenedGuards.count(I)) { + assert(isa<ConstantInt>(getCondition(I)) && "Should be!"); + if (isGuard(I)) + eliminateGuard(I); + else { + assert(isa<BranchInst>(I) && + "Eliminated something other than guard or branch?"); + ++CondBranchEliminated; + } + } return Changed; } @@ -270,7 +341,13 @@ bool GuardWideningImpl::run() { bool GuardWideningImpl::eliminateGuardViaWidening( Instruction *GuardInst, const df_iterator<DomTreeNode *> &DFSI, const DenseMap<BasicBlock *, SmallVector<Instruction *, 8>> & - GuardsInBlock) { + GuardsInBlock, bool InvertCondition) { + // Ignore trivial true or false conditions. These instructions will be + // trivially eliminated by any cleanup pass. Do not erase them because other + // guards can possibly be widened into them. + if (isa<ConstantInt>(getCondition(GuardInst))) + return false; + Instruction *BestSoFar = nullptr; auto BestScoreSoFar = WS_IllegalOrNegative; auto *GuardInstLoop = LI.getLoopFor(GuardInst->getParent()); @@ -304,7 +381,7 @@ bool GuardWideningImpl::eliminateGuardViaWidening( assert((i == (e - 1)) == (GuardInst->getParent() == CurBB) && "Bad DFS?"); - if (i == (e - 1)) { + if (i == (e - 1) && CurBB->getTerminator() != GuardInst) { // Corner case: make sure we're only looking at guards strictly dominating // GuardInst when visiting GuardInst->getParent(). auto NewEnd = std::find(I, E, GuardInst); @@ -314,9 +391,10 @@ bool GuardWideningImpl::eliminateGuardViaWidening( for (auto *Candidate : make_range(I, E)) { auto Score = - computeWideningScore(GuardInst, GuardInstLoop, Candidate, CurLoop); - LLVM_DEBUG(dbgs() << "Score between " << *getGuardCondition(GuardInst) - << " and " << *getGuardCondition(Candidate) << " is " + computeWideningScore(GuardInst, GuardInstLoop, Candidate, CurLoop, + InvertCondition); + LLVM_DEBUG(dbgs() << "Score between " << *getCondition(GuardInst) + << " and " << *getCondition(Candidate) << " is " << scoreTypeToString(Score) << "\n"); if (Score > BestScoreSoFar) { BestScoreSoFar = Score; @@ -336,41 +414,19 @@ bool GuardWideningImpl::eliminateGuardViaWidening( LLVM_DEBUG(dbgs() << "Widening " << *GuardInst << " into " << *BestSoFar << " with score " << scoreTypeToString(BestScoreSoFar) << "\n"); - widenGuard(BestSoFar, getGuardCondition(GuardInst)); - setGuardCondition(GuardInst, ConstantInt::getTrue(GuardInst->getContext())); - EliminatedGuards.push_back(GuardInst); + widenGuard(BestSoFar, getCondition(GuardInst), InvertCondition); + auto NewGuardCondition = InvertCondition + ? ConstantInt::getFalse(GuardInst->getContext()) + : ConstantInt::getTrue(GuardInst->getContext()); + setCondition(GuardInst, NewGuardCondition); + EliminatedGuardsAndBranches.push_back(GuardInst); WidenedGuards.insert(BestSoFar); return true; } -Value *GuardWideningImpl::getGuardCondition(Instruction *GuardInst) { - IntrinsicInst *GI = cast<IntrinsicInst>(GuardInst); - assert(GI->getIntrinsicID() == Intrinsic::experimental_guard && - "Bad guard intrinsic?"); - return GI->getArgOperand(0); -} - -void GuardWideningImpl::setGuardCondition(Instruction *GuardInst, - Value *NewCond) { - IntrinsicInst *GI = cast<IntrinsicInst>(GuardInst); - assert(GI->getIntrinsicID() == Intrinsic::experimental_guard && - "Bad guard intrinsic?"); - GI->setArgOperand(0, NewCond); -} - -bool GuardWideningImpl::isGuard(const Instruction* I) { - using namespace llvm::PatternMatch; - return match(I, m_Intrinsic<Intrinsic::experimental_guard>()); -} - -void GuardWideningImpl::eliminateGuard(Instruction *GuardInst) { - GuardInst->eraseFromParent(); - ++GuardsEliminated; -} - GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( Instruction *DominatedGuard, Loop *DominatedGuardLoop, - Instruction *DominatingGuard, Loop *DominatingGuardLoop) { + Instruction *DominatingGuard, Loop *DominatingGuardLoop, bool InvertCond) { bool HoistingOutOfLoop = false; if (DominatingGuardLoop != DominatedGuardLoop) { @@ -383,7 +439,7 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( HoistingOutOfLoop = true; } - if (!isAvailableAt(getGuardCondition(DominatedGuard), DominatingGuard)) + if (!isAvailableAt(getCondition(DominatedGuard), DominatingGuard)) return WS_IllegalOrNegative; // If the guard was conditional executed, it may never be reached @@ -394,8 +450,8 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( // here. TODO: evaluate cost model for spurious deopt // NOTE: As written, this also lets us hoist right over another guard which // is essentially just another spelling for control flow. - if (isWideningCondProfitable(getGuardCondition(DominatedGuard), - getGuardCondition(DominatingGuard))) + if (isWideningCondProfitable(getCondition(DominatedGuard), + getCondition(DominatingGuard), InvertCond)) return HoistingOutOfLoop ? WS_VeryPositive : WS_Positive; if (HoistingOutOfLoop) @@ -416,8 +472,7 @@ GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( return false; // TODO: diamond, triangle cases if (!PDT) return true; - return !PDT->dominates(DominatedGuard->getParent(), - DominatingGuard->getParent()); + return !PDT->dominates(DominatedBlock, DominatingBlock); }; return MaybeHoistingOutOfIf() ? WS_IllegalOrNegative : WS_Neutral; @@ -459,7 +514,8 @@ void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) { } bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, - Instruction *InsertPt, Value *&Result) { + Instruction *InsertPt, Value *&Result, + bool InvertCondition) { using namespace llvm::PatternMatch; { @@ -469,6 +525,8 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, ICmpInst::Predicate Pred0, Pred1; if (match(Cond0, m_ICmp(Pred0, m_Value(LHS), m_ConstantInt(RHS0))) && match(Cond1, m_ICmp(Pred1, m_Specific(LHS), m_ConstantInt(RHS1)))) { + if (InvertCondition) + Pred1 = ICmpInst::getInversePredicate(Pred1); ConstantRange CR0 = ConstantRange::makeExactICmpRegion(Pred0, RHS0->getValue()); @@ -502,7 +560,9 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, { SmallVector<GuardWideningImpl::RangeCheck, 4> Checks, CombinedChecks; - if (parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) && + // TODO: Support InvertCondition case? + if (!InvertCondition && + parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) && combineRangeChecks(Checks, CombinedChecks)) { if (InsertPt) { Result = nullptr; @@ -526,7 +586,8 @@ bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, if (InsertPt) { makeAvailableAt(Cond0, InsertPt); makeAvailableAt(Cond1, InsertPt); - + if (InvertCondition) + Cond1 = BinaryOperator::CreateNot(Cond1, "inverted", InsertPt); Result = BinaryOperator::CreateAnd(Cond0, Cond1, "wide.chk", InsertPt); } @@ -636,9 +697,8 @@ bool GuardWideningImpl::combineRangeChecks( // CurrentChecks.size() will typically be 3 here, but so far there has been // no need to hard-code that fact. - llvm::sort(CurrentChecks.begin(), CurrentChecks.end(), - [&](const GuardWideningImpl::RangeCheck &LHS, - const GuardWideningImpl::RangeCheck &RHS) { + llvm::sort(CurrentChecks, [&](const GuardWideningImpl::RangeCheck &LHS, + const GuardWideningImpl::RangeCheck &RHS) { return LHS.getOffsetValue().slt(RHS.getOffsetValue()); }); @@ -728,7 +788,10 @@ PreservedAnalyses GuardWideningPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LI = AM.getResult<LoopAnalysis>(F); auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); - if (!GuardWideningImpl(DT, &PDT, LI, DT.getRootNode(), + BranchProbabilityInfo *BPI = nullptr; + if (WidenFrequentBranches) + BPI = AM.getCachedResult<BranchProbabilityAnalysis>(F); + if (!GuardWideningImpl(DT, &PDT, LI, BPI, DT.getRootNode(), [](BasicBlock*) { return true; } ).run()) return PreservedAnalyses::all(); @@ -751,7 +814,10 @@ struct GuardWideningLegacyPass : public FunctionPass { auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - return GuardWideningImpl(DT, &PDT, LI, DT.getRootNode(), + BranchProbabilityInfo *BPI = nullptr; + if (WidenFrequentBranches) + BPI = &getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + return GuardWideningImpl(DT, &PDT, LI, BPI, DT.getRootNode(), [](BasicBlock*) { return true; } ).run(); } @@ -760,6 +826,8 @@ struct GuardWideningLegacyPass : public FunctionPass { AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<PostDominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); + if (WidenFrequentBranches) + AU.addRequired<BranchProbabilityInfoWrapperPass>(); } }; @@ -785,11 +853,16 @@ struct LoopGuardWideningLegacyPass : public LoopPass { auto BlockFilter = [&](BasicBlock *BB) { return BB == RootBB || L->contains(BB); }; - return GuardWideningImpl(DT, PDT, LI, + BranchProbabilityInfo *BPI = nullptr; + if (WidenFrequentBranches) + BPI = &getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + return GuardWideningImpl(DT, PDT, LI, BPI, DT.getNode(RootBB), BlockFilter).run(); } void getAnalysisUsage(AnalysisUsage &AU) const override { + if (WidenFrequentBranches) + AU.addRequired<BranchProbabilityInfoWrapperPass>(); AU.setPreservesCFG(); getLoopAnalysisUsage(AU); AU.addPreserved<PostDominatorTreeWrapperPass>(); @@ -805,6 +878,8 @@ INITIALIZE_PASS_BEGIN(GuardWideningLegacyPass, "guard-widening", "Widen guards", INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +if (WidenFrequentBranches) + INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_END(GuardWideningLegacyPass, "guard-widening", "Widen guards", false, false) @@ -814,6 +889,8 @@ INITIALIZE_PASS_BEGIN(LoopGuardWideningLegacyPass, "loop-guard-widening", INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +if (WidenFrequentBranches) + INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) INITIALIZE_PASS_END(LoopGuardWideningLegacyPass, "loop-guard-widening", "Widen guards (within a single loop, as a loop pass)", false, false) diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index 8656e88b79cb..48d8e457ba7c 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -134,26 +134,23 @@ class IndVarSimplify { const TargetTransformInfo *TTI; SmallVector<WeakTrackingVH, 16> DeadInsts; - bool Changed = false; bool isValidRewrite(Value *FromVal, Value *ToVal); - void handleFloatingPointIV(Loop *L, PHINode *PH); - void rewriteNonIntegerIVs(Loop *L); + bool handleFloatingPointIV(Loop *L, PHINode *PH); + bool rewriteNonIntegerIVs(Loop *L); - void simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); + bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet); - void rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); - void rewriteFirstIterationLoopExitValues(Loop *L); - - Value *linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, - PHINode *IndVar, SCEVExpander &Rewriter); + bool rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); + bool rewriteFirstIterationLoopExitValues(Loop *L); + bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) const; - void sinkUnusedInvariants(Loop *L); + bool linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, + PHINode *IndVar, SCEVExpander &Rewriter); - Value *expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, - Instruction *InsertPt, Type *Ty); + bool sinkUnusedInvariants(Loop *L); public: IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, @@ -284,7 +281,7 @@ static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) { /// is converted into /// for(int i = 0; i < 10000; ++i) /// bar((double)i); -void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { +bool IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); unsigned BackEdge = IncomingEdge^1; @@ -293,12 +290,12 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { int64_t InitValue; if (!InitValueVal || !ConvertToSInt(InitValueVal->getValueAPF(), InitValue)) - return; + return false; // Check IV increment. Reject this PN if increment operation is not // an add or increment value can not be represented by an integer. auto *Incr = dyn_cast<BinaryOperator>(PN->getIncomingValue(BackEdge)); - if (Incr == nullptr || Incr->getOpcode() != Instruction::FAdd) return; + if (Incr == nullptr || Incr->getOpcode() != Instruction::FAdd) return false; // If this is not an add of the PHI with a constantfp, or if the constant fp // is not an integer, bail out. @@ -306,15 +303,15 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { int64_t IncValue; if (IncValueVal == nullptr || Incr->getOperand(0) != PN || !ConvertToSInt(IncValueVal->getValueAPF(), IncValue)) - return; + return false; // Check Incr uses. One user is PN and the other user is an exit condition // used by the conditional terminator. Value::user_iterator IncrUse = Incr->user_begin(); Instruction *U1 = cast<Instruction>(*IncrUse++); - if (IncrUse == Incr->user_end()) return; + if (IncrUse == Incr->user_end()) return false; Instruction *U2 = cast<Instruction>(*IncrUse++); - if (IncrUse != Incr->user_end()) return; + if (IncrUse != Incr->user_end()) return false; // Find exit condition, which is an fcmp. If it doesn't exist, or if it isn't // only used by a branch, we can't transform it. @@ -323,7 +320,7 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { Compare = dyn_cast<FCmpInst>(U2); if (!Compare || !Compare->hasOneUse() || !isa<BranchInst>(Compare->user_back())) - return; + return false; BranchInst *TheBr = cast<BranchInst>(Compare->user_back()); @@ -335,7 +332,7 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { if (!L->contains(TheBr->getParent()) || (L->contains(TheBr->getSuccessor(0)) && L->contains(TheBr->getSuccessor(1)))) - return; + return false; // If it isn't a comparison with an integer-as-fp (the exit value), we can't // transform it. @@ -343,12 +340,12 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { int64_t ExitValue; if (ExitValueVal == nullptr || !ConvertToSInt(ExitValueVal->getValueAPF(), ExitValue)) - return; + return false; // Find new predicate for integer comparison. CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE; switch (Compare->getPredicate()) { - default: return; // Unknown comparison. + default: return false; // Unknown comparison. case CmpInst::FCMP_OEQ: case CmpInst::FCMP_UEQ: NewPred = CmpInst::ICMP_EQ; break; case CmpInst::FCMP_ONE: @@ -371,24 +368,24 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { // The start/stride/exit values must all fit in signed i32. if (!isInt<32>(InitValue) || !isInt<32>(IncValue) || !isInt<32>(ExitValue)) - return; + return false; // If not actually striding (add x, 0.0), avoid touching the code. if (IncValue == 0) - return; + return false; // Positive and negative strides have different safety conditions. if (IncValue > 0) { // If we have a positive stride, we require the init to be less than the // exit value. if (InitValue >= ExitValue) - return; + return false; uint32_t Range = uint32_t(ExitValue-InitValue); // Check for infinite loop, either: // while (i <= Exit) or until (i > Exit) if (NewPred == CmpInst::ICMP_SLE || NewPred == CmpInst::ICMP_SGT) { - if (++Range == 0) return; // Range overflows. + if (++Range == 0) return false; // Range overflows. } unsigned Leftover = Range % uint32_t(IncValue); @@ -398,23 +395,23 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { // around and do things the fp IV wouldn't. if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && Leftover != 0) - return; + return false; // If the stride would wrap around the i32 before exiting, we can't // transform the IV. if (Leftover != 0 && int32_t(ExitValue+IncValue) < ExitValue) - return; + return false; } else { // If we have a negative stride, we require the init to be greater than the // exit value. if (InitValue <= ExitValue) - return; + return false; uint32_t Range = uint32_t(InitValue-ExitValue); // Check for infinite loop, either: // while (i >= Exit) or until (i < Exit) if (NewPred == CmpInst::ICMP_SGE || NewPred == CmpInst::ICMP_SLT) { - if (++Range == 0) return; // Range overflows. + if (++Range == 0) return false; // Range overflows. } unsigned Leftover = Range % uint32_t(-IncValue); @@ -424,12 +421,12 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { // around and do things the fp IV wouldn't. if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && Leftover != 0) - return; + return false; // If the stride would wrap around the i32 before exiting, we can't // transform the IV. if (Leftover != 0 && int32_t(ExitValue+IncValue) > ExitValue) - return; + return false; } IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext()); @@ -475,10 +472,10 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { PN->replaceAllUsesWith(Conv); RecursivelyDeleteTriviallyDeadInstructions(PN, TLI); } - Changed = true; + return true; } -void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { +bool IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { // First step. Check to see if there are any floating-point recurrences. // If there are, change them into integer recurrences, permitting analysis by // the SCEV routines. @@ -488,15 +485,17 @@ void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { for (PHINode &PN : Header->phis()) PHIs.push_back(&PN); + bool Changed = false; for (unsigned i = 0, e = PHIs.size(); i != e; ++i) if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHIs[i])) - handleFloatingPointIV(L, PN); + Changed |= handleFloatingPointIV(L, PN); // If the loop previously had floating-point IV, ScalarEvolution // may not have been able to compute a trip count. Now that we've done some // re-writing, the trip count may be computable. if (Changed) SE->forgetLoop(L); + return Changed; } namespace { @@ -521,24 +520,34 @@ struct RewritePhi { } // end anonymous namespace -Value *IndVarSimplify::expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, - Loop *L, Instruction *InsertPt, - Type *ResultTy) { - // Before expanding S into an expensive LLVM expression, see if we can use an - // already existing value as the expansion for S. - if (Value *ExistingValue = Rewriter.getExactExistingExpansion(S, InsertPt, L)) - if (ExistingValue->getType() == ResultTy) - return ExistingValue; - - // We didn't find anything, fall back to using SCEVExpander. - return Rewriter.expandCodeFor(S, ResultTy, InsertPt); -} - //===----------------------------------------------------------------------===// // rewriteLoopExitValues - Optimize IV users outside the loop. // As a side effect, reduces the amount of IV processing within the loop. //===----------------------------------------------------------------------===// +bool IndVarSimplify::hasHardUserWithinLoop(const Loop *L, const Instruction *I) const { + SmallPtrSet<const Instruction *, 8> Visited; + SmallVector<const Instruction *, 8> WorkList; + Visited.insert(I); + WorkList.push_back(I); + while (!WorkList.empty()) { + const Instruction *Curr = WorkList.pop_back_val(); + // This use is outside the loop, nothing to do. + if (!L->contains(Curr)) + continue; + // Do we assume it is a "hard" use which will not be eliminated easily? + if (Curr->mayHaveSideEffects()) + return true; + // Otherwise, add all its users to worklist. + for (auto U : Curr->users()) { + auto *UI = cast<Instruction>(U); + if (Visited.insert(UI).second) + WorkList.push_back(UI); + } + } + return false; +} + /// Check to see if this loop has a computable loop-invariant execution count. /// If so, this means that we can compute the final value of any expressions /// that are recurrent in the loop, and substitute the exit values from the loop @@ -549,7 +558,7 @@ Value *IndVarSimplify::expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, /// happen later, except that it's more powerful in some cases, because it's /// able to brute-force evaluate arbitrary instructions as long as they have /// constant operands at the beginning of the loop. -void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { +bool IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { // Check a pre-condition. assert(L->isRecursivelyLCSSAForm(*DT, *LI) && "Indvars did not preserve LCSSA!"); @@ -610,48 +619,14 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { !isSafeToExpand(ExitValue, *SE)) continue; - // Computing the value outside of the loop brings no benefit if : - // - it is definitely used inside the loop in a way which can not be - // optimized away. - // - no use outside of the loop can take advantage of hoisting the - // computation out of the loop - if (ExitValue->getSCEVType()>=scMulExpr) { - unsigned NumHardInternalUses = 0; - unsigned NumSoftExternalUses = 0; - unsigned NumUses = 0; - for (auto IB = Inst->user_begin(), IE = Inst->user_end(); - IB != IE && NumUses <= 6; ++IB) { - Instruction *UseInstr = cast<Instruction>(*IB); - unsigned Opc = UseInstr->getOpcode(); - NumUses++; - if (L->contains(UseInstr)) { - if (Opc == Instruction::Call || Opc == Instruction::Ret) - NumHardInternalUses++; - } else { - if (Opc == Instruction::PHI) { - // Do not count the Phi as a use. LCSSA may have inserted - // plenty of trivial ones. - NumUses--; - for (auto PB = UseInstr->user_begin(), - PE = UseInstr->user_end(); - PB != PE && NumUses <= 6; ++PB, ++NumUses) { - unsigned PhiOpc = cast<Instruction>(*PB)->getOpcode(); - if (PhiOpc != Instruction::Call && PhiOpc != Instruction::Ret) - NumSoftExternalUses++; - } - continue; - } - if (Opc != Instruction::Call && Opc != Instruction::Ret) - NumSoftExternalUses++; - } - } - if (NumUses <= 6 && NumHardInternalUses && !NumSoftExternalUses) - continue; - } + // Computing the value outside of the loop brings no benefit if it is + // definitely used inside the loop in a way which can not be optimized + // away. + if (!isa<SCEVConstant>(ExitValue) && hasHardUserWithinLoop(L, Inst)) + continue; bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst); - Value *ExitVal = - expandSCEVIfNeeded(Rewriter, ExitValue, L, Inst, PN->getType()); + Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst); LLVM_DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal << '\n' @@ -662,6 +637,16 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { continue; } +#ifndef NDEBUG + // If we reuse an instruction from a loop which is neither L nor one of + // its containing loops, we end up breaking LCSSA form for this loop by + // creating a new use of its instruction. + if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal)) + if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) + if (EVL != L) + assert(EVL->contains(L) && "LCSSA breach detected!"); +#endif + // Collect all the candidate PHINodes to be rewritten. RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost); } @@ -670,6 +655,7 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); + bool Changed = false; // Transformation. for (const RewritePhi &Phi : RewritePhiSet) { PHINode *PN = Phi.PN; @@ -703,6 +689,7 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { // The insertion point instruction may have been deleted; clear it out // so that the rewriter doesn't trip over it later. Rewriter.clearInsertPoint(); + return Changed; } //===---------------------------------------------------------------------===// @@ -714,7 +701,7 @@ void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { /// exits. If so, we know that if the exit path is taken, it is at the first /// loop iteration. This lets us predict exit values of PHI nodes that live in /// loop header. -void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { +bool IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { // Verify the input to the pass is already in LCSSA form. assert(L->isLCSSAForm(*DT)); @@ -723,6 +710,7 @@ void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { auto *LoopHeader = L->getHeader(); assert(LoopHeader && "Invalid loop"); + bool MadeAnyChanges = false; for (auto *ExitBB : ExitBlocks) { // If there are no more PHI nodes in this exit block, then no more // values defined inside the loop are used on this path. @@ -769,12 +757,14 @@ void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { if (PreheaderIdx != -1) { assert(ExitVal->getParent() == LoopHeader && "ExitVal must be in loop header"); + MadeAnyChanges = true; PN.setIncomingValue(IncomingValIdx, ExitVal->getIncomingValue(PreheaderIdx)); } } } } + return MadeAnyChanges; } /// Check whether it is possible to delete the loop after rewriting exit @@ -1024,6 +1014,8 @@ protected: Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter); bool widenLoopCompare(NarrowIVDefUse DU); + bool widenWithVariantLoadUse(NarrowIVDefUse DU); + void widenWithVariantLoadUseCodegen(NarrowIVDefUse DU); void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef); }; @@ -1368,6 +1360,146 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { return true; } +/// If the narrow use is an instruction whose two operands are the defining +/// instruction of DU and a load instruction, then we have the following: +/// if the load is hoisted outside the loop, then we do not reach this function +/// as scalar evolution analysis works fine in widenIVUse with variables +/// hoisted outside the loop and efficient code is subsequently generated by +/// not emitting truncate instructions. But when the load is not hoisted +/// (whether due to limitation in alias analysis or due to a true legality), +/// then scalar evolution can not proceed with loop variant values and +/// inefficient code is generated. This function handles the non-hoisted load +/// special case by making the optimization generate the same type of code for +/// hoisted and non-hoisted load (widen use and eliminate sign extend +/// instruction). This special case is important especially when the induction +/// variables are affecting addressing mode in code generation. +bool WidenIV::widenWithVariantLoadUse(NarrowIVDefUse DU) { + Instruction *NarrowUse = DU.NarrowUse; + Instruction *NarrowDef = DU.NarrowDef; + Instruction *WideDef = DU.WideDef; + + // Handle the common case of add<nsw/nuw> + const unsigned OpCode = NarrowUse->getOpcode(); + // Only Add/Sub/Mul instructions are supported. + if (OpCode != Instruction::Add && OpCode != Instruction::Sub && + OpCode != Instruction::Mul) + return false; + + // The operand that is not defined by NarrowDef of DU. Let's call it the + // other operand. + unsigned ExtendOperIdx = DU.NarrowUse->getOperand(0) == NarrowDef ? 1 : 0; + assert(DU.NarrowUse->getOperand(1 - ExtendOperIdx) == DU.NarrowDef && + "bad DU"); + + const SCEV *ExtendOperExpr = nullptr; + const OverflowingBinaryOperator *OBO = + cast<OverflowingBinaryOperator>(NarrowUse); + ExtendKind ExtKind = getExtendKind(NarrowDef); + if (ExtKind == SignExtended && OBO->hasNoSignedWrap()) + ExtendOperExpr = SE->getSignExtendExpr( + SE->getSCEV(NarrowUse->getOperand(ExtendOperIdx)), WideType); + else if (ExtKind == ZeroExtended && OBO->hasNoUnsignedWrap()) + ExtendOperExpr = SE->getZeroExtendExpr( + SE->getSCEV(NarrowUse->getOperand(ExtendOperIdx)), WideType); + else + return false; + + // We are interested in the other operand being a load instruction. + // But, we should look into relaxing this restriction later on. + auto *I = dyn_cast<Instruction>(NarrowUse->getOperand(ExtendOperIdx)); + if (I && I->getOpcode() != Instruction::Load) + return false; + + // Verifying that Defining operand is an AddRec + const SCEV *Op1 = SE->getSCEV(WideDef); + const SCEVAddRecExpr *AddRecOp1 = dyn_cast<SCEVAddRecExpr>(Op1); + if (!AddRecOp1 || AddRecOp1->getLoop() != L) + return false; + // Verifying that other operand is an Extend. + if (ExtKind == SignExtended) { + if (!isa<SCEVSignExtendExpr>(ExtendOperExpr)) + return false; + } else { + if (!isa<SCEVZeroExtendExpr>(ExtendOperExpr)) + return false; + } + + if (ExtKind == SignExtended) { + for (Use &U : NarrowUse->uses()) { + SExtInst *User = dyn_cast<SExtInst>(U.getUser()); + if (!User || User->getType() != WideType) + return false; + } + } else { // ExtKind == ZeroExtended + for (Use &U : NarrowUse->uses()) { + ZExtInst *User = dyn_cast<ZExtInst>(U.getUser()); + if (!User || User->getType() != WideType) + return false; + } + } + + return true; +} + +/// Special Case for widening with variant Loads (see +/// WidenIV::widenWithVariantLoadUse). This is the code generation part. +void WidenIV::widenWithVariantLoadUseCodegen(NarrowIVDefUse DU) { + Instruction *NarrowUse = DU.NarrowUse; + Instruction *NarrowDef = DU.NarrowDef; + Instruction *WideDef = DU.WideDef; + + ExtendKind ExtKind = getExtendKind(NarrowDef); + + LLVM_DEBUG(dbgs() << "Cloning arithmetic IVUser: " << *NarrowUse << "\n"); + + // Generating a widening use instruction. + Value *LHS = (NarrowUse->getOperand(0) == NarrowDef) + ? WideDef + : createExtendInst(NarrowUse->getOperand(0), WideType, + ExtKind, NarrowUse); + Value *RHS = (NarrowUse->getOperand(1) == NarrowDef) + ? WideDef + : createExtendInst(NarrowUse->getOperand(1), WideType, + ExtKind, NarrowUse); + + auto *NarrowBO = cast<BinaryOperator>(NarrowUse); + auto *WideBO = BinaryOperator::Create(NarrowBO->getOpcode(), LHS, RHS, + NarrowBO->getName()); + IRBuilder<> Builder(NarrowUse); + Builder.Insert(WideBO); + WideBO->copyIRFlags(NarrowBO); + + if (ExtKind == SignExtended) + ExtendKindMap[NarrowUse] = SignExtended; + else + ExtendKindMap[NarrowUse] = ZeroExtended; + + // Update the Use. + if (ExtKind == SignExtended) { + for (Use &U : NarrowUse->uses()) { + SExtInst *User = dyn_cast<SExtInst>(U.getUser()); + if (User && User->getType() == WideType) { + LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " + << *WideBO << "\n"); + ++NumElimExt; + User->replaceAllUsesWith(WideBO); + DeadInsts.emplace_back(User); + } + } + } else { // ExtKind == ZeroExtended + for (Use &U : NarrowUse->uses()) { + ZExtInst *User = dyn_cast<ZExtInst>(U.getUser()); + if (User && User->getType() == WideType) { + LLVM_DEBUG(dbgs() << "INDVARS: eliminating " << *User << " replaced by " + << *WideBO << "\n"); + ++NumElimExt; + User->replaceAllUsesWith(WideBO); + DeadInsts.emplace_back(User); + } + } + } +} + /// Determine whether an individual user of the narrow IV can be widened. If so, /// return the wide clone of the user. Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { @@ -1465,6 +1597,16 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { if (widenLoopCompare(DU)) return nullptr; + // We are here about to generate a truncate instruction that may hurt + // performance because the scalar evolution expression computed earlier + // in WideAddRec.first does not indicate a polynomial induction expression. + // In that case, look at the operands of the use instruction to determine + // if we can still widen the use instead of truncating its operand. + if (widenWithVariantLoadUse(DU)) { + widenWithVariantLoadUseCodegen(DU); + return nullptr; + } + // This user does not evaluate to a recurrence after widening, so don't // follow it. Instead insert a Trunc to kill off the original use, // eventually isolating the original narrow IV so it can be removed. @@ -1781,7 +1923,7 @@ public: /// candidates for simplification. /// /// Sign/Zero extend elimination is interleaved with IV simplification. -void IndVarSimplify::simplifyAndExtend(Loop *L, +bool IndVarSimplify::simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI) { SmallVector<WideIVInfo, 8> WideIVs; @@ -1798,6 +1940,7 @@ void IndVarSimplify::simplifyAndExtend(Loop *L, // for all current phis, then determines whether any IVs can be // widened. Widening adds new phis to LoopPhis, inducing another round of // simplification on the wide IVs. + bool Changed = false; while (!LoopPhis.empty()) { // Evaluate as many IV expressions as possible before widening any IVs. This // forces SCEV to set no-wrap flags before evaluating sign/zero @@ -1827,6 +1970,7 @@ void IndVarSimplify::simplifyAndExtend(Loop *L, } } } + return Changed; } //===----------------------------------------------------------------------===// @@ -2193,11 +2337,9 @@ static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, /// able to rewrite the exit tests of any loop where the SCEV analysis can /// determine a loop-invariant trip count of the loop, which is actually a much /// broader range than just linear tests. -Value *IndVarSimplify:: -linearFunctionTestReplace(Loop *L, - const SCEV *BackedgeTakenCount, - PHINode *IndVar, - SCEVExpander &Rewriter) { +bool IndVarSimplify:: +linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, + PHINode *IndVar, SCEVExpander &Rewriter) { assert(canExpandBackedgeTakenCount(L, SE, Rewriter) && "precondition"); // Initialize CmpIndVar and IVCount to their preincremented values. @@ -2320,8 +2462,7 @@ linearFunctionTestReplace(Loop *L, DeadInsts.push_back(OrigCond); ++NumLFTR; - Changed = true; - return Cond; + return true; } //===----------------------------------------------------------------------===// @@ -2331,13 +2472,14 @@ linearFunctionTestReplace(Loop *L, /// If there's a single exit block, sink any loop-invariant values that /// were defined in the preheader but not used inside the loop into the /// exit block to reduce register pressure in the loop. -void IndVarSimplify::sinkUnusedInvariants(Loop *L) { +bool IndVarSimplify::sinkUnusedInvariants(Loop *L) { BasicBlock *ExitBlock = L->getExitBlock(); - if (!ExitBlock) return; + if (!ExitBlock) return false; BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) return; + if (!Preheader) return false; + bool MadeAnyChanges = false; BasicBlock::iterator InsertPt = ExitBlock->getFirstInsertionPt(); BasicBlock::iterator I(Preheader->getTerminator()); while (I != Preheader->begin()) { @@ -2407,10 +2549,13 @@ void IndVarSimplify::sinkUnusedInvariants(Loop *L) { Done = true; } + MadeAnyChanges = true; ToMove->moveBefore(*ExitBlock, InsertPt); if (Done) break; InsertPt = ToMove->getIterator(); } + + return MadeAnyChanges; } //===----------------------------------------------------------------------===// @@ -2421,6 +2566,7 @@ bool IndVarSimplify::run(Loop *L) { // We need (and expect!) the incoming loop to be in LCSSA. assert(L->isRecursivelyLCSSAForm(*DT, *LI) && "LCSSA required to run indvars!"); + bool Changed = false; // If LoopSimplify form is not available, stay out of trouble. Some notes: // - LSR currently only supports LoopSimplify-form loops. Indvars' @@ -2436,7 +2582,7 @@ bool IndVarSimplify::run(Loop *L) { // If there are any floating-point recurrences, attempt to // transform them to use integer recurrences. - rewriteNonIntegerIVs(L); + Changed |= rewriteNonIntegerIVs(L); const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); @@ -2453,7 +2599,7 @@ bool IndVarSimplify::run(Loop *L) { // other expressions involving loop IVs have been evaluated. This helps SCEV // set no-wrap flags before normalizing sign/zero extension. Rewriter.disableCanonicalMode(); - simplifyAndExtend(L, Rewriter, LI); + Changed |= simplifyAndExtend(L, Rewriter, LI); // Check to see if this loop has a computable loop-invariant execution count. // If so, this means that we can compute the final value of any expressions @@ -2463,7 +2609,7 @@ bool IndVarSimplify::run(Loop *L) { // if (ReplaceExitValue != NeverRepl && !isa<SCEVCouldNotCompute>(BackedgeTakenCount)) - rewriteLoopExitValues(L, Rewriter); + Changed |= rewriteLoopExitValues(L, Rewriter); // Eliminate redundant IV cycles. NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); @@ -2484,8 +2630,8 @@ bool IndVarSimplify::run(Loop *L) { // explicitly check any assumptions made by SCEV. Brittle. const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(BackedgeTakenCount); if (!AR || AR->getLoop()->getLoopPreheader()) - (void)linearFunctionTestReplace(L, BackedgeTakenCount, IndVar, - Rewriter); + Changed |= linearFunctionTestReplace(L, BackedgeTakenCount, IndVar, + Rewriter); } } // Clear the rewriter cache, because values that are in the rewriter's cache @@ -2498,18 +2644,18 @@ bool IndVarSimplify::run(Loop *L) { while (!DeadInsts.empty()) if (Instruction *Inst = dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val())) - RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI); + Changed |= RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI); // The Rewriter may not be used from this point on. // Loop-invariant instructions in the preheader that aren't used in the // loop may be sunk below the loop to reduce register pressure. - sinkUnusedInvariants(L); + Changed |= sinkUnusedInvariants(L); // rewriteFirstIterationLoopExitValues does not rely on the computation of // trip count and therefore can further simplify exit values in addition to // rewriteLoopExitValues. - rewriteFirstIterationLoopExitValues(L); + Changed |= rewriteFirstIterationLoopExitValues(L); // Clean up dead instructions. Changed |= DeleteDeadPHIs(L->getHeader(), TLI); diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index c5ed6d5c1b87..1c701bbee185 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -133,34 +133,16 @@ namespace { /// taken by the containing loop's induction variable. /// class InductiveRangeCheck { - // Classifies a range check - enum RangeCheckKind : unsigned { - // Range check of the form "0 <= I". - RANGE_CHECK_LOWER = 1, - - // Range check of the form "I < L" where L is known positive. - RANGE_CHECK_UPPER = 2, - - // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER - // conditions. - RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER, - - // Unrecognized range check condition. - RANGE_CHECK_UNKNOWN = (unsigned)-1 - }; - - static StringRef rangeCheckKindToStr(RangeCheckKind); const SCEV *Begin = nullptr; const SCEV *Step = nullptr; const SCEV *End = nullptr; Use *CheckUse = nullptr; - RangeCheckKind Kind = RANGE_CHECK_UNKNOWN; bool IsSigned = true; - static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, - ScalarEvolution &SE, Value *&Index, - Value *&Length, bool &IsSigned); + static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, + Value *&Index, Value *&Length, + bool &IsSigned); static void extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, @@ -175,7 +157,6 @@ public: void print(raw_ostream &OS) const { OS << "InductiveRangeCheck:\n"; - OS << " Kind: " << rangeCheckKindToStr(Kind) << "\n"; OS << " Begin: "; Begin->print(OS); OS << " Step: "; @@ -283,32 +264,11 @@ INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination", false, false) -StringRef InductiveRangeCheck::rangeCheckKindToStr( - InductiveRangeCheck::RangeCheckKind RCK) { - switch (RCK) { - case InductiveRangeCheck::RANGE_CHECK_UNKNOWN: - return "RANGE_CHECK_UNKNOWN"; - - case InductiveRangeCheck::RANGE_CHECK_UPPER: - return "RANGE_CHECK_UPPER"; - - case InductiveRangeCheck::RANGE_CHECK_LOWER: - return "RANGE_CHECK_LOWER"; - - case InductiveRangeCheck::RANGE_CHECK_BOTH: - return "RANGE_CHECK_BOTH"; - } - - llvm_unreachable("unknown range check type!"); -} - /// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot -/// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set -/// `Index` and `Length` to `nullptr`. Otherwise set `Index` to the value being -/// range checked, and set `Length` to the upper limit `Index` is being range -/// checked with if (and only if) the range check type is stronger or equal to -/// RANGE_CHECK_UPPER. -InductiveRangeCheck::RangeCheckKind +/// be interpreted as a range check, return false and set `Index` and `Length` +/// to `nullptr`. Otherwise set `Index` to the value being range checked, and +/// set `Length` to the upper limit `Index` is being range checked. +bool InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, Value *&Length, bool &IsSigned) { @@ -322,7 +282,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, switch (Pred) { default: - return RANGE_CHECK_UNKNOWN; + return false; case ICmpInst::ICMP_SLE: std::swap(LHS, RHS); @@ -331,9 +291,9 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, IsSigned = true; if (match(RHS, m_ConstantInt<0>())) { Index = LHS; - return RANGE_CHECK_LOWER; + return true; // Lower. } - return RANGE_CHECK_UNKNOWN; + return false; case ICmpInst::ICMP_SLT: std::swap(LHS, RHS); @@ -342,15 +302,15 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, IsSigned = true; if (match(RHS, m_ConstantInt<-1>())) { Index = LHS; - return RANGE_CHECK_LOWER; + return true; // Lower. } if (IsLoopInvariant(LHS)) { Index = RHS; Length = LHS; - return RANGE_CHECK_UPPER; + return true; // Upper. } - return RANGE_CHECK_UNKNOWN; + return false; case ICmpInst::ICMP_ULT: std::swap(LHS, RHS); @@ -360,9 +320,9 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, if (IsLoopInvariant(LHS)) { Index = RHS; Length = LHS; - return RANGE_CHECK_BOTH; + return true; // Both lower and upper. } - return RANGE_CHECK_UNKNOWN; + return false; } llvm_unreachable("default clause returns!"); @@ -391,8 +351,7 @@ void InductiveRangeCheck::extractRangeChecksFromCond( Value *Length = nullptr, *Index; bool IsSigned; - auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned); - if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) + if (!parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned)) return; const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index)); @@ -408,7 +367,6 @@ void InductiveRangeCheck::extractRangeChecksFromCond( if (Length) End = SE.getSCEV(Length); else { - assert(RCKind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); // So far we can only reach this point for Signed range check. This may // change in future. In this case we will need to pick Unsigned max for the // unsigned range check. @@ -422,7 +380,6 @@ void InductiveRangeCheck::extractRangeChecksFromCond( IRC.Begin = IndexAddRec->getStart(); IRC.Step = IndexAddRec->getStepRecurrence(SE); IRC.CheckUse = &ConditionUse; - IRC.Kind = RCKind; IRC.IsSigned = IsSigned; Checks.push_back(IRC); } @@ -689,17 +646,6 @@ void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, PN->setIncomingBlock(i, ReplaceBy); } -static bool CannotBeMaxInLoop(const SCEV *BoundSCEV, Loop *L, - ScalarEvolution &SE, bool Signed) { - unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); - APInt Max = Signed ? APInt::getSignedMaxValue(BitWidth) : - APInt::getMaxValue(BitWidth); - auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; - return SE.isAvailableAtLoopEntry(BoundSCEV, L) && - SE.isLoopEntryGuardedByCond(L, Predicate, BoundSCEV, - SE.getConstant(Max)); -} - /// Given a loop with an deccreasing induction variable, is it possible to /// safely calculate the bounds of a new loop using the given Predicate. static bool isSafeDecreasingBound(const SCEV *Start, @@ -795,31 +741,6 @@ static bool isSafeIncreasingBound(const SCEV *Start, SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); } -static bool CannotBeMinInLoop(const SCEV *BoundSCEV, Loop *L, - ScalarEvolution &SE, bool Signed) { - unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); - APInt Min = Signed ? APInt::getSignedMinValue(BitWidth) : - APInt::getMinValue(BitWidth); - auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; - return SE.isAvailableAtLoopEntry(BoundSCEV, L) && - SE.isLoopEntryGuardedByCond(L, Predicate, BoundSCEV, - SE.getConstant(Min)); -} - -static bool isKnownNonNegativeInLoop(const SCEV *BoundSCEV, const Loop *L, - ScalarEvolution &SE) { - const SCEV *Zero = SE.getZero(BoundSCEV->getType()); - return SE.isAvailableAtLoopEntry(BoundSCEV, L) && - SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGE, BoundSCEV, Zero); -} - -static bool isKnownNegativeInLoop(const SCEV *BoundSCEV, const Loop *L, - ScalarEvolution &SE) { - const SCEV *Zero = SE.getZero(BoundSCEV->getType()); - return SE.isAvailableAtLoopEntry(BoundSCEV, L) && - SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLT, BoundSCEV, Zero); -} - Optional<LoopStructure> LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo *BPI, Loop &L, @@ -977,12 +898,12 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, // ... ... // } } if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && - CannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) { + cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) { Pred = ICmpInst::ICMP_UGT; RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); DecreasedRightValueByOne = true; - } else if (CannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) { + } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) { Pred = ICmpInst::ICMP_SGT; RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); @@ -1042,11 +963,11 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, // ... ... // } } if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && - CannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { + cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { Pred = ICmpInst::ICMP_ULT; RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); IncreasedRightValueByOne = true; - } else if (CannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { + } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { Pred = ICmpInst::ICMP_SLT; RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); IncreasedRightValueByOne = true; @@ -1339,29 +1260,20 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // EnterLoopCond - is it okay to start executing this `LS'? Value *EnterLoopCond = nullptr; - if (Increasing) - EnterLoopCond = IsSignedPredicate - ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) - : B.CreateICmpULT(LS.IndVarStart, ExitSubloopAt); - else - EnterLoopCond = IsSignedPredicate - ? B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt) - : B.CreateICmpUGT(LS.IndVarStart, ExitSubloopAt); + auto Pred = + Increasing + ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) + : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); + EnterLoopCond = B.CreateICmp(Pred, LS.IndVarStart, ExitSubloopAt); B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); PreheaderJump->eraseFromParent(); LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); - Value *TakeBackedgeLoopCond = nullptr; - if (Increasing) - TakeBackedgeLoopCond = IsSignedPredicate - ? B.CreateICmpSLT(LS.IndVarBase, ExitSubloopAt) - : B.CreateICmpULT(LS.IndVarBase, ExitSubloopAt); - else - TakeBackedgeLoopCond = IsSignedPredicate - ? B.CreateICmpSGT(LS.IndVarBase, ExitSubloopAt) - : B.CreateICmpUGT(LS.IndVarBase, ExitSubloopAt); + Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, LS.IndVarBase, + ExitSubloopAt); + Value *CondForBranch = LS.LatchBrExitIdx == 1 ? TakeBackedgeLoopCond : B.CreateNot(TakeBackedgeLoopCond); @@ -1373,15 +1285,7 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. - Value *IterationsLeft = nullptr; - if (Increasing) - IterationsLeft = IsSignedPredicate - ? B.CreateICmpSLT(LS.IndVarBase, LS.LoopExitAt) - : B.CreateICmpULT(LS.IndVarBase, LS.LoopExitAt); - else - IterationsLeft = IsSignedPredicate - ? B.CreateICmpSGT(LS.IndVarBase, LS.LoopExitAt) - : B.CreateICmpUGT(LS.IndVarBase, LS.LoopExitAt); + Value *IterationsLeft = B.CreateICmp(Pred, LS.IndVarBase, LS.LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); BranchInst *BranchToContinuation = @@ -1513,16 +1417,14 @@ bool LoopConstrainer::run() { if (Increasing) ExitPreLoopAtSCEV = *SR.LowLimit; + else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); else { - if (CannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, - IsSignedPredicate)) - ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); - else { - LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "preloop exit limit. HighLimit = " - << *(*SR.HighLimit) << "\n"); - return false; - } + LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "preloop exit limit. HighLimit = " + << *(*SR.HighLimit) << "\n"); + return false; } if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { @@ -1542,16 +1444,14 @@ bool LoopConstrainer::run() { if (Increasing) ExitMainLoopAtSCEV = *SR.HighLimit; + else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); else { - if (CannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, - IsSignedPredicate)) - ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); - else { - LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " - << "mainloop exit limit. LowLimit = " - << *(*SR.LowLimit) << "\n"); - return false; - } + LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "mainloop exit limit. LowLimit = " + << *(*SR.LowLimit) << "\n"); + return false; } if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index 1d66472f93c8..48de56a02834 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -25,12 +25,12 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -38,6 +38,7 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -65,6 +66,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> @@ -285,7 +287,7 @@ bool JumpThreading::runOnFunction(Function &F) { auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - DeferredDominance DDT(*DT); + DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy); std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; bool HasProfileData = F.hasProfileData(); @@ -295,7 +297,7 @@ bool JumpThreading::runOnFunction(Function &F) { BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - bool Changed = Impl.runImpl(F, TLI, LVI, AA, &DDT, HasProfileData, + bool Changed = Impl.runImpl(F, TLI, LVI, AA, &DTU, HasProfileData, std::move(BFI), std::move(BPI)); if (PrintLVIAfterJumpThreading) { dbgs() << "LVI for function '" << F.getName() << "':\n"; @@ -312,7 +314,7 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &LVI = AM.getResult<LazyValueAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); - DeferredDominance DDT(DT); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); std::unique_ptr<BlockFrequencyInfo> BFI; std::unique_ptr<BranchProbabilityInfo> BPI; @@ -322,7 +324,7 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); } - bool Changed = runImpl(F, &TLI, &LVI, &AA, &DDT, HasProfileData, + bool Changed = runImpl(F, &TLI, &LVI, &AA, &DTU, HasProfileData, std::move(BFI), std::move(BPI)); if (!Changed) @@ -336,14 +338,14 @@ PreservedAnalyses JumpThreadingPass::run(Function &F, bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, LazyValueInfo *LVI_, AliasAnalysis *AA_, - DeferredDominance *DDT_, bool HasProfileData_, + DomTreeUpdater *DTU_, bool HasProfileData_, std::unique_ptr<BlockFrequencyInfo> BFI_, std::unique_ptr<BranchProbabilityInfo> BPI_) { LLVM_DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = TLI_; LVI = LVI_; AA = AA_; - DDT = DDT_; + DTU = DTU_; BFI.reset(); BPI.reset(); // When profile data is available, we need to update edge weights after @@ -360,7 +362,9 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // JumpThreading must not processes blocks unreachable from entry. It's a // waste of compute time and can potentially lead to hangs. SmallPtrSet<BasicBlock *, 16> Unreachable; - DominatorTree &DT = DDT->flush(); + assert(DTU && "DTU isn't passed into JumpThreading before using it."); + assert(DTU->hasDomTree() && "JumpThreading relies on DomTree to proceed."); + DominatorTree &DT = DTU->getDomTree(); for (auto &BB : F) if (!DT.isReachableFromEntry(&BB)) Unreachable.insert(&BB); @@ -379,7 +383,7 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // Stop processing BB if it's the entry or is now deleted. The following // routines attempt to eliminate BB and locating a suitable replacement // for the entry is non-trivial. - if (&BB == &F.getEntryBlock() || DDT->pendingDeletedBB(&BB)) + if (&BB == &F.getEntryBlock() || DTU->isBBPendingDeletion(&BB)) continue; if (pred_empty(&BB)) { @@ -390,7 +394,7 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, << '\n'); LoopHeaders.erase(&BB); LVI->eraseBlock(&BB); - DeleteDeadBlock(&BB, DDT); + DeleteDeadBlock(&BB, DTU); Changed = true; continue; } @@ -404,9 +408,9 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, // Don't alter Loop headers and latches to ensure another pass can // detect and transform nested loops later. !LoopHeaders.count(&BB) && !LoopHeaders.count(BI->getSuccessor(0)) && - TryToSimplifyUncondBranchFromEmptyBlock(&BB, DDT)) { - // BB is valid for cleanup here because we passed in DDT. F remains - // BB's parent until a DDT->flush() event. + TryToSimplifyUncondBranchFromEmptyBlock(&BB, DTU)) { + // BB is valid for cleanup here because we passed in DTU. F remains + // BB's parent until a DTU->getDomTree() event. LVI->eraseBlock(&BB); Changed = true; } @@ -415,7 +419,8 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, } while (Changed); LoopHeaders.clear(); - DDT->flush(); + // Flush only the Dominator Tree. + DTU->getDomTree(); LVI->enableDT(); return EverChanged; } @@ -569,9 +574,11 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) { /// BB in the result vector. /// /// This returns true if there were any known values. -bool JumpThreadingPass::ComputeValueKnownInPredecessors( +bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl( Value *V, BasicBlock *BB, PredValueInfo &Result, - ConstantPreference Preference, Instruction *CxtI) { + ConstantPreference Preference, + DenseSet<std::pair<Value *, BasicBlock *>> &RecursionSet, + Instruction *CxtI) { // This method walks up use-def chains recursively. Because of this, we could // get into an infinite loop going around loops in the use-def chain. To // prevent this, keep track of what (value, block) pairs we've already visited @@ -579,10 +586,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( if (!RecursionSet.insert(std::make_pair(V, BB)).second) return false; - // An RAII help to remove this pair from the recursion set once the recursion - // stack pops back out again. - RecursionSetRemover remover(RecursionSet, std::make_pair(V, BB)); - // If V is a constant, then it is known in all predecessors. if (Constant *KC = getKnownConstant(V, Preference)) { for (BasicBlock *Pred : predecessors(BB)) @@ -609,7 +612,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( // "X < 4" and "X < 3" is known true but "X < 4" itself is not available. // Perhaps getConstantOnEdge should be smart enough to do this? - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -626,7 +629,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( /// If I is a PHI node, then we know the incoming values for any constants. if (PHINode *PN = dyn_cast<PHINode>(I)) { - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -652,7 +655,8 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( Value *Source = CI->getOperand(0); if (!isa<PHINode>(Source) && !isa<CmpInst>(Source)) return false; - ComputeValueKnownInPredecessors(Source, BB, Result, Preference, CxtI); + ComputeValueKnownInPredecessorsImpl(Source, BB, Result, Preference, + RecursionSet, CxtI); if (Result.empty()) return false; @@ -672,10 +676,10 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( I->getOpcode() == Instruction::And) { PredValueInfoTy LHSVals, RHSVals; - ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, - WantInteger, CxtI); - ComputeValueKnownInPredecessors(I->getOperand(1), BB, RHSVals, - WantInteger, CxtI); + ComputeValueKnownInPredecessorsImpl(I->getOperand(0), BB, LHSVals, + WantInteger, RecursionSet, CxtI); + ComputeValueKnownInPredecessorsImpl(I->getOperand(1), BB, RHSVals, + WantInteger, RecursionSet, CxtI); if (LHSVals.empty() && RHSVals.empty()) return false; @@ -710,8 +714,8 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( if (I->getOpcode() == Instruction::Xor && isa<ConstantInt>(I->getOperand(1)) && cast<ConstantInt>(I->getOperand(1))->isOne()) { - ComputeValueKnownInPredecessors(I->getOperand(0), BB, Result, - WantInteger, CxtI); + ComputeValueKnownInPredecessorsImpl(I->getOperand(0), BB, Result, + WantInteger, RecursionSet, CxtI); if (Result.empty()) return false; @@ -728,8 +732,8 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( && "A binary operator creating a block address?"); if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { PredValueInfoTy LHSVals; - ComputeValueKnownInPredecessors(BO->getOperand(0), BB, LHSVals, - WantInteger, CxtI); + ComputeValueKnownInPredecessorsImpl(BO->getOperand(0), BB, LHSVals, + WantInteger, RecursionSet, CxtI); // Try to use constant folding to simplify the binary operator. for (const auto &LHSVal : LHSVals) { @@ -759,7 +763,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( const DataLayout &DL = PN->getModule()->getDataLayout(); // We can do this simplification if any comparisons fold to true or false. // See if any do. - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -806,7 +810,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( if (!isa<Instruction>(CmpLHS) || cast<Instruction>(CmpLHS)->getParent() != BB) { - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -838,7 +842,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( match(CmpLHS, m_Add(m_Value(AddLHS), m_ConstantInt(AddConst)))) { if (!isa<Instruction>(AddLHS) || cast<Instruction>(AddLHS)->getParent() != BB) { - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -874,8 +878,8 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( // Try to find a constant value for the LHS of a comparison, // and evaluate it statically if we can. PredValueInfoTy LHSVals; - ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, - WantInteger, CxtI); + ComputeValueKnownInPredecessorsImpl(I->getOperand(0), BB, LHSVals, + WantInteger, RecursionSet, CxtI); for (const auto &LHSVal : LHSVals) { Constant *V = LHSVal.first; @@ -895,8 +899,8 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( Constant *FalseVal = getKnownConstant(SI->getFalseValue(), Preference); PredValueInfoTy Conds; if ((TrueVal || FalseVal) && - ComputeValueKnownInPredecessors(SI->getCondition(), BB, Conds, - WantInteger, CxtI)) { + ComputeValueKnownInPredecessorsImpl(SI->getCondition(), BB, Conds, + WantInteger, RecursionSet, CxtI)) { for (auto &C : Conds) { Constant *Cond = C.first; @@ -923,7 +927,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( } // If all else fails, see if LVI can figure out a constant value for us. - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -942,7 +946,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( /// Since we can pick an arbitrary destination, we pick the successor with the /// fewest predecessors. This should reduce the in-degree of the others. static unsigned GetBestDestForJumpOnUndef(BasicBlock *BB) { - TerminatorInst *BBTerm = BB->getTerminator(); + Instruction *BBTerm = BB->getTerminator(); unsigned MinSucc = 0; BasicBlock *TestBB = BBTerm->getSuccessor(MinSucc); // Compute the successor with the minimum number of predecessors. @@ -974,7 +978,7 @@ static bool hasAddressTakenAndUsed(BasicBlock *BB) { bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // If the block is trivially dead, just return and let the caller nuke it. // This simplifies other transformations. - if (DDT->pendingDeletedBB(BB) || + if (DTU->isBBPendingDeletion(BB) || (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock())) return false; @@ -983,15 +987,15 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // because now the condition in this block can be threaded through // predecessors of our predecessor block. if (BasicBlock *SinglePred = BB->getSinglePredecessor()) { - const TerminatorInst *TI = SinglePred->getTerminator(); - if (!TI->isExceptional() && TI->getNumSuccessors() == 1 && + const Instruction *TI = SinglePred->getTerminator(); + if (!TI->isExceptionalTerminator() && TI->getNumSuccessors() == 1 && SinglePred != BB && !hasAddressTakenAndUsed(BB)) { // If SinglePred was a loop header, BB becomes one. if (LoopHeaders.erase(SinglePred)) LoopHeaders.insert(BB); LVI->eraseBlock(SinglePred); - MergeBasicBlockIntoOnlyPred(BB, nullptr, DDT); + MergeBasicBlockIntoOnlyPred(BB, DTU); // Now that BB is merged into SinglePred (i.e. SinglePred Code followed by // BB code within one basic block `BB`), we need to invalidate the LVI @@ -1075,7 +1079,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { std::vector<DominatorTree::UpdateType> Updates; // Fold the branch/switch. - TerminatorInst *BBTerm = BB->getTerminator(); + Instruction *BBTerm = BB->getTerminator(); Updates.reserve(BBTerm->getNumSuccessors()); for (unsigned i = 0, e = BBTerm->getNumSuccessors(); i != e; ++i) { if (i == BestSucc) continue; @@ -1088,7 +1092,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { << "' folding undef terminator: " << *BBTerm << '\n'); BranchInst::Create(BBTerm->getSuccessor(BestSucc), BBTerm); BBTerm->eraseFromParent(); - DDT->applyUpdates(Updates); + DTU->applyUpdates(Updates); return true; } @@ -1100,7 +1104,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { << "' folding terminator: " << *BB->getTerminator() << '\n'); ++NumFolds; - ConstantFoldTerminator(BB, true, nullptr, DDT); + ConstantFoldTerminator(BB, true, nullptr, DTU); return true; } @@ -1127,7 +1131,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // threading is concerned. assert(CondBr->isConditional() && "Threading on unconditional terminator"); - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -1156,7 +1160,7 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { ConstantInt::getFalse(CondCmp->getType()); ReplaceFoldableUses(CondCmp, CI); } - DDT->deleteEdge(BB, ToRemoveSucc); + DTU->deleteEdgeRelaxed(BB, ToRemoveSucc); return true; } @@ -1167,6 +1171,9 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { } } + if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) + TryToUnfoldSelect(SI, BB); + // Check for some cases that are worth simplifying. Right now we want to look // for loads that are used by a switch or by the condition for the branch. If // we see one, check to see if it's partially redundant. If so, insert a PHI @@ -1240,7 +1247,7 @@ bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) { RemoveSucc->removePredecessor(BB); BranchInst::Create(KeepSucc, BI); BI->eraseFromParent(); - DDT->deleteEdge(BB, RemoveSucc); + DTU->deleteEdgeRelaxed(BB, RemoveSucc); return true; } CurrentBB = CurrentPred; @@ -1296,7 +1303,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { if (IsLoadCSE) { LoadInst *NLoadI = cast<LoadInst>(AvailableVal); - combineMetadataForCSE(NLoadI, LoadI); + combineMetadataForCSE(NLoadI, LoadI, false); }; // If the returned value is the load itself, replace with an undef. This can @@ -1486,7 +1493,7 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LoadI) { } for (LoadInst *PredLoadI : CSELoads) { - combineMetadataForCSE(PredLoadI, LoadI); + combineMetadataForCSE(PredLoadI, LoadI, true); } LoadI->replaceAllUsesWith(PN); @@ -1544,7 +1551,7 @@ FindMostPopularDest(BasicBlock *BB, // successor list. if (!SamePopularity.empty()) { SamePopularity.push_back(MostPopularDest); - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); for (unsigned i = 0; ; ++i) { assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); @@ -1664,10 +1671,10 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, } // Finally update the terminator. - TerminatorInst *Term = BB->getTerminator(); + Instruction *Term = BB->getTerminator(); BranchInst::Create(OnlyDest, Term); Term->eraseFromParent(); - DDT->applyUpdates(Updates); + DTU->applyUpdates(Updates); // If the condition is now dead due to the removal of the old terminator, // erase it. @@ -1945,7 +1952,7 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, << "' with cost: " << JumpThreadCost << ", across block:\n " << *BB << "\n"); - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -1974,7 +1981,7 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, // Clone the non-phi instructions of BB into NewBB, keeping track of the // mapping and using it to remap operands in the cloned instructions. - for (; !isa<TerminatorInst>(BI); ++BI) { + for (; !BI->isTerminator(); ++BI) { Instruction *New = BI->clone(); New->setName(BI->getName()); NewBB->getInstList().push_back(New); @@ -2001,7 +2008,7 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, // Update the terminator of PredBB to jump to NewBB instead of BB. This // eliminates predecessors from BB, which requires us to simplify any PHI // nodes in BB. - TerminatorInst *PredTerm = PredBB->getTerminator(); + Instruction *PredTerm = PredBB->getTerminator(); for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) if (PredTerm->getSuccessor(i) == BB) { BB->removePredecessor(PredBB, true); @@ -2009,7 +2016,7 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, } // Enqueue required DT updates. - DDT->applyUpdates({{DominatorTree::Insert, NewBB, SuccBB}, + DTU->applyUpdates({{DominatorTree::Insert, NewBB, SuccBB}, {DominatorTree::Insert, PredBB, NewBB}, {DominatorTree::Delete, PredBB, BB}}); @@ -2105,12 +2112,12 @@ BasicBlock *JumpThreadingPass::SplitBlockPreds(BasicBlock *BB, BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); } - DDT->applyUpdates(Updates); + DTU->applyUpdates(Updates); return NewBBs[0]; } bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { - const TerminatorInst *TI = BB->getTerminator(); + const Instruction *TI = BB->getTerminator(); assert(TI->getNumSuccessors() > 1 && "not a split"); MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); @@ -2378,12 +2385,78 @@ bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( // Remove the unconditional branch at the end of the PredBB block. OldPredBranch->eraseFromParent(); - DDT->applyUpdates(Updates); + DTU->applyUpdates(Updates); ++NumDupes; return true; } +// Pred is a predecessor of BB with an unconditional branch to BB. SI is +// a Select instruction in Pred. BB has other predecessors and SI is used in +// a PHI node in BB. SI has no other use. +// A new basic block, NewBB, is created and SI is converted to compare and +// conditional branch. SI is erased from parent. +void JumpThreadingPass::UnfoldSelectInstr(BasicBlock *Pred, BasicBlock *BB, + SelectInst *SI, PHINode *SIUse, + unsigned Idx) { + // Expand the select. + // + // Pred -- + // | v + // | NewBB + // | | + // |----- + // v + // BB + BranchInst *PredTerm = dyn_cast<BranchInst>(Pred->getTerminator()); + BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "select.unfold", + BB->getParent(), BB); + // Move the unconditional branch to NewBB. + PredTerm->removeFromParent(); + NewBB->getInstList().insert(NewBB->end(), PredTerm); + // Create a conditional branch and update PHI nodes. + BranchInst::Create(NewBB, BB, SI->getCondition(), Pred); + SIUse->setIncomingValue(Idx, SI->getFalseValue()); + SIUse->addIncoming(SI->getTrueValue(), NewBB); + + // The select is now dead. + SI->eraseFromParent(); + DTU->applyUpdates({{DominatorTree::Insert, NewBB, BB}, + {DominatorTree::Insert, Pred, NewBB}}); + + // Update any other PHI nodes in BB. + for (BasicBlock::iterator BI = BB->begin(); + PHINode *Phi = dyn_cast<PHINode>(BI); ++BI) + if (Phi != SIUse) + Phi->addIncoming(Phi->getIncomingValueForBlock(Pred), NewBB); +} + +bool JumpThreadingPass::TryToUnfoldSelect(SwitchInst *SI, BasicBlock *BB) { + PHINode *CondPHI = dyn_cast<PHINode>(SI->getCondition()); + + if (!CondPHI || CondPHI->getParent() != BB) + return false; + + for (unsigned I = 0, E = CondPHI->getNumIncomingValues(); I != E; ++I) { + BasicBlock *Pred = CondPHI->getIncomingBlock(I); + SelectInst *PredSI = dyn_cast<SelectInst>(CondPHI->getIncomingValue(I)); + + // The second and third condition can be potentially relaxed. Currently + // the conditions help to simplify the code and allow us to reuse existing + // code, developed for TryToUnfoldSelect(CmpInst *, BasicBlock *) + if (!PredSI || PredSI->getParent() != Pred || !PredSI->hasOneUse()) + continue; + + BranchInst *PredTerm = dyn_cast<BranchInst>(Pred->getTerminator()); + if (!PredTerm || !PredTerm->isUnconditional()) + continue; + + UnfoldSelectInstr(Pred, BB, PredSI, CondPHI, I); + return true; + } + return false; +} + /// TryToUnfoldSelect - Look for blocks of the form /// bb1: /// %a = select @@ -2421,7 +2494,7 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { // Now check if one of the select values would allow us to constant fold the // terminator in BB. We don't do the transform if both sides fold, those // cases will be threaded in any case. - if (DDT->pending()) + if (DTU->hasPendingDomTreeUpdates()) LVI->disableDT(); else LVI->enableDT(); @@ -2434,34 +2507,7 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { if ((LHSFolds != LazyValueInfo::Unknown || RHSFolds != LazyValueInfo::Unknown) && LHSFolds != RHSFolds) { - // Expand the select. - // - // Pred -- - // | v - // | NewBB - // | | - // |----- - // v - // BB - BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "select.unfold", - BB->getParent(), BB); - // Move the unconditional branch to NewBB. - PredTerm->removeFromParent(); - NewBB->getInstList().insert(NewBB->end(), PredTerm); - // Create a conditional branch and update PHI nodes. - BranchInst::Create(NewBB, BB, SI->getCondition(), Pred); - CondLHS->setIncomingValue(I, SI->getFalseValue()); - CondLHS->addIncoming(SI->getTrueValue(), NewBB); - // The select is now dead. - SI->eraseFromParent(); - - DDT->applyUpdates({{DominatorTree::Insert, NewBB, BB}, - {DominatorTree::Insert, Pred, NewBB}}); - // Update any other PHI nodes in BB. - for (BasicBlock::iterator BI = BB->begin(); - PHINode *Phi = dyn_cast<PHINode>(BI); ++BI) - if (Phi != CondLHS) - Phi->addIncoming(Phi->getIncomingValueForBlock(Pred), NewBB); + UnfoldSelectInstr(Pred, BB, SI, CondLHS, I); return true; } } @@ -2533,7 +2579,7 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { if (!SI) continue; // Expand the select. - TerminatorInst *Term = + Instruction *Term = SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); BasicBlock *SplitBB = SI->getParent(); BasicBlock *NewBB = Term->getParent(); @@ -2548,12 +2594,12 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { Updates.push_back({DominatorTree::Insert, BB, SplitBB}); Updates.push_back({DominatorTree::Insert, BB, NewBB}); Updates.push_back({DominatorTree::Insert, NewBB, SplitBB}); - // BB's successors were moved to SplitBB, update DDT accordingly. + // BB's successors were moved to SplitBB, update DTU accordingly. for (auto *Succ : successors(SplitBB)) { Updates.push_back({DominatorTree::Delete, BB, Succ}); Updates.push_back({DominatorTree::Insert, SplitBB, Succ}); } - DDT->applyUpdates(Updates); + DTU->applyUpdates(Updates); return true; } return false; @@ -2603,9 +2649,8 @@ bool JumpThreadingPass::ProcessGuards(BasicBlock *BB) { if (auto *BI = dyn_cast<BranchInst>(Parent->getTerminator())) for (auto &I : *BB) - if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>())) - if (ThreadGuard(BB, cast<IntrinsicInst>(&I), BI)) - return true; + if (isGuard(&I) && ThreadGuard(BB, cast<IntrinsicInst>(&I), BI)) + return true; return false; } @@ -2651,28 +2696,16 @@ bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard, // Duplicate all instructions before the guard and the guard itself to the // branch where implication is not proved. BasicBlock *GuardedBlock = DuplicateInstructionsInSplitBetween( - BB, PredGuardedBlock, AfterGuard, GuardedMapping); + BB, PredGuardedBlock, AfterGuard, GuardedMapping, *DTU); assert(GuardedBlock && "Could not create the guarded block?"); // Duplicate all instructions before the guard in the unguarded branch. // Since we have successfully duplicated the guarded block and this block // has fewer instructions, we expect it to succeed. BasicBlock *UnguardedBlock = DuplicateInstructionsInSplitBetween( - BB, PredUnguardedBlock, Guard, UnguardedMapping); + BB, PredUnguardedBlock, Guard, UnguardedMapping, *DTU); assert(UnguardedBlock && "Could not create the unguarded block?"); LLVM_DEBUG(dbgs() << "Moved guard " << *Guard << " to block " << GuardedBlock->getName() << "\n"); - // DuplicateInstructionsInSplitBetween inserts a new block "BB.split" between - // PredBB and BB. We need to perform two inserts and one delete for each of - // the above calls to update Dominators. - DDT->applyUpdates( - {// Guarded block split. - {DominatorTree::Delete, PredGuardedBlock, BB}, - {DominatorTree::Insert, PredGuardedBlock, GuardedBlock}, - {DominatorTree::Insert, GuardedBlock, BB}, - // Unguarded block split. - {DominatorTree::Delete, PredUnguardedBlock, BB}, - {DominatorTree::Insert, PredUnguardedBlock, UnguardedBlock}, - {DominatorTree::Insert, UnguardedBlock, BB}}); // Some instructions before the guard may still have uses. For them, we need // to create Phi nodes merging their copies in both guarded and unguarded // branches. Those instructions that have no uses can be just removed. diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index c4ea43a43249..d204654c3915 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -31,6 +31,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LICM.h" +#include "llvm/ADT/SetOperations.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" @@ -38,16 +39,18 @@ #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" @@ -58,6 +61,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/PredIteratorCache.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -65,6 +69,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> @@ -73,6 +78,8 @@ using namespace llvm; #define DEBUG_TYPE "licm" +STATISTIC(NumCreatedBlocks, "Number of blocks created"); +STATISTIC(NumClonedBranches, "Number of branches cloned"); STATISTIC(NumSunk, "Number of instructions sunk out of loop"); STATISTIC(NumHoisted, "Number of instructions hoisted out of loop"); STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk"); @@ -84,51 +91,81 @@ static cl::opt<bool> DisablePromotion("disable-licm-promotion", cl::Hidden, cl::init(false), cl::desc("Disable memory promotion in LICM pass")); +static cl::opt<bool> ControlFlowHoisting( + "licm-control-flow-hoisting", cl::Hidden, cl::init(false), + cl::desc("Enable control flow (and PHI) hoisting in LICM")); + static cl::opt<uint32_t> MaxNumUsesTraversed( "licm-max-num-uses-traversed", cl::Hidden, cl::init(8), cl::desc("Max num uses visited for identifying load " "invariance in loop using invariant start (default = 8)")); +// Default value of zero implies we use the regular alias set tracker mechanism +// instead of the cross product using AA to identify aliasing of the memory +// location we are interested in. +static cl::opt<int> +LICMN2Theshold("licm-n2-threshold", cl::Hidden, cl::init(0), + cl::desc("How many instruction to cross product using AA")); + +// Experimental option to allow imprecision in LICM (use MemorySSA cap) in +// pathological cases, in exchange for faster compile. This is to be removed +// if MemorySSA starts to address the same issue. This flag applies only when +// LICM uses MemorySSA instead on AliasSetTracker. When the flag is disabled +// (default), LICM calls MemorySSAWalker's getClobberingMemoryAccess, which +// gets perfect accuracy. When flag is enabled, LICM will call into MemorySSA's +// getDefiningAccess, which may not be precise, since optimizeUses is capped. +static cl::opt<bool> EnableLicmCap( + "enable-licm-cap", cl::init(false), cl::Hidden, + cl::desc("Enable imprecision in LICM (uses MemorySSA cap) in " + "pathological cases, in exchange for faster compile")); + static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, TargetTransformInfo *TTI, bool &FreeInLoop); -static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE); +static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, + BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, + MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE); static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, - const Loop *CurLoop, LoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE, bool FreeInLoop); + const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, + MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE, + bool FreeInLoop); static bool isSafeToExecuteUnconditionally(Instruction &Inst, const DominatorTree *DT, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE, const Instruction *CtxI = nullptr); -static bool pointerInvalidatedByLoop(Value *V, uint64_t Size, - const AAMDNodes &AAInfo, - AliasSetTracker *CurAST); -static Instruction * -CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, - const LoopInfo *LI, - const LoopSafetyInfo *SafetyInfo); +static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, + AliasSetTracker *CurAST, Loop *CurLoop, + AliasAnalysis *AA); +static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, + Loop *CurLoop); +static Instruction *CloneInstructionInExitBlock( + Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, + const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU); + +static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, + AliasSetTracker *AST, MemorySSAUpdater *MSSAU); + +static void moveInstructionBefore(Instruction &I, Instruction &Dest, + ICFLoopSafetyInfo &SafetyInfo); namespace { struct LoopInvariantCodeMotion { + using ASTrackerMapTy = DenseMap<Loop *, std::unique_ptr<AliasSetTracker>>; bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE, MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool DeleteAST); - DenseMap<Loop *, AliasSetTracker *> &getLoopToAliasSetMap() { - return LoopToAliasSetMap; - } + ASTrackerMapTy &getLoopToAliasSetMap() { return LoopToAliasSetMap; } private: - DenseMap<Loop *, AliasSetTracker *> LoopToAliasSetMap; + ASTrackerMapTy LoopToAliasSetMap; - AliasSetTracker *collectAliasInfoForLoop(Loop *L, LoopInfo *LI, - AliasAnalysis *AA); + std::unique_ptr<AliasSetTracker> + collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AliasAnalysis *AA); }; struct LegacyLICMPass : public LoopPass { @@ -142,8 +179,6 @@ struct LegacyLICMPass : public LoopPass { // If we have run LICM on a previous loop but now we are skipping // (because we've hit the opt-bisect limit), we need to clear the // loop alias information. - for (auto <AS : LICM.getLoopToAliasSetMap()) - delete LTAS.second; LICM.getLoopToAliasSetMap().clear(); return false; } @@ -173,8 +208,10 @@ struct LegacyLICMPass : public LoopPass { AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<LoopInfoWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); - if (EnableMSSALoopDependency) + if (EnableMSSALoopDependency) { AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } AU.addRequired<TargetTransformInfoWrapperPass>(); getLoopAnalysisUsage(AU); } @@ -254,14 +291,22 @@ bool LoopInvariantCodeMotion::runOnLoop( assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); - AliasSetTracker *CurAST = collectAliasInfoForLoop(L, LI, AA); + std::unique_ptr<AliasSetTracker> CurAST; + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (!MSSA) { + LLVM_DEBUG(dbgs() << "LICM: Using Alias Set Tracker.\n"); + CurAST = collectAliasInfoForLoop(L, LI, AA); + } else { + LLVM_DEBUG(dbgs() << "LICM: Using MemorySSA. Promotion disabled.\n"); + MSSAU = make_unique<MemorySSAUpdater>(MSSA); + } // Get the preheader block to move instructions into... BasicBlock *Preheader = L->getLoopPreheader(); // Compute loop safety information. - LoopSafetyInfo SafetyInfo; - computeLoopSafetyInfo(&SafetyInfo, L); + ICFLoopSafetyInfo SafetyInfo(DT); + SafetyInfo.computeLoopSafetyInfo(L); // We want to visit all of the instructions in this loop... that are not parts // of our subloops (they have already had their invariants hoisted out of @@ -275,10 +320,10 @@ bool LoopInvariantCodeMotion::runOnLoop( // if (L->hasDedicatedExits()) Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L, - CurAST, &SafetyInfo, ORE); + CurAST.get(), MSSAU.get(), &SafetyInfo, ORE); if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, - CurAST, &SafetyInfo, ORE); + CurAST.get(), MSSAU.get(), &SafetyInfo, ORE); // Now that all loop invariants have been removed from the loop, promote any // memory references to scalars that we can. @@ -307,27 +352,30 @@ bool LoopInvariantCodeMotion::runOnLoop( bool Promoted = false; - // Loop over all of the alias sets in the tracker object. - for (AliasSet &AS : *CurAST) { - // We can promote this alias set if it has a store, if it is a "Must" - // alias set, if the pointer is loop invariant, and if we are not - // eliminating any volatile loads or stores. - if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || - AS.isVolatile() || !L->isLoopInvariant(AS.begin()->getValue())) - continue; - - assert( - !AS.empty() && - "Must alias set should have at least one pointer element in it!"); - - SmallSetVector<Value *, 8> PointerMustAliases; - for (const auto &ASI : AS) - PointerMustAliases.insert(ASI.getValue()); - - Promoted |= promoteLoopAccessesToScalars(PointerMustAliases, ExitBlocks, - InsertPts, PIC, LI, DT, TLI, L, - CurAST, &SafetyInfo, ORE); + if (CurAST.get()) { + // Loop over all of the alias sets in the tracker object. + for (AliasSet &AS : *CurAST) { + // We can promote this alias set if it has a store, if it is a "Must" + // alias set, if the pointer is loop invariant, and if we are not + // eliminating any volatile loads or stores. + if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || + !L->isLoopInvariant(AS.begin()->getValue())) + continue; + + assert( + !AS.empty() && + "Must alias set should have at least one pointer element in it!"); + + SmallSetVector<Value *, 8> PointerMustAliases; + for (const auto &ASI : AS) + PointerMustAliases.insert(ASI.getValue()); + + Promoted |= promoteLoopAccessesToScalars( + PointerMustAliases, ExitBlocks, InsertPts, PIC, LI, DT, TLI, L, + CurAST.get(), &SafetyInfo, ORE); + } } + // FIXME: Promotion initially disabled when using MemorySSA. // Once we have promoted values across the loop body we have to // recursively reform LCSSA as any nested loop may now have values defined @@ -351,10 +399,11 @@ bool LoopInvariantCodeMotion::runOnLoop( // If this loop is nested inside of another one, save the alias information // for when we process the outer loop. - if (L->getParentLoop() && !DeleteAST) - LoopToAliasSetMap[L] = CurAST; - else - delete CurAST; + if (CurAST.get() && L->getParentLoop() && !DeleteAST) + LoopToAliasSetMap[L] = std::move(CurAST); + + if (MSSAU.get() && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); if (Changed && SE) SE->forgetLoopDispositions(L); @@ -369,13 +418,16 @@ bool LoopInvariantCodeMotion::runOnLoop( bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, TargetTransformInfo *TTI, Loop *CurLoop, - AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, + ICFLoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && - CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && - "Unexpected input to sinkRegion"); + CurLoop != nullptr && SafetyInfo != nullptr && + "Unexpected input to sinkRegion."); + assert(((CurAST != nullptr) ^ (MSSAU != nullptr)) && + "Either AliasSetTracker or MemorySSA should be initialized."); // We want to visit children before parents. We will enque all the parents // before their children in the worklist and process the worklist in reverse @@ -399,8 +451,7 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, LLVM_DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); salvageDebugInfo(I); ++II; - CurAST->deleteValue(&I); - I.eraseFromParent(); + eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); Changed = true; continue; } @@ -412,21 +463,252 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // bool FreeInLoop = false; if (isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE)) { - if (sink(I, LI, DT, CurLoop, SafetyInfo, ORE, FreeInLoop)) { + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, ORE) && + !I.mayHaveSideEffects()) { + if (sink(I, LI, DT, CurLoop, SafetyInfo, MSSAU, ORE, FreeInLoop)) { if (!FreeInLoop) { ++II; - CurAST->deleteValue(&I); - I.eraseFromParent(); + eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); } Changed = true; } } } } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); return Changed; } +namespace { +// This is a helper class for hoistRegion to make it able to hoist control flow +// in order to be able to hoist phis. The way this works is that we initially +// start hoisting to the loop preheader, and when we see a loop invariant branch +// we make note of this. When we then come to hoist an instruction that's +// conditional on such a branch we duplicate the branch and the relevant control +// flow, then hoist the instruction into the block corresponding to its original +// block in the duplicated control flow. +class ControlFlowHoister { +private: + // Information about the loop we are hoisting from + LoopInfo *LI; + DominatorTree *DT; + Loop *CurLoop; + MemorySSAUpdater *MSSAU; + + // A map of blocks in the loop to the block their instructions will be hoisted + // to. + DenseMap<BasicBlock *, BasicBlock *> HoistDestinationMap; + + // The branches that we can hoist, mapped to the block that marks a + // convergence point of their control flow. + DenseMap<BranchInst *, BasicBlock *> HoistableBranches; + +public: + ControlFlowHoister(LoopInfo *LI, DominatorTree *DT, Loop *CurLoop, + MemorySSAUpdater *MSSAU) + : LI(LI), DT(DT), CurLoop(CurLoop), MSSAU(MSSAU) {} + + void registerPossiblyHoistableBranch(BranchInst *BI) { + // We can only hoist conditional branches with loop invariant operands. + if (!ControlFlowHoisting || !BI->isConditional() || + !CurLoop->hasLoopInvariantOperands(BI)) + return; + + // The branch destinations need to be in the loop, and we don't gain + // anything by duplicating conditional branches with duplicate successors, + // as it's essentially the same as an unconditional branch. + BasicBlock *TrueDest = BI->getSuccessor(0); + BasicBlock *FalseDest = BI->getSuccessor(1); + if (!CurLoop->contains(TrueDest) || !CurLoop->contains(FalseDest) || + TrueDest == FalseDest) + return; + + // We can hoist BI if one branch destination is the successor of the other, + // or both have common successor which we check by seeing if the + // intersection of their successors is non-empty. + // TODO: This could be expanded to allowing branches where both ends + // eventually converge to a single block. + SmallPtrSet<BasicBlock *, 4> TrueDestSucc, FalseDestSucc; + TrueDestSucc.insert(succ_begin(TrueDest), succ_end(TrueDest)); + FalseDestSucc.insert(succ_begin(FalseDest), succ_end(FalseDest)); + BasicBlock *CommonSucc = nullptr; + if (TrueDestSucc.count(FalseDest)) { + CommonSucc = FalseDest; + } else if (FalseDestSucc.count(TrueDest)) { + CommonSucc = TrueDest; + } else { + set_intersect(TrueDestSucc, FalseDestSucc); + // If there's one common successor use that. + if (TrueDestSucc.size() == 1) + CommonSucc = *TrueDestSucc.begin(); + // If there's more than one pick whichever appears first in the block list + // (we can't use the value returned by TrueDestSucc.begin() as it's + // unpredicatable which element gets returned). + else if (!TrueDestSucc.empty()) { + Function *F = TrueDest->getParent(); + auto IsSucc = [&](BasicBlock &BB) { return TrueDestSucc.count(&BB); }; + auto It = std::find_if(F->begin(), F->end(), IsSucc); + assert(It != F->end() && "Could not find successor in function"); + CommonSucc = &*It; + } + } + // The common successor has to be dominated by the branch, as otherwise + // there will be some other path to the successor that will not be + // controlled by this branch so any phi we hoist would be controlled by the + // wrong condition. This also takes care of avoiding hoisting of loop back + // edges. + // TODO: In some cases this could be relaxed if the successor is dominated + // by another block that's been hoisted and we can guarantee that the + // control flow has been replicated exactly. + if (CommonSucc && DT->dominates(BI, CommonSucc)) + HoistableBranches[BI] = CommonSucc; + } + + bool canHoistPHI(PHINode *PN) { + // The phi must have loop invariant operands. + if (!ControlFlowHoisting || !CurLoop->hasLoopInvariantOperands(PN)) + return false; + // We can hoist phis if the block they are in is the target of hoistable + // branches which cover all of the predecessors of the block. + SmallPtrSet<BasicBlock *, 8> PredecessorBlocks; + BasicBlock *BB = PN->getParent(); + for (BasicBlock *PredBB : predecessors(BB)) + PredecessorBlocks.insert(PredBB); + // If we have less predecessor blocks than predecessors then the phi will + // have more than one incoming value for the same block which we can't + // handle. + // TODO: This could be handled be erasing some of the duplicate incoming + // values. + if (PredecessorBlocks.size() != pred_size(BB)) + return false; + for (auto &Pair : HoistableBranches) { + if (Pair.second == BB) { + // Which blocks are predecessors via this branch depends on if the + // branch is triangle-like or diamond-like. + if (Pair.first->getSuccessor(0) == BB) { + PredecessorBlocks.erase(Pair.first->getParent()); + PredecessorBlocks.erase(Pair.first->getSuccessor(1)); + } else if (Pair.first->getSuccessor(1) == BB) { + PredecessorBlocks.erase(Pair.first->getParent()); + PredecessorBlocks.erase(Pair.first->getSuccessor(0)); + } else { + PredecessorBlocks.erase(Pair.first->getSuccessor(0)); + PredecessorBlocks.erase(Pair.first->getSuccessor(1)); + } + } + } + // PredecessorBlocks will now be empty if for every predecessor of BB we + // found a hoistable branch source. + return PredecessorBlocks.empty(); + } + + BasicBlock *getOrCreateHoistedBlock(BasicBlock *BB) { + if (!ControlFlowHoisting) + return CurLoop->getLoopPreheader(); + // If BB has already been hoisted, return that + if (HoistDestinationMap.count(BB)) + return HoistDestinationMap[BB]; + + // Check if this block is conditional based on a pending branch + auto HasBBAsSuccessor = + [&](DenseMap<BranchInst *, BasicBlock *>::value_type &Pair) { + return BB != Pair.second && (Pair.first->getSuccessor(0) == BB || + Pair.first->getSuccessor(1) == BB); + }; + auto It = std::find_if(HoistableBranches.begin(), HoistableBranches.end(), + HasBBAsSuccessor); + + // If not involved in a pending branch, hoist to preheader + BasicBlock *InitialPreheader = CurLoop->getLoopPreheader(); + if (It == HoistableBranches.end()) { + LLVM_DEBUG(dbgs() << "LICM using " << InitialPreheader->getName() + << " as hoist destination for " << BB->getName() + << "\n"); + HoistDestinationMap[BB] = InitialPreheader; + return InitialPreheader; + } + BranchInst *BI = It->first; + assert(std::find_if(++It, HoistableBranches.end(), HasBBAsSuccessor) == + HoistableBranches.end() && + "BB is expected to be the target of at most one branch"); + + LLVMContext &C = BB->getContext(); + BasicBlock *TrueDest = BI->getSuccessor(0); + BasicBlock *FalseDest = BI->getSuccessor(1); + BasicBlock *CommonSucc = HoistableBranches[BI]; + BasicBlock *HoistTarget = getOrCreateHoistedBlock(BI->getParent()); + + // Create hoisted versions of blocks that currently don't have them + auto CreateHoistedBlock = [&](BasicBlock *Orig) { + if (HoistDestinationMap.count(Orig)) + return HoistDestinationMap[Orig]; + BasicBlock *New = + BasicBlock::Create(C, Orig->getName() + ".licm", Orig->getParent()); + HoistDestinationMap[Orig] = New; + DT->addNewBlock(New, HoistTarget); + if (CurLoop->getParentLoop()) + CurLoop->getParentLoop()->addBasicBlockToLoop(New, *LI); + ++NumCreatedBlocks; + LLVM_DEBUG(dbgs() << "LICM created " << New->getName() + << " as hoist destination for " << Orig->getName() + << "\n"); + return New; + }; + BasicBlock *HoistTrueDest = CreateHoistedBlock(TrueDest); + BasicBlock *HoistFalseDest = CreateHoistedBlock(FalseDest); + BasicBlock *HoistCommonSucc = CreateHoistedBlock(CommonSucc); + + // Link up these blocks with branches. + if (!HoistCommonSucc->getTerminator()) { + // The new common successor we've generated will branch to whatever that + // hoist target branched to. + BasicBlock *TargetSucc = HoistTarget->getSingleSuccessor(); + assert(TargetSucc && "Expected hoist target to have a single successor"); + HoistCommonSucc->moveBefore(TargetSucc); + BranchInst::Create(TargetSucc, HoistCommonSucc); + } + if (!HoistTrueDest->getTerminator()) { + HoistTrueDest->moveBefore(HoistCommonSucc); + BranchInst::Create(HoistCommonSucc, HoistTrueDest); + } + if (!HoistFalseDest->getTerminator()) { + HoistFalseDest->moveBefore(HoistCommonSucc); + BranchInst::Create(HoistCommonSucc, HoistFalseDest); + } + + // If BI is being cloned to what was originally the preheader then + // HoistCommonSucc will now be the new preheader. + if (HoistTarget == InitialPreheader) { + // Phis in the loop header now need to use the new preheader. + InitialPreheader->replaceSuccessorsPhiUsesWith(HoistCommonSucc); + if (MSSAU) + MSSAU->wireOldPredecessorsToNewImmediatePredecessor( + HoistTarget->getSingleSuccessor(), HoistCommonSucc, {HoistTarget}); + // The new preheader dominates the loop header. + DomTreeNode *PreheaderNode = DT->getNode(HoistCommonSucc); + DomTreeNode *HeaderNode = DT->getNode(CurLoop->getHeader()); + DT->changeImmediateDominator(HeaderNode, PreheaderNode); + // The preheader hoist destination is now the new preheader, with the + // exception of the hoist destination of this branch. + for (auto &Pair : HoistDestinationMap) + if (Pair.second == InitialPreheader && Pair.first != BI->getParent()) + Pair.second = HoistCommonSucc; + } + + // Now finally clone BI. + ReplaceInstWithInst( + HoistTarget->getTerminator(), + BranchInst::Create(HoistTrueDest, HoistFalseDest, BI->getCondition())); + ++NumClonedBranches; + + assert(CurLoop->getLoopPreheader() && + "Hoisting blocks should not have destroyed preheader"); + return HoistDestinationMap[BB]; + } +}; +} // namespace + /// Walk the specified region of the CFG (defined by all blocks dominated by /// the specified block, and that are in the current loop) in depth first /// order w.r.t the DominatorTree. This allows us to visit definitions before @@ -434,30 +716,34 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, /// bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, - AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + AliasSetTracker *CurAST, MemorySSAUpdater *MSSAU, + ICFLoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && - CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && - "Unexpected input to hoistRegion"); - - // We want to visit parents before children. We will enque all the parents - // before their children in the worklist and process the worklist in order. - SmallVector<DomTreeNode *, 16> Worklist = collectChildrenInLoop(N, CurLoop); - + CurLoop != nullptr && SafetyInfo != nullptr && + "Unexpected input to hoistRegion."); + assert(((CurAST != nullptr) ^ (MSSAU != nullptr)) && + "Either AliasSetTracker or MemorySSA should be initialized."); + + ControlFlowHoister CFH(LI, DT, CurLoop, MSSAU); + + // Keep track of instructions that have been hoisted, as they may need to be + // re-hoisted if they end up not dominating all of their uses. + SmallVector<Instruction *, 16> HoistedInstructions; + + // For PHI hoisting to work we need to hoist blocks before their successors. + // We can do this by iterating through the blocks in the loop in reverse + // post-order. + LoopBlocksRPO Worklist(CurLoop); + Worklist.perform(LI); bool Changed = false; - for (DomTreeNode *DTN : Worklist) { - BasicBlock *BB = DTN->getBlock(); + for (BasicBlock *BB : Worklist) { // Only need to process the contents of this block if it is not part of a // subloop (which would already have been processed). if (inSubLoop(BB, CurLoop, LI)) continue; - // Keep track of whether the prefix of instructions visited so far are such - // that the next instruction visited is guaranteed to execute if the loop - // is entered. - bool IsMustExecute = CurLoop->getHeader() == BB; - for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { Instruction &I = *II++; // Try constant folding this instruction. If all the operands are @@ -467,12 +753,12 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, &I, I.getModule()->getDataLayout(), TLI)) { LLVM_DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C << '\n'); - CurAST->copyValue(&I, C); + if (CurAST) + CurAST->copyValue(&I, C); + // FIXME MSSA: Such replacements may make accesses unoptimized (D51960). I.replaceAllUsesWith(C); - if (isInstructionTriviallyDead(&I, TLI)) { - CurAST->deleteValue(&I); - I.eraseFromParent(); - } + if (isInstructionTriviallyDead(&I, TLI)) + eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); Changed = true; continue; } @@ -480,14 +766,18 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, // Try hoisting the instruction out to the preheader. We can only do // this if all of the operands of the instruction are loop invariant and // if it is safe to hoist the instruction. - // + // TODO: It may be safe to hoist if we are hoisting to a conditional block + // and we have accurately duplicated the control flow from the loop header + // to that block. if (CurLoop->hasLoopInvariantOperands(&I) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE) && - (IsMustExecute || - isSafeToExecuteUnconditionally( - I, DT, CurLoop, SafetyInfo, ORE, - CurLoop->getLoopPreheader()->getTerminator()))) { - Changed |= hoist(I, DT, CurLoop, SafetyInfo, ORE); + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, MSSAU, true, ORE) && + isSafeToExecuteUnconditionally( + I, DT, CurLoop, SafetyInfo, ORE, + CurLoop->getLoopPreheader()->getTerminator())) { + hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, + MSSAU, ORE); + HoistedInstructions.push_back(&I); + Changed = true; continue; } @@ -500,24 +790,101 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); + SafetyInfo->insertInstructionTo(ReciprocalDivisor, I.getParent()); ReciprocalDivisor->insertBefore(&I); auto Product = BinaryOperator::CreateFMul(I.getOperand(0), ReciprocalDivisor); Product->setFastMathFlags(I.getFastMathFlags()); + SafetyInfo->insertInstructionTo(Product, I.getParent()); Product->insertAfter(&I); I.replaceAllUsesWith(Product); - I.eraseFromParent(); + eraseInstruction(I, *SafetyInfo, CurAST, MSSAU); - hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); + hoist(*ReciprocalDivisor, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), + SafetyInfo, MSSAU, ORE); + HoistedInstructions.push_back(ReciprocalDivisor); Changed = true; continue; } - if (IsMustExecute) - IsMustExecute = isGuaranteedToTransferExecutionToSuccessor(&I); + using namespace PatternMatch; + if (((I.use_empty() && + match(&I, m_Intrinsic<Intrinsic::invariant_start>())) || + isGuard(&I)) && + CurLoop->hasLoopInvariantOperands(&I) && + SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop) && + SafetyInfo->doesNotWriteMemoryBefore(I, CurLoop)) { + hoist(I, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, + MSSAU, ORE); + HoistedInstructions.push_back(&I); + Changed = true; + continue; + } + + if (PHINode *PN = dyn_cast<PHINode>(&I)) { + if (CFH.canHoistPHI(PN)) { + // Redirect incoming blocks first to ensure that we create hoisted + // versions of those blocks before we hoist the phi. + for (unsigned int i = 0; i < PN->getNumIncomingValues(); ++i) + PN->setIncomingBlock( + i, CFH.getOrCreateHoistedBlock(PN->getIncomingBlock(i))); + hoist(*PN, DT, CurLoop, CFH.getOrCreateHoistedBlock(BB), SafetyInfo, + MSSAU, ORE); + assert(DT->dominates(PN, BB) && "Conditional PHIs not expected"); + Changed = true; + continue; + } + } + + // Remember possibly hoistable branches so we can actually hoist them + // later if needed. + if (BranchInst *BI = dyn_cast<BranchInst>(&I)) + CFH.registerPossiblyHoistableBranch(BI); + } + } + + // If we hoisted instructions to a conditional block they may not dominate + // their uses that weren't hoisted (such as phis where some operands are not + // loop invariant). If so make them unconditional by moving them to their + // immediate dominator. We iterate through the instructions in reverse order + // which ensures that when we rehoist an instruction we rehoist its operands, + // and also keep track of where in the block we are rehoisting to to make sure + // that we rehoist instructions before the instructions that use them. + Instruction *HoistPoint = nullptr; + if (ControlFlowHoisting) { + for (Instruction *I : reverse(HoistedInstructions)) { + if (!llvm::all_of(I->uses(), + [&](Use &U) { return DT->dominates(I, U); })) { + BasicBlock *Dominator = + DT->getNode(I->getParent())->getIDom()->getBlock(); + if (!HoistPoint || !DT->dominates(HoistPoint->getParent(), Dominator)) { + if (HoistPoint) + assert(DT->dominates(Dominator, HoistPoint->getParent()) && + "New hoist point expected to dominate old hoist point"); + HoistPoint = Dominator->getTerminator(); + } + LLVM_DEBUG(dbgs() << "LICM rehoisting to " + << HoistPoint->getParent()->getName() + << ": " << *I << "\n"); + moveInstructionBefore(*I, *HoistPoint, *SafetyInfo); + HoistPoint = I; + Changed = true; + } } } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + + // Now that we've finished hoisting make sure that LI and DT are still + // valid. +#ifndef NDEBUG + if (Changed) { + assert(DT->verify(DominatorTree::VerificationLevel::Fast) && + "Dominator tree verification failed"); + LI->verify(*DT); + } +#endif return Changed; } @@ -575,13 +942,68 @@ static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, return false; } +namespace { +/// Return true if-and-only-if we know how to (mechanically) both hoist and +/// sink a given instruction out of a loop. Does not address legality +/// concerns such as aliasing or speculation safety. +bool isHoistableAndSinkableInst(Instruction &I) { + // Only these instructions are hoistable/sinkable. + return (isa<LoadInst>(I) || isa<StoreInst>(I) || + isa<CallInst>(I) || isa<FenceInst>(I) || + isa<BinaryOperator>(I) || isa<CastInst>(I) || + isa<SelectInst>(I) || isa<GetElementPtrInst>(I) || + isa<CmpInst>(I) || isa<InsertElementInst>(I) || + isa<ExtractElementInst>(I) || isa<ShuffleVectorInst>(I) || + isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)); +} +/// Return true if all of the alias sets within this AST are known not to +/// contain a Mod, or if MSSA knows thare are no MemoryDefs in the loop. +bool isReadOnly(AliasSetTracker *CurAST, const MemorySSAUpdater *MSSAU, + const Loop *L) { + if (CurAST) { + for (AliasSet &AS : *CurAST) { + if (!AS.isForwardingAliasSet() && AS.isMod()) { + return false; + } + } + return true; + } else { /*MSSAU*/ + for (auto *BB : L->getBlocks()) + if (MSSAU->getMemorySSA()->getBlockDefs(BB)) + return false; + return true; + } +} + +/// Return true if I is the only Instruction with a MemoryAccess in L. +bool isOnlyMemoryAccess(const Instruction *I, const Loop *L, + const MemorySSAUpdater *MSSAU) { + for (auto *BB : L->getBlocks()) + if (auto *Accs = MSSAU->getMemorySSA()->getBlockAccesses(BB)) { + int NotAPhi = 0; + for (const auto &Acc : *Accs) { + if (isa<MemoryPhi>(&Acc)) + continue; + const auto *MUD = cast<MemoryUseOrDef>(&Acc); + if (MUD->getMemoryInst() != I || NotAPhi++ == 1) + return false; + } + } + return true; +} +} + bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, AliasSetTracker *CurAST, - LoopSafetyInfo *SafetyInfo, + MemorySSAUpdater *MSSAU, + bool TargetExecutesOncePerLoop, OptimizationRemarkEmitter *ORE) { - // SafetyInfo is nullptr if we are checking for sinking from preheader to - // loop body. - const bool SinkingToLoopBody = !SafetyInfo; + // If we don't understand the instruction, bail early. + if (!isHoistableAndSinkableInst(I)) + return false; + + MemorySSA *MSSA = MSSAU ? MSSAU->getMemorySSA() : nullptr; + // Loads have extra constraints we have to verify before we can hoist them. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { if (!LI->isUnordered()) @@ -594,23 +1016,20 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (LI->getMetadata(LLVMContext::MD_invariant_load)) return true; - if (LI->isAtomic() && SinkingToLoopBody) - return false; // Don't sink unordered atomic loads to loop body. + if (LI->isAtomic() && !TargetExecutesOncePerLoop) + return false; // Don't risk duplicating unordered loads // This checks for an invariant.start dominating the load. if (isLoadInvariantInLoop(LI, DT, CurLoop)) return true; - // Don't hoist loads which have may-aliased stores in loop. - uint64_t Size = 0; - if (LI->getType()->isSized()) - Size = I.getModule()->getDataLayout().getTypeStoreSize(LI->getType()); - - AAMDNodes AAInfo; - LI->getAAMetadata(AAInfo); - - bool Invalidated = - pointerInvalidatedByLoop(LI->getOperand(0), Size, AAInfo, CurAST); + bool Invalidated; + if (CurAST) + Invalidated = pointerInvalidatedByLoop(MemoryLocation::get(LI), CurAST, + CurLoop, AA); + else + Invalidated = pointerInvalidatedByLoopWithMSSA( + MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(LI)), CurLoop); // Check loop-invariant address because this may also be a sinkable load // whose address is not necessarily loop-invariant. if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) @@ -631,6 +1050,11 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (CI->mayThrow()) return false; + using namespace PatternMatch; + if (match(CI, m_Intrinsic<Intrinsic::assume>())) + // Assumes don't actually alias anything or throw + return true; + // Handle simple cases by querying alias analysis. FunctionModRefBehavior Behavior = AA->getModRefBehavior(CI); if (Behavior == FMRB_DoesNotAccessMemory) @@ -640,23 +1064,26 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // it's arguments with arbitrary offsets. If we can prove there are no // writes to this memory in the loop, we can hoist or sink. if (AliasAnalysis::onlyAccessesArgPointees(Behavior)) { + // TODO: expand to writeable arguments for (Value *Op : CI->arg_operands()) - if (Op->getType()->isPointerTy() && - pointerInvalidatedByLoop(Op, MemoryLocation::UnknownSize, - AAMDNodes(), CurAST)) - return false; + if (Op->getType()->isPointerTy()) { + bool Invalidated; + if (CurAST) + Invalidated = pointerInvalidatedByLoop( + MemoryLocation(Op, LocationSize::unknown(), AAMDNodes()), + CurAST, CurLoop, AA); + else + Invalidated = pointerInvalidatedByLoopWithMSSA( + MSSA, cast<MemoryUse>(MSSA->getMemoryAccess(CI)), CurLoop); + if (Invalidated) + return false; + } return true; } + // If this call only reads from memory and there are no writes to memory // in the loop, we can hoist or sink the call as appropriate. - bool FoundMod = false; - for (AliasSet &AS : *CurAST) { - if (!AS.isForwardingAliasSet() && AS.isMod()) { - FoundMod = true; - break; - } - } - if (!FoundMod) + if (isReadOnly(CurAST, MSSAU, CurLoop)) return true; } @@ -664,25 +1091,63 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // sink the call. return false; + } else if (auto *FI = dyn_cast<FenceInst>(&I)) { + // Fences alias (most) everything to provide ordering. For the moment, + // just give up if there are any other memory operations in the loop. + if (CurAST) { + auto Begin = CurAST->begin(); + assert(Begin != CurAST->end() && "must contain FI"); + if (std::next(Begin) != CurAST->end()) + // constant memory for instance, TODO: handle better + return false; + auto *UniqueI = Begin->getUniqueInstruction(); + if (!UniqueI) + // other memory op, give up + return false; + (void)FI; // suppress unused variable warning + assert(UniqueI == FI && "AS must contain FI"); + return true; + } else // MSSAU + return isOnlyMemoryAccess(FI, CurLoop, MSSAU); + } else if (auto *SI = dyn_cast<StoreInst>(&I)) { + if (!SI->isUnordered()) + return false; // Don't sink/hoist volatile or ordered atomic store! + + // We can only hoist a store that we can prove writes a value which is not + // read or overwritten within the loop. For those cases, we fallback to + // load store promotion instead. TODO: We can extend this to cases where + // there is exactly one write to the location and that write dominates an + // arbitrary number of reads in the loop. + if (CurAST) { + auto &AS = CurAST->getAliasSetFor(MemoryLocation::get(SI)); + + if (AS.isRef() || !AS.isMustAlias()) + // Quick exit test, handled by the full path below as well. + return false; + auto *UniqueI = AS.getUniqueInstruction(); + if (!UniqueI) + // other memory op, give up + return false; + assert(UniqueI == SI && "AS must contain SI"); + return true; + } else { // MSSAU + if (isOnlyMemoryAccess(SI, CurLoop, MSSAU)) + return true; + if (!EnableLicmCap) { + auto *Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(SI); + if (MSSA->isLiveOnEntryDef(Source) || + !CurLoop->contains(Source->getBlock())) + return true; + } + return false; + } } - // Only these instructions are hoistable/sinkable. - if (!isa<BinaryOperator>(I) && !isa<CastInst>(I) && !isa<SelectInst>(I) && - !isa<GetElementPtrInst>(I) && !isa<CmpInst>(I) && - !isa<InsertElementInst>(I) && !isa<ExtractElementInst>(I) && - !isa<ShuffleVectorInst>(I) && !isa<ExtractValueInst>(I) && - !isa<InsertValueInst>(I)) - return false; - - // If we are checking for sinking from preheader to loop body it will be - // always safe as there is no speculative execution. - if (SinkingToLoopBody) - return true; + assert(!I.mayReadOrWriteMemory() && "unhandled aliasing"); - // TODO: Plumb the context instruction through to make hoisting and sinking - // more powerful. Hoisting of loads already works due to the special casing - // above. - return isSafeToExecuteUnconditionally(I, DT, CurLoop, SafetyInfo, nullptr); + // We've established mechanical ability and aliasing, it's up to the caller + // to check fault safety + return true; } /// Returns true if a PHINode is a trivially replaceable with an @@ -730,7 +1195,7 @@ static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, TargetTransformInfo *TTI, bool &FreeInLoop) { - const auto &BlockColors = SafetyInfo->BlockColors; + const auto &BlockColors = SafetyInfo->getBlockColors(); bool IsFree = isFreeInLoop(I, CurLoop, TTI); for (const User *U : I.users()) { const Instruction *UI = cast<Instruction>(U); @@ -759,13 +1224,12 @@ static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, return true; } -static Instruction * -CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, - const LoopInfo *LI, - const LoopSafetyInfo *SafetyInfo) { +static Instruction *CloneInstructionInExitBlock( + Instruction &I, BasicBlock &ExitBlock, PHINode &PN, const LoopInfo *LI, + const LoopSafetyInfo *SafetyInfo, MemorySSAUpdater *MSSAU) { Instruction *New; if (auto *CI = dyn_cast<CallInst>(&I)) { - const auto &BlockColors = SafetyInfo->BlockColors; + const auto &BlockColors = SafetyInfo->getBlockColors(); // Sinking call-sites need to be handled differently from other // instructions. The cloned call-site needs a funclet bundle operand @@ -798,6 +1262,21 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, if (!I.getName().empty()) New->setName(I.getName() + ".le"); + MemoryAccess *OldMemAcc; + if (MSSAU && (OldMemAcc = MSSAU->getMemorySSA()->getMemoryAccess(&I))) { + // Create a new MemoryAccess and let MemorySSA set its defining access. + MemoryAccess *NewMemAcc = MSSAU->createMemoryAccessInBB( + New, nullptr, New->getParent(), MemorySSA::Beginning); + if (NewMemAcc) { + if (auto *MemDef = dyn_cast<MemoryDef>(NewMemAcc)) + MSSAU->insertDef(MemDef, /*RenameUses=*/true); + else { + auto *MemUse = cast<MemoryUse>(NewMemAcc); + MSSAU->insertUse(MemUse); + } + } + } + // Build LCSSA PHI nodes for any in-loop operands. Note that this is // particularly cheap because we can rip off the PHI node that we're // replacing for the number and blocks of the predecessors. @@ -820,10 +1299,28 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, return New; } +static void eraseInstruction(Instruction &I, ICFLoopSafetyInfo &SafetyInfo, + AliasSetTracker *AST, MemorySSAUpdater *MSSAU) { + if (AST) + AST->deleteValue(&I); + if (MSSAU) + MSSAU->removeMemoryAccess(&I); + SafetyInfo.removeInstruction(&I); + I.eraseFromParent(); +} + +static void moveInstructionBefore(Instruction &I, Instruction &Dest, + ICFLoopSafetyInfo &SafetyInfo) { + SafetyInfo.removeInstruction(&I); + SafetyInfo.insertInstructionTo(&I, Dest.getParent()); + I.moveBefore(&Dest); +} + static Instruction *sinkThroughTriviallyReplaceablePHI( PHINode *TPN, Instruction *I, LoopInfo *LI, SmallDenseMap<BasicBlock *, Instruction *, 32> &SunkCopies, - const LoopSafetyInfo *SafetyInfo, const Loop *CurLoop) { + const LoopSafetyInfo *SafetyInfo, const Loop *CurLoop, + MemorySSAUpdater *MSSAU) { assert(isTriviallyReplaceablePHI(*TPN, *I) && "Expect only trivially replaceable PHI"); BasicBlock *ExitBlock = TPN->getParent(); @@ -832,8 +1329,8 @@ static Instruction *sinkThroughTriviallyReplaceablePHI( if (It != SunkCopies.end()) New = It->second; else - New = SunkCopies[ExitBlock] = - CloneInstructionInExitBlock(*I, *ExitBlock, *TPN, LI, SafetyInfo); + New = SunkCopies[ExitBlock] = CloneInstructionInExitBlock( + *I, *ExitBlock, *TPN, LI, SafetyInfo, MSSAU); return New; } @@ -845,7 +1342,7 @@ static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) { // it require updating BlockColors for all offspring blocks accordingly. By // skipping such corner case, we can make updating BlockColors after splitting // predecessor fairly simple. - if (!SafetyInfo->BlockColors.empty() && BB->getFirstNonPHI()->isEHPad()) + if (!SafetyInfo->getBlockColors().empty() && BB->getFirstNonPHI()->isEHPad()) return false; for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { BasicBlock *BBPred = *PI; @@ -857,7 +1354,8 @@ static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) { static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, LoopInfo *LI, const Loop *CurLoop, - LoopSafetyInfo *SafetyInfo) { + LoopSafetyInfo *SafetyInfo, + MemorySSAUpdater *MSSAU) { #ifndef NDEBUG SmallVector<BasicBlock *, 32> ExitBlocks; CurLoop->getUniqueExitBlocks(ExitBlocks); @@ -899,7 +1397,7 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, // LE: // %p = phi [%p1, %LE.split], [%p2, %LE.split2] // - auto &BlockColors = SafetyInfo->BlockColors; + const auto &BlockColors = SafetyInfo->getBlockColors(); SmallSetVector<BasicBlock *, 8> PredBBs(pred_begin(ExitBB), pred_end(ExitBB)); while (!PredBBs.empty()) { BasicBlock *PredBB = *PredBBs.begin(); @@ -907,18 +1405,15 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, "Expect all predecessors are in the loop"); if (PN->getBasicBlockIndex(PredBB) >= 0) { BasicBlock *NewPred = SplitBlockPredecessors( - ExitBB, PredBB, ".split.loop.exit", DT, LI, true); + ExitBB, PredBB, ".split.loop.exit", DT, LI, MSSAU, true); // Since we do not allow splitting EH-block with BlockColors in // canSplitPredecessors(), we can simply assign predecessor's color to // the new block. - if (!BlockColors.empty()) { + if (!BlockColors.empty()) // Grab a reference to the ColorVector to be inserted before getting the // reference to the vector we are copying because inserting the new // element in BlockColors might cause the map to be reallocated. - ColorVector &ColorsForNewBlock = BlockColors[NewPred]; - ColorVector &ColorsForOldBlock = BlockColors[PredBB]; - ColorsForNewBlock = ColorsForOldBlock; - } + SafetyInfo->copyColors(NewPred, PredBB); } PredBBs.remove(PredBB); } @@ -930,8 +1425,9 @@ static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, /// position, and may either delete it or move it to outside of the loop. /// static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, - const Loop *CurLoop, LoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE, bool FreeInLoop) { + const Loop *CurLoop, ICFLoopSafetyInfo *SafetyInfo, + MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE, + bool FreeInLoop) { LLVM_DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "InstSunk", &I) @@ -983,7 +1479,7 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, // Split predecessors of the PHI so that we can make users trivially // replaceable. - splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop, SafetyInfo); + splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop, SafetyInfo, MSSAU); // Should rebuild the iterators, as they may be invalidated by // splitPredecessorsOfLoopExit(). @@ -1018,10 +1514,10 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, assert(ExitBlockSet.count(PN->getParent()) && "The LCSSA PHI is not in an exit block!"); // The PHI must be trivially replaceable. - Instruction *New = sinkThroughTriviallyReplaceablePHI(PN, &I, LI, SunkCopies, - SafetyInfo, CurLoop); + Instruction *New = sinkThroughTriviallyReplaceablePHI( + PN, &I, LI, SunkCopies, SafetyInfo, CurLoop, MSSAU); PN->replaceAllUsesWith(New); - PN->eraseFromParent(); + eraseInstruction(*PN, *SafetyInfo, nullptr, nullptr); Changed = true; } return Changed; @@ -1030,11 +1526,10 @@ static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, /// When an instruction is found to only use loop invariant operands that /// is safe to hoist, this instruction is called to do the dirty work. /// -static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE) { - auto *Preheader = CurLoop->getLoopPreheader(); - LLVM_DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I +static void hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, + BasicBlock *Dest, ICFLoopSafetyInfo *SafetyInfo, + MemorySSAUpdater *MSSAU, OptimizationRemarkEmitter *ORE) { + LLVM_DEBUG(dbgs() << "LICM hoisting to " << Dest->getName() << ": " << I << "\n"); ORE->emit([&]() { return OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) << "hoisting " @@ -1049,11 +1544,22 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, // The check on hasMetadataOtherThanDebugLoc is to prevent us from burning // time in isGuaranteedToExecute if we don't actually have anything to // drop. It is a compile time optimization, not required for correctness. - !isGuaranteedToExecute(I, DT, CurLoop, SafetyInfo)) + !SafetyInfo->isGuaranteedToExecute(I, DT, CurLoop)) I.dropUnknownNonDebugMetadata(); - // Move the new node to the Preheader, before its terminator. - I.moveBefore(Preheader->getTerminator()); + if (isa<PHINode>(I)) + // Move the new node to the end of the phi list in the destination block. + moveInstructionBefore(I, *Dest->getFirstNonPHI(), *SafetyInfo); + else + // Move the new node to the destination block, before its terminator. + moveInstructionBefore(I, *Dest->getTerminator(), *SafetyInfo); + if (MSSAU) { + // If moving, I just moved a load or store, so update MemorySSA. + MemoryUseOrDef *OldMemAcc = cast_or_null<MemoryUseOrDef>( + MSSAU->getMemorySSA()->getMemoryAccess(&I)); + if (OldMemAcc) + MSSAU->moveToPlace(OldMemAcc, Dest, MemorySSA::End); + } // Do not retain debug locations when we are moving instructions to different // basic blocks, because we want to avoid jumpy line tables. Calls, however, @@ -1068,7 +1574,6 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, else if (isa<CallInst>(I)) ++NumMovedCalls; ++NumHoisted; - return true; } /// Only sink or hoist an instruction if it is not a trapping instruction, @@ -1084,7 +1589,7 @@ static bool isSafeToExecuteUnconditionally(Instruction &Inst, return true; bool GuaranteedToExecute = - isGuaranteedToExecute(Inst, DT, CurLoop, SafetyInfo); + SafetyInfo->isGuaranteedToExecute(Inst, DT, CurLoop); if (!GuaranteedToExecute) { auto *LI = dyn_cast<LoadInst>(&Inst); @@ -1113,6 +1618,7 @@ class LoopPromoter : public LoadAndStorePromoter { int Alignment; bool UnorderedAtomic; AAMDNodes AATags; + ICFLoopSafetyInfo &SafetyInfo; Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const { if (Instruction *I = dyn_cast<Instruction>(V)) @@ -1135,11 +1641,13 @@ public: SmallVectorImpl<BasicBlock *> &LEB, SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, - bool UnorderedAtomic, const AAMDNodes &AATags) + bool UnorderedAtomic, const AAMDNodes &AATags, + ICFLoopSafetyInfo &SafetyInfo) : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), LI(li), DL(std::move(dl)), Alignment(alignment), - UnorderedAtomic(UnorderedAtomic), AATags(AATags) {} + UnorderedAtomic(UnorderedAtomic), AATags(AATags), SafetyInfo(SafetyInfo) + {} bool isInstInList(Instruction *I, const SmallVectorImpl<Instruction *> &) const override { @@ -1176,7 +1684,10 @@ public: // Update alias analysis. AST.copyValue(LI, V); } - void instructionDeleted(Instruction *I) const override { AST.deleteValue(I); } + void instructionDeleted(Instruction *I) const override { + SafetyInfo.removeInstruction(I); + AST.deleteValue(I); + } }; @@ -1214,7 +1725,7 @@ bool llvm::promoteLoopAccessesToScalars( SmallVectorImpl<BasicBlock *> &ExitBlocks, SmallVectorImpl<Instruction *> &InsertPts, PredIteratorCache &PIC, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, - Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + Loop *CurLoop, AliasSetTracker *CurAST, ICFLoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { // Verify inputs. assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && @@ -1277,7 +1788,7 @@ bool llvm::promoteLoopAccessesToScalars( const DataLayout &MDL = Preheader->getModule()->getDataLayout(); bool IsKnownThreadLocalObject = false; - if (SafetyInfo->MayThrow) { + if (SafetyInfo->anyBlockMayThrow()) { // If a loop can throw, we have to insert a store along each unwind edge. // That said, we can't actually make the unwind edge explicit. Therefore, // we have to prove that the store is dead along the unwind edge. We do @@ -1310,7 +1821,6 @@ bool llvm::promoteLoopAccessesToScalars( // If there is an non-load/store instruction in the loop, we can't promote // it. if (LoadInst *Load = dyn_cast<LoadInst>(UI)) { - assert(!Load->isVolatile() && "AST broken"); if (!Load->isUnordered()) return false; @@ -1325,7 +1835,6 @@ bool llvm::promoteLoopAccessesToScalars( // pointer. if (UI->getOperand(1) != ASIV) continue; - assert(!Store->isVolatile() && "AST broken"); if (!Store->isUnordered()) return false; @@ -1344,7 +1853,7 @@ bool llvm::promoteLoopAccessesToScalars( if (!DereferenceableInPH || !SafeToInsertStore || (InstAlignment > Alignment)) { - if (isGuaranteedToExecute(*UI, DT, CurLoop, SafetyInfo)) { + if (SafetyInfo->isGuaranteedToExecute(*UI, DT, CurLoop)) { DereferenceableInPH = true; SafeToInsertStore = true; Alignment = std::max(Alignment, InstAlignment); @@ -1435,7 +1944,7 @@ bool llvm::promoteLoopAccessesToScalars( SSAUpdater SSA(&NewPHIs); LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, InsertPts, PIC, *CurAST, *LI, DL, Alignment, - SawUnorderedAtomic, AATags); + SawUnorderedAtomic, AATags, *SafetyInfo); // Set up the preheader to have a definition of the value. It is the live-out // value from the preheader that uses in the loop will use. @@ -1455,7 +1964,7 @@ bool llvm::promoteLoopAccessesToScalars( // If the SSAUpdater didn't use the load in the preheader, just zap it now. if (PreheaderLoad->use_empty()) - PreheaderLoad->eraseFromParent(); + eraseInstruction(*PreheaderLoad, *SafetyInfo, CurAST, nullptr); return true; } @@ -1466,10 +1975,10 @@ bool llvm::promoteLoopAccessesToScalars( /// analysis such as cloneBasicBlockAnalysis, so the AST needs to be recomputed /// from scratch for every loop. Hook up with the helper functions when /// available in the new pass manager to avoid redundant computation. -AliasSetTracker * +std::unique_ptr<AliasSetTracker> LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, AliasAnalysis *AA) { - AliasSetTracker *CurAST = nullptr; + std::unique_ptr<AliasSetTracker> CurAST; SmallVector<Loop *, 4> RecomputeLoops; for (Loop *InnerL : L->getSubLoops()) { auto MapI = LoopToAliasSetMap.find(InnerL); @@ -1480,35 +1989,30 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, RecomputeLoops.push_back(InnerL); continue; } - AliasSetTracker *InnerAST = MapI->second; + std::unique_ptr<AliasSetTracker> InnerAST = std::move(MapI->second); - if (CurAST != nullptr) { + if (CurAST) { // What if InnerLoop was modified by other passes ? - CurAST->add(*InnerAST); - // Once we've incorporated the inner loop's AST into ours, we don't need // the subloop's anymore. - delete InnerAST; + CurAST->add(*InnerAST); } else { - CurAST = InnerAST; + CurAST = std::move(InnerAST); } LoopToAliasSetMap.erase(MapI); } - if (CurAST == nullptr) - CurAST = new AliasSetTracker(*AA); - - auto mergeLoop = [&](Loop *L) { - // Loop over the body of this loop, looking for calls, invokes, and stores. - for (BasicBlock *BB : L->blocks()) - CurAST->add(*BB); // Incorporate the specified basic block - }; + if (!CurAST) + CurAST = make_unique<AliasSetTracker>(*AA); // Add everything from the sub loops that are no longer directly available. for (Loop *InnerL : RecomputeLoops) - mergeLoop(InnerL); + for (BasicBlock *BB : InnerL->blocks()) + CurAST->add(*BB); - // And merge in this loop. - mergeLoop(L); + // And merge in this loop (without anything from inner loops). + for (BasicBlock *BB : L->blocks()) + if (LI->getLoopFor(BB) == L) + CurAST->add(*BB); return CurAST; } @@ -1517,42 +2021,89 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, /// void LegacyLICMPass::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, Loop *L) { - AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); - if (!AST) + auto ASTIt = LICM.getLoopToAliasSetMap().find(L); + if (ASTIt == LICM.getLoopToAliasSetMap().end()) return; - AST->copyValue(From, To); + ASTIt->second->copyValue(From, To); } /// Simple Analysis hook. Delete value V from alias set /// void LegacyLICMPass::deleteAnalysisValue(Value *V, Loop *L) { - AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); - if (!AST) + auto ASTIt = LICM.getLoopToAliasSetMap().find(L); + if (ASTIt == LICM.getLoopToAliasSetMap().end()) return; - AST->deleteValue(V); + ASTIt->second->deleteValue(V); } /// Simple Analysis hook. Delete value L from alias set map. /// void LegacyLICMPass::deleteAnalysisLoop(Loop *L) { - AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); - if (!AST) + if (!LICM.getLoopToAliasSetMap().count(L)) return; - delete AST; LICM.getLoopToAliasSetMap().erase(L); } -/// Return true if the body of this loop may store into the memory -/// location pointed to by V. -/// -static bool pointerInvalidatedByLoop(Value *V, uint64_t Size, - const AAMDNodes &AAInfo, - AliasSetTracker *CurAST) { - // Check to see if any of the basic blocks in CurLoop invalidate *V. - return CurAST->getAliasSetForPointer(V, Size, AAInfo).isMod(); +static bool pointerInvalidatedByLoop(MemoryLocation MemLoc, + AliasSetTracker *CurAST, Loop *CurLoop, + AliasAnalysis *AA) { + // First check to see if any of the basic blocks in CurLoop invalidate *V. + bool isInvalidatedAccordingToAST = CurAST->getAliasSetFor(MemLoc).isMod(); + + if (!isInvalidatedAccordingToAST || !LICMN2Theshold) + return isInvalidatedAccordingToAST; + + // Check with a diagnostic analysis if we can refine the information above. + // This is to identify the limitations of using the AST. + // The alias set mechanism used by LICM has a major weakness in that it + // combines all things which may alias into a single set *before* asking + // modref questions. As a result, a single readonly call within a loop will + // collapse all loads and stores into a single alias set and report + // invalidation if the loop contains any store. For example, readonly calls + // with deopt states have this form and create a general alias set with all + // loads and stores. In order to get any LICM in loops containing possible + // deopt states we need a more precise invalidation of checking the mod ref + // info of each instruction within the loop and LI. This has a complexity of + // O(N^2), so currently, it is used only as a diagnostic tool since the + // default value of LICMN2Threshold is zero. + + // Don't look at nested loops. + if (CurLoop->begin() != CurLoop->end()) + return true; + + int N = 0; + for (BasicBlock *BB : CurLoop->getBlocks()) + for (Instruction &I : *BB) { + if (N >= LICMN2Theshold) { + LLVM_DEBUG(dbgs() << "Alasing N2 threshold exhausted for " + << *(MemLoc.Ptr) << "\n"); + return true; + } + N++; + auto Res = AA->getModRefInfo(&I, MemLoc); + if (isModSet(Res)) { + LLVM_DEBUG(dbgs() << "Aliasing failed on " << I << " for " + << *(MemLoc.Ptr) << "\n"); + return true; + } + } + LLVM_DEBUG(dbgs() << "Aliasing okay for " << *(MemLoc.Ptr) << "\n"); + return false; +} + +static bool pointerInvalidatedByLoopWithMSSA(MemorySSA *MSSA, MemoryUse *MU, + Loop *CurLoop) { + MemoryAccess *Source; + // See declaration of EnableLicmCap for usage details. + if (EnableLicmCap) + Source = MU->getDefiningAccess(); + else + Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(MU); + return !MSSA->isLiveOnEntryDef(Source) && + CurLoop->contains(Source->getBlock()); } /// Little predicate that returns true if the specified basic block is in diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp index 06083a4f5086..d797c9dc9e72 100644 --- a/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/lib/Transforms/Scalar/LoopDistribute.cpp @@ -78,6 +78,18 @@ using namespace llvm; #define LDIST_NAME "loop-distribute" #define DEBUG_TYPE LDIST_NAME +/// @{ +/// Metadata attribute names +static const char *const LLVMLoopDistributeFollowupAll = + "llvm.loop.distribute.followup_all"; +static const char *const LLVMLoopDistributeFollowupCoincident = + "llvm.loop.distribute.followup_coincident"; +static const char *const LLVMLoopDistributeFollowupSequential = + "llvm.loop.distribute.followup_sequential"; +static const char *const LLVMLoopDistributeFollowupFallback = + "llvm.loop.distribute.followup_fallback"; +/// @} + static cl::opt<bool> LDistVerify("loop-distribute-verify", cl::Hidden, cl::desc("Turn on DominatorTree and LoopInfo verification " @@ -186,7 +198,7 @@ public: /// Returns the loop where this partition ends up after distribution. /// If this partition is mapped to the original loop then use the block from /// the loop. - const Loop *getDistributedLoop() const { + Loop *getDistributedLoop() const { return ClonedLoop ? ClonedLoop : OrigLoop; } @@ -443,6 +455,9 @@ public: assert(&*OrigPH->begin() == OrigPH->getTerminator() && "preheader not empty"); + // Preserve the original loop ID for use after the transformation. + MDNode *OrigLoopID = L->getLoopID(); + // Create a loop for each partition except the last. Clone the original // loop before PH along with adding a preheader for the cloned loop. Then // update PH to point to the newly added preheader. @@ -457,9 +472,13 @@ public: Part->getVMap()[ExitBlock] = TopPH; Part->remapInstructions(); + setNewLoopID(OrigLoopID, Part); } Pred->getTerminator()->replaceUsesOfWith(OrigPH, TopPH); + // Also set a new loop ID for the last loop. + setNewLoopID(OrigLoopID, &PartitionContainer.back()); + // Now go in forward order and update the immediate dominator for the // preheaders with the exiting block of the previous loop. Dominance // within the loop is updated in cloneLoopWithPreheader. @@ -575,6 +594,19 @@ private: } } } + + /// Assign new LoopIDs for the partition's cloned loop. + void setNewLoopID(MDNode *OrigLoopID, InstPartition *Part) { + Optional<MDNode *> PartitionID = makeFollowupLoopID( + OrigLoopID, + {LLVMLoopDistributeFollowupAll, + Part->hasDepCycle() ? LLVMLoopDistributeFollowupSequential + : LLVMLoopDistributeFollowupCoincident}); + if (PartitionID.hasValue()) { + Loop *NewLoop = Part->getDistributedLoop(); + NewLoop->setLoopID(PartitionID.getValue()); + } + } }; /// For each memory instruction, this class maintains difference of the @@ -743,6 +775,9 @@ public: return fail("TooManySCEVRuntimeChecks", "too many SCEV run-time checks needed.\n"); + if (!IsForced.getValueOr(false) && hasDisableAllTransformsHint(L)) + return fail("HeuristicDisabled", "distribution heuristic disabled"); + LLVM_DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); // We're done forming the partitions set up the reverse mapping from // instructions to partitions. @@ -762,6 +797,8 @@ public: RtPtrChecking); if (!Pred.isAlwaysTrue() || !Checks.empty()) { + MDNode *OrigLoopID = L->getLoopID(); + LLVM_DEBUG(dbgs() << "\nPointers:\n"); LLVM_DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); LoopVersioning LVer(*LAI, L, LI, DT, SE, false); @@ -769,6 +806,17 @@ public: LVer.setSCEVChecks(LAI->getPSE().getUnionPredicate()); LVer.versionLoop(DefsUsedOutside); LVer.annotateLoopWithNoAlias(); + + // The unversioned loop will not be changed, so we inherit all attributes + // from the original loop, but remove the loop distribution metadata to + // avoid to distribute it again. + MDNode *UnversionedLoopID = + makeFollowupLoopID(OrigLoopID, + {LLVMLoopDistributeFollowupAll, + LLVMLoopDistributeFollowupFallback}, + "llvm.loop.distribute.", true) + .getValue(); + LVer.getNonVersionedLoop()->setLoopID(UnversionedLoopID); } // Create identical copies of the original loop for each partition and hook diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 653948717fb9..fbffa1920a84 100644 --- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -26,7 +26,7 @@ // Future floating point idioms to recognize in -ffast-math mode: // fpowi // Future integer operation idioms to recognize: -// ctpop, ctlz, cttz +// ctpop // // Beware that isel's default lowering for ctpop is highly inefficient for // i64 and larger types when i64 is legal and the value has few bits set. It @@ -163,8 +163,9 @@ private: void collectStores(BasicBlock *BB); LegalStoreKind isLegalStore(StoreInst *SI); + enum class ForMemset { No, Yes }; bool processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEV *BECount, - bool ForMemset); + ForMemset For); bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, @@ -186,9 +187,10 @@ private: bool recognizePopcount(); void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, PHINode *CntPhi, Value *Var); - bool recognizeAndInsertCTLZ(); - void transformLoopToCountable(BasicBlock *PreCondBB, Instruction *CntInst, - PHINode *CntPhi, Value *Var, Instruction *DefX, + bool recognizeAndInsertFFS(); /// Find First Set: ctlz or cttz + void transformLoopToCountable(Intrinsic::ID IntrinID, BasicBlock *PreCondBB, + Instruction *CntInst, PHINode *CntPhi, + Value *Var, Instruction *DefX, const DebugLoc &DL, bool ZeroCheck, bool IsCntPhiUsedOutsideLoop); @@ -319,9 +321,9 @@ bool LoopIdiomRecognize::runOnCountableLoop() { // The following transforms hoist stores/memsets into the loop pre-header. // Give up if the loop has instructions may throw. - LoopSafetyInfo SafetyInfo; - computeLoopSafetyInfo(&SafetyInfo, CurLoop); - if (SafetyInfo.MayThrow) + SimpleLoopSafetyInfo SafetyInfo; + SafetyInfo.computeLoopSafetyInfo(CurLoop); + if (SafetyInfo.anyBlockMayThrow()) return MadeChange; // Scan all the blocks in the loop that are not in subloops. @@ -347,6 +349,9 @@ static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) { /// Note that we don't ever attempt to use memset_pattern8 or 4, because these /// just replicate their input array and then pass on to memset_pattern16. static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) { + // FIXME: This could check for UndefValue because it can be merged into any + // other valid pattern. + // If the value isn't a constant, we can't promote it to being in a constant // array. We could theoretically do a store to an alloca or something, but // that doesn't seem worthwhile. @@ -543,10 +548,10 @@ bool LoopIdiomRecognize::runOnLoopBlock( // optimized into a memset (memset_pattern). The latter most commonly happens // with structs and handunrolled loops. for (auto &SL : StoreRefsForMemset) - MadeChange |= processLoopStores(SL.second, BECount, true); + MadeChange |= processLoopStores(SL.second, BECount, ForMemset::Yes); for (auto &SL : StoreRefsForMemsetPattern) - MadeChange |= processLoopStores(SL.second, BECount, false); + MadeChange |= processLoopStores(SL.second, BECount, ForMemset::No); // Optimize the store into a memcpy, if it feeds an similarly strided load. for (auto &SI : StoreRefsForMemcpy) @@ -572,10 +577,9 @@ bool LoopIdiomRecognize::runOnLoopBlock( return MadeChange; } -/// processLoopStores - See if this store(s) can be promoted to a memset. +/// See if this store(s) can be promoted to a memset. bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, - const SCEV *BECount, - bool ForMemset) { + const SCEV *BECount, ForMemset For) { // Try to find consecutive stores that can be transformed into memsets. SetVector<StoreInst *> Heads, Tails; SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain; @@ -602,7 +606,7 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, Value *FirstSplatValue = nullptr; Constant *FirstPatternValue = nullptr; - if (ForMemset) + if (For == ForMemset::Yes) FirstSplatValue = isBytewiseValue(FirstStoredVal); else FirstPatternValue = getMemSetPatternValue(FirstStoredVal, DL); @@ -635,7 +639,7 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, Value *SecondSplatValue = nullptr; Constant *SecondPatternValue = nullptr; - if (ForMemset) + if (For == ForMemset::Yes) SecondSplatValue = isBytewiseValue(SecondStoredVal); else SecondPatternValue = getMemSetPatternValue(SecondStoredVal, DL); @@ -644,10 +648,14 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, "Expected either splat value or pattern value."); if (isConsecutiveAccess(SL[i], SL[k], *DL, *SE, false)) { - if (ForMemset) { + if (For == ForMemset::Yes) { + if (isa<UndefValue>(FirstSplatValue)) + FirstSplatValue = SecondSplatValue; if (FirstSplatValue != SecondSplatValue) continue; } else { + if (isa<UndefValue>(FirstPatternValue)) + FirstPatternValue = SecondPatternValue; if (FirstPatternValue != SecondPatternValue) continue; } @@ -772,12 +780,13 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, // Get the location that may be stored across the loop. Since the access is // strided positively through memory, we say that the modified location starts // at the pointer and has infinite size. - uint64_t AccessSize = MemoryLocation::UnknownSize; + LocationSize AccessSize = LocationSize::unknown(); // If the loop iterates a fixed number of times, we can refine the access size // to be exactly the size of the memset, which is (BECount+1)*StoreSize if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) - AccessSize = (BECst->getValue()->getZExtValue() + 1) * StoreSize; + AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * + StoreSize); // TODO: For this to be really effective, we have to dive into the pointer // operand in the store. Store to &A[i] of 100 will always return may alias @@ -921,10 +930,11 @@ bool LoopIdiomRecognize::processLoopStridedStore( Type *Int8PtrTy = DestInt8PtrTy; Module *M = TheStore->getModule(); + StringRef FuncName = "memset_pattern16"; Value *MSP = - M->getOrInsertFunction("memset_pattern16", Builder.getVoidTy(), + M->getOrInsertFunction(FuncName, Builder.getVoidTy(), Int8PtrTy, Int8PtrTy, IntPtr); - inferLibFuncAttributes(*M->getFunction("memset_pattern16"), *TLI); + inferLibFuncAttributes(M, FuncName, *TLI); // Otherwise we should form a memset_pattern16. PatternValue is known to be // an constant array of 16-bytes. Plop the value into a mergable global. @@ -1099,15 +1109,17 @@ bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset, } bool LoopIdiomRecognize::runOnNoncountableLoop() { - return recognizePopcount() || recognizeAndInsertCTLZ(); + return recognizePopcount() || recognizeAndInsertFFS(); } /// Check if the given conditional branch is based on the comparison between -/// a variable and zero, and if the variable is non-zero, the control yields to -/// the loop entry. If the branch matches the behavior, the variable involved -/// in the comparison is returned. This function will be called to see if the -/// precondition and postcondition of the loop are in desirable form. -static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry) { +/// a variable and zero, and if the variable is non-zero or zero (JmpOnZero is +/// true), the control yields to the loop entry. If the branch matches the +/// behavior, the variable involved in the comparison is returned. This function +/// will be called to see if the precondition and postcondition of the loop are +/// in desirable form. +static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry, + bool JmpOnZero = false) { if (!BI || !BI->isConditional()) return nullptr; @@ -1119,9 +1131,14 @@ static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry) { if (!CmpZero || !CmpZero->isZero()) return nullptr; + BasicBlock *TrueSucc = BI->getSuccessor(0); + BasicBlock *FalseSucc = BI->getSuccessor(1); + if (JmpOnZero) + std::swap(TrueSucc, FalseSucc); + ICmpInst::Predicate Pred = Cond->getPredicate(); - if ((Pred == ICmpInst::ICMP_NE && BI->getSuccessor(0) == LoopEntry) || - (Pred == ICmpInst::ICMP_EQ && BI->getSuccessor(1) == LoopEntry)) + if ((Pred == ICmpInst::ICMP_NE && TrueSucc == LoopEntry) || + (Pred == ICmpInst::ICMP_EQ && FalseSucc == LoopEntry)) return Cond->getOperand(0); return nullptr; @@ -1297,14 +1314,14 @@ static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, /// /// loop-exit: /// \endcode -static bool detectCTLZIdiom(Loop *CurLoop, PHINode *&PhiX, - Instruction *&CntInst, PHINode *&CntPhi, - Instruction *&DefX) { +static bool detectShiftUntilZeroIdiom(Loop *CurLoop, const DataLayout &DL, + Intrinsic::ID &IntrinID, Value *&InitX, + Instruction *&CntInst, PHINode *&CntPhi, + Instruction *&DefX) { BasicBlock *LoopEntry; Value *VarX = nullptr; DefX = nullptr; - PhiX = nullptr; CntInst = nullptr; CntPhi = nullptr; LoopEntry = *(CurLoop->block_begin()); @@ -1316,20 +1333,28 @@ static bool detectCTLZIdiom(Loop *CurLoop, PHINode *&PhiX, else return false; - // step 2: detect instructions corresponding to "x.next = x >> 1" - if (!DefX || (DefX->getOpcode() != Instruction::AShr && - DefX->getOpcode() != Instruction::LShr)) + // step 2: detect instructions corresponding to "x.next = x >> 1 or x << 1" + if (!DefX || !DefX->isShift()) return false; + IntrinID = DefX->getOpcode() == Instruction::Shl ? Intrinsic::cttz : + Intrinsic::ctlz; ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1)); if (!Shft || !Shft->isOne()) return false; VarX = DefX->getOperand(0); // step 3: Check the recurrence of variable X - PhiX = getRecurrenceVar(VarX, DefX, LoopEntry); + PHINode *PhiX = getRecurrenceVar(VarX, DefX, LoopEntry); if (!PhiX) return false; + InitX = PhiX->getIncomingValueForBlock(CurLoop->getLoopPreheader()); + + // Make sure the initial value can't be negative otherwise the ashr in the + // loop might never reach zero which would make the loop infinite. + if (DefX->getOpcode() == Instruction::AShr && !isKnownNonNegative(InitX, DL)) + return false; + // step 4: Find the instruction which count the CTLZ: cnt.next = cnt + 1 // TODO: We can skip the step. If loop trip count is known (CTLZ), // then all uses of "cnt.next" could be optimized to the trip count @@ -1361,17 +1386,25 @@ static bool detectCTLZIdiom(Loop *CurLoop, PHINode *&PhiX, return true; } -/// Recognize CTLZ idiom in a non-countable loop and convert the loop -/// to countable (with CTLZ trip count). -/// If CTLZ inserted as a new trip count returns true; otherwise, returns false. -bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { +/// Recognize CTLZ or CTTZ idiom in a non-countable loop and convert the loop +/// to countable (with CTLZ / CTTZ trip count). If CTLZ / CTTZ inserted as a new +/// trip count returns true; otherwise, returns false. +bool LoopIdiomRecognize::recognizeAndInsertFFS() { // Give up if the loop has multiple blocks or multiple backedges. if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) return false; - Instruction *CntInst, *DefX; - PHINode *CntPhi, *PhiX; - if (!detectCTLZIdiom(CurLoop, PhiX, CntInst, CntPhi, DefX)) + Intrinsic::ID IntrinID; + Value *InitX; + Instruction *DefX = nullptr; + PHINode *CntPhi = nullptr; + Instruction *CntInst = nullptr; + // Help decide if transformation is profitable. For ShiftUntilZero idiom, + // this is always 6. + size_t IdiomCanonicalSize = 6; + + if (!detectShiftUntilZeroIdiom(CurLoop, *DL, IntrinID, InitX, + CntInst, CntPhi, DefX)) return false; bool IsCntPhiUsedOutsideLoop = false; @@ -1398,12 +1431,6 @@ bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { // It is safe to assume Preheader exist as it was checked in // parent function RunOnLoop. BasicBlock *PH = CurLoop->getLoopPreheader(); - Value *InitX = PhiX->getIncomingValueForBlock(PH); - - // Make sure the initial value can't be negative otherwise the ashr in the - // loop might never reach zero which would make the loop infinite. - if (DefX->getOpcode() == Instruction::AShr && !isKnownNonNegative(InitX, *DL)) - return false; // If we are using the count instruction outside the loop, make sure we // have a zero check as a precondition. Without the check the loop would run @@ -1421,8 +1448,10 @@ bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { ZeroCheck = true; } - // Check if CTLZ intrinsic is profitable. Assume it is always profitable - // if we delete the loop (the loop has only 6 instructions): + // Check if CTLZ / CTTZ intrinsic is profitable. Assume it is always + // profitable if we delete the loop. + + // the loop has only 6 instructions: // %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ] // %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ] // %shr = ashr %n.addr.0, 1 @@ -1433,12 +1462,12 @@ bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { const Value *Args[] = {InitX, ZeroCheck ? ConstantInt::getTrue(InitX->getContext()) : ConstantInt::getFalse(InitX->getContext())}; - if (CurLoop->getHeader()->size() != 6 && - TTI->getIntrinsicCost(Intrinsic::ctlz, InitX->getType(), Args) > - TargetTransformInfo::TCC_Basic) + if (CurLoop->getHeader()->size() != IdiomCanonicalSize && + TTI->getIntrinsicCost(IntrinID, InitX->getType(), Args) > + TargetTransformInfo::TCC_Basic) return false; - transformLoopToCountable(PH, CntInst, CntPhi, InitX, DefX, + transformLoopToCountable(IntrinID, PH, CntInst, CntPhi, InitX, DefX, DefX->getDebugLoc(), ZeroCheck, IsCntPhiUsedOutsideLoop); return true; @@ -1507,20 +1536,21 @@ static CallInst *createPopcntIntrinsic(IRBuilder<> &IRBuilder, Value *Val, return CI; } -static CallInst *createCTLZIntrinsic(IRBuilder<> &IRBuilder, Value *Val, - const DebugLoc &DL, bool ZeroCheck) { +static CallInst *createFFSIntrinsic(IRBuilder<> &IRBuilder, Value *Val, + const DebugLoc &DL, bool ZeroCheck, + Intrinsic::ID IID) { Value *Ops[] = {Val, ZeroCheck ? IRBuilder.getTrue() : IRBuilder.getFalse()}; Type *Tys[] = {Val->getType()}; Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent(); - Value *Func = Intrinsic::getDeclaration(M, Intrinsic::ctlz, Tys); + Value *Func = Intrinsic::getDeclaration(M, IID, Tys); CallInst *CI = IRBuilder.CreateCall(Func, Ops); CI->setDebugLoc(DL); return CI; } -/// Transform the following loop: +/// Transform the following loop (Using CTLZ, CTTZ is similar): /// loop: /// CntPhi = PHI [Cnt0, CntInst] /// PhiX = PHI [InitX, DefX] @@ -1552,19 +1582,19 @@ static CallInst *createCTLZIntrinsic(IRBuilder<> &IRBuilder, Value *Val, /// If LOOP_BODY is empty the loop will be deleted. /// If CntInst and DefX are not used in LOOP_BODY they will be removed. void LoopIdiomRecognize::transformLoopToCountable( - BasicBlock *Preheader, Instruction *CntInst, PHINode *CntPhi, Value *InitX, - Instruction *DefX, const DebugLoc &DL, bool ZeroCheck, - bool IsCntPhiUsedOutsideLoop) { + Intrinsic::ID IntrinID, BasicBlock *Preheader, Instruction *CntInst, + PHINode *CntPhi, Value *InitX, Instruction *DefX, const DebugLoc &DL, + bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) { BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); - // Step 1: Insert the CTLZ instruction at the end of the preheader block - // Count = BitWidth - CTLZ(InitX); - // If there are uses of CntPhi create: - // CountPrev = BitWidth - CTLZ(InitX >> 1); + // Step 1: Insert the CTLZ/CTTZ instruction at the end of the preheader block IRBuilder<> Builder(PreheaderBr); Builder.SetCurrentDebugLocation(DL); - Value *CTLZ, *Count, *CountPrev, *NewCount, *InitXNext; + Value *FFS, *Count, *CountPrev, *NewCount, *InitXNext; + // Count = BitWidth - CTLZ(InitX); + // If there are uses of CntPhi create: + // CountPrev = BitWidth - CTLZ(InitX >> 1); if (IsCntPhiUsedOutsideLoop) { if (DefX->getOpcode() == Instruction::AShr) InitXNext = @@ -1572,29 +1602,30 @@ void LoopIdiomRecognize::transformLoopToCountable( else if (DefX->getOpcode() == Instruction::LShr) InitXNext = Builder.CreateLShr(InitX, ConstantInt::get(InitX->getType(), 1)); + else if (DefX->getOpcode() == Instruction::Shl) // cttz + InitXNext = + Builder.CreateShl(InitX, ConstantInt::get(InitX->getType(), 1)); else llvm_unreachable("Unexpected opcode!"); } else InitXNext = InitX; - CTLZ = createCTLZIntrinsic(Builder, InitXNext, DL, ZeroCheck); + FFS = createFFSIntrinsic(Builder, InitXNext, DL, ZeroCheck, IntrinID); Count = Builder.CreateSub( - ConstantInt::get(CTLZ->getType(), - CTLZ->getType()->getIntegerBitWidth()), - CTLZ); + ConstantInt::get(FFS->getType(), + FFS->getType()->getIntegerBitWidth()), + FFS); if (IsCntPhiUsedOutsideLoop) { CountPrev = Count; Count = Builder.CreateAdd( CountPrev, ConstantInt::get(CountPrev->getType(), 1)); } - if (IsCntPhiUsedOutsideLoop) - NewCount = Builder.CreateZExtOrTrunc(CountPrev, - cast<IntegerType>(CntInst->getType())); - else - NewCount = Builder.CreateZExtOrTrunc(Count, - cast<IntegerType>(CntInst->getType())); - // If the CTLZ counter's initial value is not zero, insert Add Inst. + NewCount = Builder.CreateZExtOrTrunc( + IsCntPhiUsedOutsideLoop ? CountPrev : Count, + cast<IntegerType>(CntInst->getType())); + + // If the counter's initial value is not zero, insert Add Inst. Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader); ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal); if (!InitConst || !InitConst->isZero()) @@ -1630,8 +1661,7 @@ void LoopIdiomRecognize::transformLoopToCountable( LbCond->setOperand(1, ConstantInt::get(Ty, 0)); // Step 3: All the references to the original counter outside - // the loop are replaced with the NewCount -- the value returned from - // __builtin_ctlz(x). + // the loop are replaced with the NewCount if (IsCntPhiUsedOutsideLoop) CntPhi->replaceUsesOutsideBlock(NewCount, Body); else diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index 71859efbf4bd..6f7dc2429c09 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -22,8 +22,9 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" @@ -36,6 +37,7 @@ #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include <algorithm> #include <utility> @@ -47,8 +49,8 @@ using namespace llvm; STATISTIC(NumSimplified, "Number of redundant instructions simplified"); static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, - const TargetLibraryInfo &TLI) { + AssumptionCache &AC, const TargetLibraryInfo &TLI, + MemorySSAUpdater *MSSAU) { const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); SimplifyQuery SQ(DL, &TLI, &DT, &AC); @@ -75,9 +77,12 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, // iterate. LoopBlocksRPO RPOT(&L); RPOT.perform(&LI); + MemorySSA *MSSA = MSSAU ? MSSAU->getMemorySSA() : nullptr; bool Changed = false; for (;;) { + if (MSSAU && VerifyMemorySSA) + MSSA->verifyMemorySSA(); for (BasicBlock *BB : RPOT) { for (Instruction &I : *BB) { if (auto *PI = dyn_cast<PHINode>(&I)) @@ -129,6 +134,12 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, ToSimplify->insert(UserI); } + if (MSSAU) + if (Instruction *SimpleI = dyn_cast_or_null<Instruction>(V)) + if (MemoryAccess *MA = MSSA->getMemoryAccess(&I)) + if (MemoryAccess *ReplacementMA = MSSA->getMemoryAccess(SimpleI)) + MA->replaceAllUsesWith(ReplacementMA); + assert(I.use_empty() && "Should always have replaced all uses!"); if (isInstructionTriviallyDead(&I, &TLI)) DeadInsts.push_back(&I); @@ -141,9 +152,12 @@ static bool simplifyLoopInst(Loop &L, DominatorTree &DT, LoopInfo &LI, // iteration over all instructions in all the loop blocks. if (!DeadInsts.empty()) { Changed = true; - RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, &TLI); + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, &TLI, MSSAU); } + if (MSSAU && VerifyMemorySSA) + MSSA->verifyMemorySSA(); + // If we never found a PHI that needs to be simplified in the next // iteration, we're done. if (Next->empty()) @@ -180,8 +194,15 @@ public: *L->getHeader()->getParent()); const TargetLibraryInfo &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + MemorySSA *MSSA = nullptr; + Optional<MemorySSAUpdater> MSSAU; + if (EnableMSSALoopDependency) { + MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSAU = MemorySSAUpdater(MSSA); + } - return simplifyLoopInst(*L, DT, LI, AC, TLI); + return simplifyLoopInst(*L, DT, LI, AC, TLI, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -189,6 +210,10 @@ public: AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.setPreservesCFG(); + if (EnableMSSALoopDependency) { + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } getLoopAnalysisUsage(AU); } }; @@ -198,7 +223,13 @@ public: PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { - if (!simplifyLoopInst(L, AR.DT, AR.LI, AR.AC, AR.TLI)) + Optional<MemorySSAUpdater> MSSAU; + if (AR.MSSA) { + MSSAU = MemorySSAUpdater(AR.MSSA); + AR.MSSA->verifyMemorySSA(); + } + if (!simplifyLoopInst(L, AR.DT, AR.LI, AR.AC, AR.TLI, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); @@ -212,6 +243,7 @@ INITIALIZE_PASS_BEGIN(LoopInstSimplifyLegacyPass, "loop-instsimplify", "Simplify instructions in loops", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(LoopInstSimplifyLegacyPass, "loop-instsimplify", "Simplify instructions in loops", false, false) diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp index 2978165ed8a9..766e39b439a0 100644 --- a/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/lib/Transforms/Scalar/LoopInterchange.cpp @@ -17,9 +17,9 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -271,7 +271,7 @@ static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, return true; } -static void populateWorklist(Loop &L, SmallVector<LoopVector, 8> &V) { +static LoopVector populateWorklist(Loop &L) { LLVM_DEBUG(dbgs() << "Calling populateWorklist on Func: " << L.getHeader()->getParent()->getName() << " Loop: %" << L.getHeader()->getName() << '\n'); @@ -282,16 +282,15 @@ static void populateWorklist(Loop &L, SmallVector<LoopVector, 8> &V) { // The current loop has multiple subloops in it hence it is not tightly // nested. // Discard all loops above it added into Worklist. - if (Vec->size() != 1) { - LoopList.clear(); - return; - } + if (Vec->size() != 1) + return {}; + LoopList.push_back(CurrentLoop); CurrentLoop = Vec->front(); Vec = &CurrentLoop->getSubLoops(); } LoopList.push_back(CurrentLoop); - V.push_back(std::move(LoopList)); + return LoopList; } static PHINode *getInductionVariable(Loop *L, ScalarEvolution *SE) { @@ -327,10 +326,8 @@ namespace { class LoopInterchangeLegality { public: LoopInterchangeLegality(Loop *Outer, Loop *Inner, ScalarEvolution *SE, - LoopInfo *LI, DominatorTree *DT, bool PreserveLCSSA, OptimizationRemarkEmitter *ORE) - : OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), - PreserveLCSSA(PreserveLCSSA), ORE(ORE) {} + : OuterLoop(Outer), InnerLoop(Inner), SE(SE), ORE(ORE) {} /// Check if the loops can be interchanged. bool canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId, @@ -342,29 +339,33 @@ public: bool currentLimitations(); - bool hasInnerLoopReduction() { return InnerLoopHasReduction; } + const SmallPtrSetImpl<PHINode *> &getOuterInnerReductions() const { + return OuterInnerReductions; + } private: bool tightlyNested(Loop *Outer, Loop *Inner); - bool containsUnsafeInstructionsInHeader(BasicBlock *BB); - bool areAllUsesReductions(Instruction *Ins, Loop *L); - bool containsUnsafeInstructionsInLatch(BasicBlock *BB); + bool containsUnsafeInstructions(BasicBlock *BB); + + /// Discover induction and reduction PHIs in the header of \p L. Induction + /// PHIs are added to \p Inductions, reductions are added to + /// OuterInnerReductions. When the outer loop is passed, the inner loop needs + /// to be passed as \p InnerLoop. bool findInductionAndReductions(Loop *L, SmallVector<PHINode *, 8> &Inductions, - SmallVector<PHINode *, 8> &Reductions); + Loop *InnerLoop); Loop *OuterLoop; Loop *InnerLoop; ScalarEvolution *SE; - LoopInfo *LI; - DominatorTree *DT; - bool PreserveLCSSA; /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; - bool InnerLoopHasReduction = false; + /// Set of reduction PHIs taking part of a reduction across the inner and + /// outer loop. + SmallPtrSet<PHINode *, 4> OuterInnerReductions; }; /// LoopInterchangeProfitability checks if it is profitable to interchange the @@ -398,10 +399,9 @@ public: LoopInterchangeTransform(Loop *Outer, Loop *Inner, ScalarEvolution *SE, LoopInfo *LI, DominatorTree *DT, BasicBlock *LoopNestExit, - bool InnerLoopContainsReductions) + const LoopInterchangeLegality &LIL) : OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), - LoopExit(LoopNestExit), - InnerLoopHasReduction(InnerLoopContainsReductions) {} + LoopExit(LoopNestExit), LIL(LIL) {} /// Interchange OuterLoop and InnerLoop. bool transform(); @@ -416,8 +416,6 @@ private: bool adjustLoopLinks(); void adjustLoopPreheaders(); bool adjustLoopBranches(); - void updateIncomingBlock(BasicBlock *CurrBlock, BasicBlock *OldPred, - BasicBlock *NewPred); Loop *OuterLoop; Loop *InnerLoop; @@ -428,41 +426,34 @@ private: LoopInfo *LI; DominatorTree *DT; BasicBlock *LoopExit; - bool InnerLoopHasReduction; + + const LoopInterchangeLegality &LIL; }; // Main LoopInterchange Pass. -struct LoopInterchange : public FunctionPass { +struct LoopInterchange : public LoopPass { static char ID; ScalarEvolution *SE = nullptr; LoopInfo *LI = nullptr; DependenceInfo *DI = nullptr; DominatorTree *DT = nullptr; - bool PreserveLCSSA; /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; - LoopInterchange() : FunctionPass(ID) { + LoopInterchange() : LoopPass(ID) { initializeLoopInterchangePass(*PassRegistry::getPassRegistry()); } void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<AAResultsWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<LoopInfoWrapperPass>(); AU.addRequired<DependenceAnalysisWrapperPass>(); - AU.addRequiredID(LoopSimplifyID); - AU.addRequiredID(LCSSAID); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); + getLoopAnalysisUsage(AU); } - bool runOnFunction(Function &F) override { - if (skipFunction(F)) + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L) || L->getParentLoop()) return false; SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); @@ -470,21 +461,8 @@ struct LoopInterchange : public FunctionPass { DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); - PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); - - // Build up a worklist of loop pairs to analyze. - SmallVector<LoopVector, 8> Worklist; - - for (Loop *L : *LI) - populateWorklist(*L, Worklist); - LLVM_DEBUG(dbgs() << "Worklist size = " << Worklist.size() << "\n"); - bool Changed = true; - while (!Worklist.empty()) { - LoopVector LoopList = Worklist.pop_back_val(); - Changed = processLoopList(LoopList, F); - } - return Changed; + return processLoopList(populateWorklist(*L)); } bool isComputableLoopNest(LoopVector LoopList) { @@ -512,7 +490,7 @@ struct LoopInterchange : public FunctionPass { return LoopList.size() - 1; } - bool processLoopList(LoopVector LoopList, Function &F) { + bool processLoopList(LoopVector LoopList) { bool Changed = false; unsigned LoopNestDepth = LoopList.size(); if (LoopNestDepth < 2) { @@ -580,8 +558,7 @@ struct LoopInterchange : public FunctionPass { Loop *InnerLoop = LoopList[InnerLoopId]; Loop *OuterLoop = LoopList[OuterLoopId]; - LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, LI, DT, - PreserveLCSSA, ORE); + LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, ORE); if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) { LLVM_DEBUG(dbgs() << "Not interchanging loops. Cannot prove legality.\n"); return false; @@ -600,8 +577,8 @@ struct LoopInterchange : public FunctionPass { << "Loop interchanged with enclosing loop."; }); - LoopInterchangeTransform LIT(OuterLoop, InnerLoop, SE, LI, DT, - LoopNestExit, LIL.hasInnerLoopReduction()); + LoopInterchangeTransform LIT(OuterLoop, InnerLoop, SE, LI, DT, LoopNestExit, + LIL); LIT.transform(); LLVM_DEBUG(dbgs() << "Loops interchanged.\n"); LoopsInterchanged++; @@ -611,42 +588,12 @@ struct LoopInterchange : public FunctionPass { } // end anonymous namespace -bool LoopInterchangeLegality::areAllUsesReductions(Instruction *Ins, Loop *L) { - return llvm::none_of(Ins->users(), [=](User *U) -> bool { - auto *UserIns = dyn_cast<PHINode>(U); - RecurrenceDescriptor RD; - return !UserIns || !RecurrenceDescriptor::isReductionPHI(UserIns, L, RD); +bool LoopInterchangeLegality::containsUnsafeInstructions(BasicBlock *BB) { + return any_of(*BB, [](const Instruction &I) { + return I.mayHaveSideEffects() || I.mayReadFromMemory(); }); } -bool LoopInterchangeLegality::containsUnsafeInstructionsInHeader( - BasicBlock *BB) { - for (Instruction &I : *BB) { - // Load corresponding to reduction PHI's are safe while concluding if - // tightly nested. - if (LoadInst *L = dyn_cast<LoadInst>(&I)) { - if (!areAllUsesReductions(L, InnerLoop)) - return true; - } else if (I.mayHaveSideEffects() || I.mayReadFromMemory()) - return true; - } - return false; -} - -bool LoopInterchangeLegality::containsUnsafeInstructionsInLatch( - BasicBlock *BB) { - for (Instruction &I : *BB) { - // Stores corresponding to reductions are safe while concluding if tightly - // nested. - if (StoreInst *L = dyn_cast<StoreInst>(&I)) { - if (!isa<PHINode>(L->getOperand(0))) - return true; - } else if (I.mayHaveSideEffects() || I.mayReadFromMemory()) - return true; - } - return false; -} - bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); @@ -662,15 +609,16 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { if (!OuterLoopHeaderBI) return false; - for (BasicBlock *Succ : OuterLoopHeaderBI->successors()) - if (Succ != InnerLoopPreHeader && Succ != OuterLoopLatch) + for (BasicBlock *Succ : successors(OuterLoopHeaderBI)) + if (Succ != InnerLoopPreHeader && Succ != InnerLoop->getHeader() && + Succ != OuterLoopLatch) return false; LLVM_DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); // We do not have any basic block in between now make sure the outer header // and outer loop latch doesn't contain any unsafe instructions. - if (containsUnsafeInstructionsInHeader(OuterLoopHeader) || - containsUnsafeInstructionsInLatch(OuterLoopLatch)) + if (containsUnsafeInstructions(OuterLoopHeader) || + containsUnsafeInstructions(OuterLoopLatch)) return false; LLVM_DEBUG(dbgs() << "Loops are perfectly nested\n"); @@ -702,9 +650,36 @@ bool LoopInterchangeLegality::isLoopStructureUnderstood( return true; } +// If SV is a LCSSA PHI node with a single incoming value, return the incoming +// value. +static Value *followLCSSA(Value *SV) { + PHINode *PHI = dyn_cast<PHINode>(SV); + if (!PHI) + return SV; + + if (PHI->getNumIncomingValues() != 1) + return SV; + return followLCSSA(PHI->getIncomingValue(0)); +} + +// Check V's users to see if it is involved in a reduction in L. +static PHINode *findInnerReductionPhi(Loop *L, Value *V) { + for (Value *User : V->users()) { + if (PHINode *PHI = dyn_cast<PHINode>(User)) { + if (PHI->getNumIncomingValues() == 1) + continue; + RecurrenceDescriptor RD; + if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) + return PHI; + return nullptr; + } + } + + return nullptr; +} + bool LoopInterchangeLegality::findInductionAndReductions( - Loop *L, SmallVector<PHINode *, 8> &Inductions, - SmallVector<PHINode *, 8> &Reductions) { + Loop *L, SmallVector<PHINode *, 8> &Inductions, Loop *InnerLoop) { if (!L->getLoopLatch() || !L->getLoopPredecessor()) return false; for (PHINode &PHI : L->getHeader()->phis()) { @@ -712,12 +687,33 @@ bool LoopInterchangeLegality::findInductionAndReductions( InductionDescriptor ID; if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) Inductions.push_back(&PHI); - else if (RecurrenceDescriptor::isReductionPHI(&PHI, L, RD)) - Reductions.push_back(&PHI); else { - LLVM_DEBUG( - dbgs() << "Failed to recognize PHI as an induction or reduction.\n"); - return false; + // PHIs in inner loops need to be part of a reduction in the outer loop, + // discovered when checking the PHIs of the outer loop earlier. + if (!InnerLoop) { + if (OuterInnerReductions.find(&PHI) == OuterInnerReductions.end()) { + LLVM_DEBUG(dbgs() << "Inner loop PHI is not part of reductions " + "across the outer loop.\n"); + return false; + } + } else { + assert(PHI.getNumIncomingValues() == 2 && + "Phis in loop header should have exactly 2 incoming values"); + // Check if we have a PHI node in the outer loop that has a reduction + // result from the inner loop as an incoming value. + Value *V = followLCSSA(PHI.getIncomingValueForBlock(L->getLoopLatch())); + PHINode *InnerRedPhi = findInnerReductionPhi(InnerLoop, V); + if (!InnerRedPhi || + !llvm::any_of(InnerRedPhi->incoming_values(), + [&PHI](Value *V) { return V == &PHI; })) { + LLVM_DEBUG( + dbgs() + << "Failed to recognize PHI as an induction or reduction.\n"); + return false; + } + OuterInnerReductions.insert(&PHI); + OuterInnerReductions.insert(InnerRedPhi); + } } } return true; @@ -766,81 +762,64 @@ bool LoopInterchangeLegality::currentLimitations() { PHINode *InnerInductionVar; SmallVector<PHINode *, 8> Inductions; - SmallVector<PHINode *, 8> Reductions; - if (!findInductionAndReductions(InnerLoop, Inductions, Reductions)) { + if (!findInductionAndReductions(OuterLoop, Inductions, InnerLoop)) { LLVM_DEBUG( - dbgs() << "Only inner loops with induction or reduction PHI nodes " + dbgs() << "Only outer loops with induction or reduction PHI nodes " << "are supported currently.\n"); ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Only inner loops with induction or reduction PHI nodes can be" - " interchange currently."; + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with induction or reduction PHI nodes can be" + " interchanged currently."; }); return true; } // TODO: Currently we handle only loops with 1 induction variable. if (Inductions.size() != 1) { - LLVM_DEBUG( - dbgs() << "We currently only support loops with 1 induction variable." - << "Failed to interchange due to current limitation\n"); + LLVM_DEBUG(dbgs() << "Loops with more than 1 induction variables are not " + << "supported currently.\n"); ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "MultiInductionInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Only inner loops with 1 induction variable can be " + return OptimizationRemarkMissed(DEBUG_TYPE, "MultiIndutionOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with 1 induction variable can be " "interchanged currently."; }); return true; } - if (Reductions.size() > 0) - InnerLoopHasReduction = true; - InnerInductionVar = Inductions.pop_back_val(); - Reductions.clear(); - if (!findInductionAndReductions(OuterLoop, Inductions, Reductions)) { + Inductions.clear(); + if (!findInductionAndReductions(InnerLoop, Inductions, nullptr)) { LLVM_DEBUG( - dbgs() << "Only outer loops with induction or reduction PHI nodes " + dbgs() << "Only inner loops with induction or reduction PHI nodes " << "are supported currently.\n"); ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Only outer loops with induction or reduction PHI nodes can be" - " interchanged currently."; + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with induction or reduction PHI nodes can be" + " interchange currently."; }); return true; } - // Outer loop cannot have reduction because then loops will not be tightly - // nested. - if (!Reductions.empty()) { - LLVM_DEBUG(dbgs() << "Outer loops with reductions are not supported " - << "currently.\n"); - ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "ReductionsOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Outer loops with reductions cannot be interchangeed " - "currently."; - }); - return true; - } // TODO: Currently we handle only loops with 1 induction variable. if (Inductions.size() != 1) { - LLVM_DEBUG(dbgs() << "Loops with more than 1 induction variables are not " - << "supported currently.\n"); + LLVM_DEBUG( + dbgs() << "We currently only support loops with 1 induction variable." + << "Failed to interchange due to current limitation\n"); ORE->emit([&]() { - return OptimizationRemarkMissed(DEBUG_TYPE, "MultiIndutionOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Only outer loops with 1 induction variable can be " + return OptimizationRemarkMissed(DEBUG_TYPE, "MultiInductionInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with 1 induction variable can be " "interchanged currently."; }); return true; } + InnerInductionVar = Inductions.pop_back_val(); // TODO: Triangular loops are not handled for now. if (!isLoopStructureUnderstood(InnerInductionVar)) { @@ -1016,28 +995,6 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, return false; } - // Create unique Preheaders if we already do not have one. - BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); - - // Create a unique outer preheader - - // 1) If OuterLoop preheader is not present. - // 2) If OuterLoop Preheader is same as OuterLoop Header - // 3) If OuterLoop Preheader is same as Header of the previous loop. - // 4) If OuterLoop Preheader is Entry node. - if (!OuterLoopPreHeader || OuterLoopPreHeader == OuterLoop->getHeader() || - isa<PHINode>(OuterLoopPreHeader->begin()) || - !OuterLoopPreHeader->getUniquePredecessor()) { - OuterLoopPreHeader = - InsertPreheaderForLoop(OuterLoop, DT, LI, PreserveLCSSA); - } - - if (!InnerLoopPreHeader || InnerLoopPreHeader == InnerLoop->getHeader() || - InnerLoopPreHeader == OuterLoop->getHeader()) { - InnerLoopPreHeader = - InsertPreheaderForLoop(InnerLoop, DT, LI, PreserveLCSSA); - } - // TODO: The loops could not be interchanged due to current limitations in the // transform module. if (currentLimitations()) { @@ -1258,6 +1215,10 @@ void LoopInterchangeTransform::restructureLoops( // outer loop. NewOuter->addBlockEntry(OrigOuterPreHeader); LI->changeLoopFor(OrigOuterPreHeader, NewOuter); + + // Tell SE that we move the loops around. + SE->forgetLoop(NewOuter); + SE->forgetLoop(NewInner); } bool LoopInterchangeTransform::transform() { @@ -1319,9 +1280,8 @@ static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { FromBB->getTerminator()->getIterator()); } -void LoopInterchangeTransform::updateIncomingBlock(BasicBlock *CurrBlock, - BasicBlock *OldPred, - BasicBlock *NewPred) { +static void updateIncomingBlock(BasicBlock *CurrBlock, BasicBlock *OldPred, + BasicBlock *NewPred) { for (PHINode &PHI : CurrBlock->phis()) { unsigned Num = PHI.getNumIncomingValues(); for (unsigned i = 0; i < Num; ++i) { @@ -1336,7 +1296,7 @@ void LoopInterchangeTransform::updateIncomingBlock(BasicBlock *CurrBlock, static void updateSuccessor(BranchInst *BI, BasicBlock *OldBB, BasicBlock *NewBB, std::vector<DominatorTree::UpdateType> &DTUpdates) { - assert(llvm::count_if(BI->successors(), + assert(llvm::count_if(successors(BI), [OldBB](BasicBlock *BB) { return BB == OldBB; }) < 2 && "BI must jump to OldBB at most once."); for (unsigned i = 0, e = BI->getNumSuccessors(); i < e; ++i) { @@ -1352,17 +1312,77 @@ static void updateSuccessor(BranchInst *BI, BasicBlock *OldBB, } } +// Move Lcssa PHIs to the right place. +static void moveLCSSAPhis(BasicBlock *InnerExit, BasicBlock *InnerLatch, + BasicBlock *OuterLatch) { + SmallVector<PHINode *, 8> LcssaInnerExit; + for (PHINode &P : InnerExit->phis()) + LcssaInnerExit.push_back(&P); + + SmallVector<PHINode *, 8> LcssaInnerLatch; + for (PHINode &P : InnerLatch->phis()) + LcssaInnerLatch.push_back(&P); + + // Lcssa PHIs for values used outside the inner loop are in InnerExit. + // If a PHI node has users outside of InnerExit, it has a use outside the + // interchanged loop and we have to preserve it. We move these to + // InnerLatch, which will become the new exit block for the innermost + // loop after interchanging. For PHIs only used in InnerExit, we can just + // replace them with the incoming value. + for (PHINode *P : LcssaInnerExit) { + bool hasUsersOutside = false; + for (auto UI = P->use_begin(), E = P->use_end(); UI != E;) { + Use &U = *UI; + ++UI; + auto *Usr = cast<Instruction>(U.getUser()); + if (Usr->getParent() != InnerExit) { + hasUsersOutside = true; + continue; + } + U.set(P->getIncomingValueForBlock(InnerLatch)); + } + if (hasUsersOutside) + P->moveBefore(InnerLatch->getFirstNonPHI()); + else + P->eraseFromParent(); + } + + // If the inner loop latch contains LCSSA PHIs, those come from a child loop + // and we have to move them to the new inner latch. + for (PHINode *P : LcssaInnerLatch) + P->moveBefore(InnerExit->getFirstNonPHI()); + + // Now adjust the incoming blocks for the LCSSA PHIs. + // For PHIs moved from Inner's exit block, we need to replace Inner's latch + // with the new latch. + updateIncomingBlock(InnerLatch, InnerLatch, OuterLatch); +} + bool LoopInterchangeTransform::adjustLoopBranches() { LLVM_DEBUG(dbgs() << "adjustLoopBranches called\n"); std::vector<DominatorTree::UpdateType> DTUpdates; + BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + + assert(OuterLoopPreHeader != OuterLoop->getHeader() && + InnerLoopPreHeader != InnerLoop->getHeader() && OuterLoopPreHeader && + InnerLoopPreHeader && "Guaranteed by loop-simplify form"); + // Ensure that both preheaders do not contain PHI nodes and have single + // predecessors. This allows us to move them easily. We use + // InsertPreHeaderForLoop to create an 'extra' preheader, if the existing + // preheaders do not satisfy those conditions. + if (isa<PHINode>(OuterLoopPreHeader->begin()) || + !OuterLoopPreHeader->getUniquePredecessor()) + OuterLoopPreHeader = InsertPreheaderForLoop(OuterLoop, DT, LI, true); + if (InnerLoopPreHeader == OuterLoop->getHeader()) + InnerLoopPreHeader = InsertPreheaderForLoop(InnerLoop, DT, LI, true); + // Adjust the loop preheader BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); - BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); BasicBlock *OuterLoopPredecessor = OuterLoopPreHeader->getUniquePredecessor(); BasicBlock *InnerLoopLatchPredecessor = InnerLoopLatch->getUniquePredecessor(); @@ -1417,17 +1437,6 @@ bool LoopInterchangeTransform::adjustLoopBranches() { updateSuccessor(InnerLoopLatchPredecessorBI, InnerLoopLatch, InnerLoopLatchSuccessor, DTUpdates); - // Adjust PHI nodes in InnerLoopLatchSuccessor. Update all uses of PHI with - // the value and remove this PHI node from inner loop. - SmallVector<PHINode *, 8> LcssaVec; - for (PHINode &P : InnerLoopLatchSuccessor->phis()) - LcssaVec.push_back(&P); - - for (PHINode *P : LcssaVec) { - Value *Incoming = P->getIncomingValueForBlock(InnerLoopLatch); - P->replaceAllUsesWith(Incoming); - P->eraseFromParent(); - } if (OuterLoopLatchBI->getSuccessor(0) == OuterLoopHeader) OuterLoopLatchSuccessor = OuterLoopLatchBI->getSuccessor(1); @@ -1439,12 +1448,15 @@ bool LoopInterchangeTransform::adjustLoopBranches() { updateSuccessor(OuterLoopLatchBI, OuterLoopLatchSuccessor, InnerLoopLatch, DTUpdates); - updateIncomingBlock(OuterLoopLatchSuccessor, OuterLoopLatch, InnerLoopLatch); - DT->applyUpdates(DTUpdates); restructureLoops(OuterLoop, InnerLoop, InnerLoopPreHeader, OuterLoopPreHeader); + moveLCSSAPhis(InnerLoopLatchSuccessor, InnerLoopLatch, OuterLoopLatch); + // For PHIs in the exit block of the outer loop, outer's latch has been + // replaced by Inners'. + updateIncomingBlock(OuterLoopLatchSuccessor, OuterLoopLatch, InnerLoopLatch); + // Now update the reduction PHIs in the inner and outer loop headers. SmallVector<PHINode *, 4> InnerLoopPHIs, OuterLoopPHIs; for (PHINode &PHI : drop_begin(InnerLoopHeader->phis(), 1)) @@ -1452,26 +1464,21 @@ bool LoopInterchangeTransform::adjustLoopBranches() { for (PHINode &PHI : drop_begin(OuterLoopHeader->phis(), 1)) OuterLoopPHIs.push_back(cast<PHINode>(&PHI)); - for (PHINode *PHI : OuterLoopPHIs) - PHI->moveBefore(InnerLoopHeader->getFirstNonPHI()); + auto &OuterInnerReductions = LIL.getOuterInnerReductions(); + (void)OuterInnerReductions; - // Move the PHI nodes from the inner loop header to the outer loop header. - // We have to deal with one kind of PHI nodes: - // 1) PHI nodes that are part of inner loop-only reductions. - // We only have to move the PHI node and update the incoming blocks. + // Now move the remaining reduction PHIs from outer to inner loop header and + // vice versa. The PHI nodes must be part of a reduction across the inner and + // outer loop and all the remains to do is and updating the incoming blocks. + for (PHINode *PHI : OuterLoopPHIs) { + PHI->moveBefore(InnerLoopHeader->getFirstNonPHI()); + assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() && + "Expected a reduction PHI node"); + } for (PHINode *PHI : InnerLoopPHIs) { PHI->moveBefore(OuterLoopHeader->getFirstNonPHI()); - for (BasicBlock *InBB : PHI->blocks()) { - if (InnerLoop->contains(InBB)) - continue; - - assert(!isa<PHINode>(PHI->getIncomingValueForBlock(InBB)) && - "Unexpected incoming PHI node, reductions in outer loop are not " - "supported yet"); - PHI->replaceAllUsesWith(PHI->getIncomingValueForBlock(InBB)); - PHI->eraseFromParent(); - break; - } + assert(OuterInnerReductions.find(PHI) != OuterInnerReductions.end() && + "Expected a reduction PHI node"); } // Update the incoming blocks for moved PHI nodes. @@ -1514,13 +1521,8 @@ char LoopInterchange::ID = 0; INITIALIZE_PASS_BEGIN(LoopInterchange, "loop-interchange", "Interchanges loops for cache reuse", false, false) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopSimplify) -INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(LoopInterchange, "loop-interchange", diff --git a/lib/Transforms/Scalar/LoopPassManager.cpp b/lib/Transforms/Scalar/LoopPassManager.cpp index 10f6fcdcfdb7..774ad7b945a0 100644 --- a/lib/Transforms/Scalar/LoopPassManager.cpp +++ b/lib/Transforms/Scalar/LoopPassManager.cpp @@ -30,12 +30,26 @@ PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, if (DebugLogging) dbgs() << "Starting Loop pass manager run.\n"; + // Request PassInstrumentation from analysis manager, will use it to run + // instrumenting callbacks for the passes later. + PassInstrumentation PI = AM.getResult<PassInstrumentationAnalysis>(L, AR); for (auto &Pass : Passes) { if (DebugLogging) dbgs() << "Running pass: " << Pass->name() << " on " << L; + // Check the PassInstrumentation's BeforePass callbacks before running the + // pass, skip its execution completely if asked to (callback returns false). + if (!PI.runBeforePass<Loop>(*Pass, L)) + continue; + PreservedAnalyses PassPA = Pass->run(L, AM, AR, U); + // do not pass deleted Loop into the instrumentation + if (U.skipCurrentLoop()) + PI.runAfterPassInvalidated<Loop>(*Pass); + else + PI.runAfterPass<Loop>(*Pass, L); + // If the loop was deleted, abort the run and return to the outer walk. if (U.skipCurrentLoop()) { PA.intersect(std::move(PassPA)); diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp index cbb6594cf8f4..5983c804c0c1 100644 --- a/lib/Transforms/Scalar/LoopPredication.cpp +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -178,7 +178,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -196,6 +198,9 @@ #define DEBUG_TYPE "loop-predication" +STATISTIC(TotalConsidered, "Number of guards considered"); +STATISTIC(TotalWidened, "Number of checks widened"); + using namespace llvm; static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation", @@ -574,6 +579,8 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, LLVM_DEBUG(dbgs() << "Processing guard:\n"); LLVM_DEBUG(Guard->dump()); + TotalConsidered++; + IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); // The guard condition is expected to be in form of: @@ -615,6 +622,8 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, if (NumWidened == 0) return false; + TotalWidened += NumWidened; + // Emit the new guard condition Builder.SetInsertPoint(Guard); Value *LastCheck = nullptr; @@ -812,9 +821,8 @@ bool LoopPredication::runOnLoop(Loop *Loop) { SmallVector<IntrinsicInst *, 4> Guards; for (const auto BB : L->blocks()) for (auto &I : *BB) - if (auto *II = dyn_cast<IntrinsicInst>(&I)) - if (II->getIntrinsicID() == Intrinsic::experimental_guard) - Guards.push_back(II); + if (isGuard(&I)) + Guards.push_back(cast<IntrinsicInst>(&I)); if (Guards.empty()) return false; diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index eeaad39dc1d1..fd22128f7fe6 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -15,6 +15,8 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Support/Debug.h" @@ -40,12 +42,19 @@ PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); const SimplifyQuery SQ = getBestSimplifyQuery(AR, DL); - bool Changed = LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, SQ, - false, Threshold, false); + Optional<MemorySSAUpdater> MSSAU; + if (AR.MSSA) + MSSAU = MemorySSAUpdater(AR.MSSA); + bool Changed = LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, + SQ, false, Threshold, false); if (!Changed) return PreservedAnalyses::all(); + if (AR.MSSA && VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + return getLoopPassPreservedAnalyses(); } @@ -68,6 +77,10 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + if (EnableMSSALoopDependency) { + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } getLoopAnalysisUsage(AU); } @@ -84,8 +97,14 @@ public: auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); - return LoopRotation(L, LI, TTI, AC, DT, SE, SQ, false, MaxHeaderSize, - false); + Optional<MemorySSAUpdater> MSSAU; + if (EnableMSSALoopDependency) { + MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSAU = MemorySSAUpdater(MSSA); + } + return LoopRotation(L, LI, TTI, AC, DT, SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr, SQ, + false, MaxHeaderSize, false); } }; } @@ -96,6 +115,7 @@ INITIALIZE_PASS_BEGIN(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_END(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", false, false) diff --git a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp index 2b83d3dc5f1b..2e5927f9a068 100644 --- a/lib/Transforms/Scalar/LoopSimplifyCFG.cpp +++ b/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -24,9 +24,12 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" @@ -38,9 +41,527 @@ using namespace llvm; #define DEBUG_TYPE "loop-simplifycfg" -static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, - ScalarEvolution &SE) { +static cl::opt<bool> EnableTermFolding("enable-loop-simplifycfg-term-folding", + cl::init(false)); + +STATISTIC(NumTerminatorsFolded, + "Number of terminators folded to unconditional branches"); +STATISTIC(NumLoopBlocksDeleted, + "Number of loop blocks deleted"); +STATISTIC(NumLoopExitsDeleted, + "Number of loop exiting edges deleted"); + +/// If \p BB is a switch or a conditional branch, but only one of its successors +/// can be reached from this block in runtime, return this successor. Otherwise, +/// return nullptr. +static BasicBlock *getOnlyLiveSuccessor(BasicBlock *BB) { + Instruction *TI = BB->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (BI->isUnconditional()) + return nullptr; + if (BI->getSuccessor(0) == BI->getSuccessor(1)) + return BI->getSuccessor(0); + ConstantInt *Cond = dyn_cast<ConstantInt>(BI->getCondition()); + if (!Cond) + return nullptr; + return Cond->isZero() ? BI->getSuccessor(1) : BI->getSuccessor(0); + } + + if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + auto *CI = dyn_cast<ConstantInt>(SI->getCondition()); + if (!CI) + return nullptr; + for (auto Case : SI->cases()) + if (Case.getCaseValue() == CI) + return Case.getCaseSuccessor(); + return SI->getDefaultDest(); + } + + return nullptr; +} + +namespace { +/// Helper class that can turn branches and switches with constant conditions +/// into unconditional branches. +class ConstantTerminatorFoldingImpl { +private: + Loop &L; + LoopInfo &LI; + DominatorTree &DT; + ScalarEvolution &SE; + MemorySSAUpdater *MSSAU; + + // Whether or not the current loop has irreducible CFG. + bool HasIrreducibleCFG = false; + // Whether or not the current loop will still exist after terminator constant + // folding will be done. In theory, there are two ways how it can happen: + // 1. Loop's latch(es) become unreachable from loop header; + // 2. Loop's header becomes unreachable from method entry. + // In practice, the second situation is impossible because we only modify the + // current loop and its preheader and do not affect preheader's reachibility + // from any other block. So this variable set to true means that loop's latch + // has become unreachable from loop header. + bool DeleteCurrentLoop = false; + + // The blocks of the original loop that will still be reachable from entry + // after the constant folding. + SmallPtrSet<BasicBlock *, 8> LiveLoopBlocks; + // The blocks of the original loop that will become unreachable from entry + // after the constant folding. + SmallVector<BasicBlock *, 8> DeadLoopBlocks; + // The exits of the original loop that will still be reachable from entry + // after the constant folding. + SmallPtrSet<BasicBlock *, 8> LiveExitBlocks; + // The exits of the original loop that will become unreachable from entry + // after the constant folding. + SmallVector<BasicBlock *, 8> DeadExitBlocks; + // The blocks that will still be a part of the current loop after folding. + SmallPtrSet<BasicBlock *, 8> BlocksInLoopAfterFolding; + // The blocks that have terminators with constant condition that can be + // folded. Note: fold candidates should be in L but not in any of its + // subloops to avoid complex LI updates. + SmallVector<BasicBlock *, 8> FoldCandidates; + + void dump() const { + dbgs() << "Constant terminator folding for loop " << L << "\n"; + dbgs() << "After terminator constant-folding, the loop will"; + if (!DeleteCurrentLoop) + dbgs() << " not"; + dbgs() << " be destroyed\n"; + auto PrintOutVector = [&](const char *Message, + const SmallVectorImpl<BasicBlock *> &S) { + dbgs() << Message << "\n"; + for (const BasicBlock *BB : S) + dbgs() << "\t" << BB->getName() << "\n"; + }; + auto PrintOutSet = [&](const char *Message, + const SmallPtrSetImpl<BasicBlock *> &S) { + dbgs() << Message << "\n"; + for (const BasicBlock *BB : S) + dbgs() << "\t" << BB->getName() << "\n"; + }; + PrintOutVector("Blocks in which we can constant-fold terminator:", + FoldCandidates); + PrintOutSet("Live blocks from the original loop:", LiveLoopBlocks); + PrintOutVector("Dead blocks from the original loop:", DeadLoopBlocks); + PrintOutSet("Live exit blocks:", LiveExitBlocks); + PrintOutVector("Dead exit blocks:", DeadExitBlocks); + if (!DeleteCurrentLoop) + PrintOutSet("The following blocks will still be part of the loop:", + BlocksInLoopAfterFolding); + } + + /// Whether or not the current loop has irreducible CFG. + bool hasIrreducibleCFG(LoopBlocksDFS &DFS) { + assert(DFS.isComplete() && "DFS is expected to be finished"); + // Index of a basic block in RPO traversal. + DenseMap<const BasicBlock *, unsigned> RPO; + unsigned Current = 0; + for (auto I = DFS.beginRPO(), E = DFS.endRPO(); I != E; ++I) + RPO[*I] = Current++; + + for (auto I = DFS.beginRPO(), E = DFS.endRPO(); I != E; ++I) { + BasicBlock *BB = *I; + for (auto *Succ : successors(BB)) + if (L.contains(Succ) && !LI.isLoopHeader(Succ) && RPO[BB] > RPO[Succ]) + // If an edge goes from a block with greater order number into a block + // with lesses number, and it is not a loop backedge, then it can only + // be a part of irreducible non-loop cycle. + return true; + } + return false; + } + + /// Fill all information about status of blocks and exits of the current loop + /// if constant folding of all branches will be done. + void analyze() { + LoopBlocksDFS DFS(&L); + DFS.perform(&LI); + assert(DFS.isComplete() && "DFS is expected to be finished"); + + // TODO: The algorithm below relies on both RPO and Postorder traversals. + // When the loop has only reducible CFG inside, then the invariant "all + // predecessors of X are processed before X in RPO" is preserved. However + // an irreducible loop can break this invariant (e.g. latch does not have to + // be the last block in the traversal in this case, and the algorithm relies + // on this). We can later decide to support such cases by altering the + // algorithms, but so far we just give up analyzing them. + if (hasIrreducibleCFG(DFS)) { + HasIrreducibleCFG = true; + return; + } + + // Collect live and dead loop blocks and exits. + LiveLoopBlocks.insert(L.getHeader()); + for (auto I = DFS.beginRPO(), E = DFS.endRPO(); I != E; ++I) { + BasicBlock *BB = *I; + + // If a loop block wasn't marked as live so far, then it's dead. + if (!LiveLoopBlocks.count(BB)) { + DeadLoopBlocks.push_back(BB); + continue; + } + + BasicBlock *TheOnlySucc = getOnlyLiveSuccessor(BB); + + // If a block has only one live successor, it's a candidate on constant + // folding. Only handle blocks from current loop: branches in child loops + // are skipped because if they can be folded, they should be folded during + // the processing of child loops. + if (TheOnlySucc && LI.getLoopFor(BB) == &L) + FoldCandidates.push_back(BB); + + // Handle successors. + for (BasicBlock *Succ : successors(BB)) + if (!TheOnlySucc || TheOnlySucc == Succ) { + if (L.contains(Succ)) + LiveLoopBlocks.insert(Succ); + else + LiveExitBlocks.insert(Succ); + } + } + + // Sanity check: amount of dead and live loop blocks should match the total + // number of blocks in loop. + assert(L.getNumBlocks() == LiveLoopBlocks.size() + DeadLoopBlocks.size() && + "Malformed block sets?"); + + // Now, all exit blocks that are not marked as live are dead. + SmallVector<BasicBlock *, 8> ExitBlocks; + L.getExitBlocks(ExitBlocks); + for (auto *ExitBlock : ExitBlocks) + if (!LiveExitBlocks.count(ExitBlock)) + DeadExitBlocks.push_back(ExitBlock); + + // Whether or not the edge From->To will still be present in graph after the + // folding. + auto IsEdgeLive = [&](BasicBlock *From, BasicBlock *To) { + if (!LiveLoopBlocks.count(From)) + return false; + BasicBlock *TheOnlySucc = getOnlyLiveSuccessor(From); + return !TheOnlySucc || TheOnlySucc == To; + }; + + // The loop will not be destroyed if its latch is live. + DeleteCurrentLoop = !IsEdgeLive(L.getLoopLatch(), L.getHeader()); + + // If we are going to delete the current loop completely, no extra analysis + // is needed. + if (DeleteCurrentLoop) + return; + + // Otherwise, we should check which blocks will still be a part of the + // current loop after the transform. + BlocksInLoopAfterFolding.insert(L.getLoopLatch()); + // If the loop is live, then we should compute what blocks are still in + // loop after all branch folding has been done. A block is in loop if + // it has a live edge to another block that is in the loop; by definition, + // latch is in the loop. + auto BlockIsInLoop = [&](BasicBlock *BB) { + return any_of(successors(BB), [&](BasicBlock *Succ) { + return BlocksInLoopAfterFolding.count(Succ) && IsEdgeLive(BB, Succ); + }); + }; + for (auto I = DFS.beginPostorder(), E = DFS.endPostorder(); I != E; ++I) { + BasicBlock *BB = *I; + if (BlockIsInLoop(BB)) + BlocksInLoopAfterFolding.insert(BB); + } + + // Sanity check: header must be in loop. + assert(BlocksInLoopAfterFolding.count(L.getHeader()) && + "Header not in loop?"); + assert(BlocksInLoopAfterFolding.size() <= LiveLoopBlocks.size() && + "All blocks that stay in loop should be live!"); + } + + /// We need to preserve static reachibility of all loop exit blocks (this is) + /// required by loop pass manager. In order to do it, we make the following + /// trick: + /// + /// preheader: + /// <preheader code> + /// br label %loop_header + /// + /// loop_header: + /// ... + /// br i1 false, label %dead_exit, label %loop_block + /// ... + /// + /// We cannot simply remove edge from the loop to dead exit because in this + /// case dead_exit (and its successors) may become unreachable. To avoid that, + /// we insert the following fictive preheader: + /// + /// preheader: + /// <preheader code> + /// switch i32 0, label %preheader-split, + /// [i32 1, label %dead_exit_1], + /// [i32 2, label %dead_exit_2], + /// ... + /// [i32 N, label %dead_exit_N], + /// + /// preheader-split: + /// br label %loop_header + /// + /// loop_header: + /// ... + /// br i1 false, label %dead_exit_N, label %loop_block + /// ... + /// + /// Doing so, we preserve static reachibility of all dead exits and can later + /// remove edges from the loop to these blocks. + void handleDeadExits() { + // If no dead exits, nothing to do. + if (DeadExitBlocks.empty()) + return; + + // Construct split preheader and the dummy switch to thread edges from it to + // dead exits. + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + BasicBlock *Preheader = L.getLoopPreheader(); + BasicBlock *NewPreheader = Preheader->splitBasicBlock( + Preheader->getTerminator(), + Twine(Preheader->getName()).concat("-split")); + DTU.deleteEdge(Preheader, L.getHeader()); + DTU.insertEdge(NewPreheader, L.getHeader()); + DTU.insertEdge(Preheader, NewPreheader); + IRBuilder<> Builder(Preheader->getTerminator()); + SwitchInst *DummySwitch = + Builder.CreateSwitch(Builder.getInt32(0), NewPreheader); + Preheader->getTerminator()->eraseFromParent(); + + unsigned DummyIdx = 1; + for (BasicBlock *BB : DeadExitBlocks) { + SmallVector<Instruction *, 4> DeadPhis; + for (auto &PN : BB->phis()) + DeadPhis.push_back(&PN); + + // Eliminate all Phis from dead exits. + for (Instruction *PN : DeadPhis) { + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->eraseFromParent(); + } + assert(DummyIdx != 0 && "Too many dead exits!"); + DummySwitch->addCase(Builder.getInt32(DummyIdx++), BB); + DTU.insertEdge(Preheader, BB); + ++NumLoopExitsDeleted; + } + + assert(L.getLoopPreheader() == NewPreheader && "Malformed CFG?"); + if (Loop *OuterLoop = LI.getLoopFor(Preheader)) { + OuterLoop->addBasicBlockToLoop(NewPreheader, LI); + + // When we break dead edges, the outer loop may become unreachable from + // the current loop. We need to fix loop info accordingly. For this, we + // find the most nested loop that still contains L and remove L from all + // loops that are inside of it. + Loop *StillReachable = nullptr; + for (BasicBlock *BB : LiveExitBlocks) { + Loop *BBL = LI.getLoopFor(BB); + if (BBL && BBL->contains(L.getHeader())) + if (!StillReachable || + BBL->getLoopDepth() > StillReachable->getLoopDepth()) + StillReachable = BBL; + } + + // Okay, our loop is no longer in the outer loop (and maybe not in some of + // its parents as well). Make the fixup. + if (StillReachable != OuterLoop) { + LI.changeLoopFor(NewPreheader, StillReachable); + for (Loop *NotContaining = OuterLoop; NotContaining != StillReachable; + NotContaining = NotContaining->getParentLoop()) { + NotContaining->removeBlockFromLoop(NewPreheader); + for (auto *BB : L.blocks()) + NotContaining->removeBlockFromLoop(BB); + } + OuterLoop->removeChildLoop(&L); + if (StillReachable) + StillReachable->addChildLoop(&L); + else + LI.addTopLevelLoop(&L); + } + } + } + + /// Delete loop blocks that have become unreachable after folding. Make all + /// relevant updates to DT and LI. + void deleteDeadLoopBlocks() { + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + if (MSSAU) { + SmallPtrSet<BasicBlock *, 8> DeadLoopBlocksSet(DeadLoopBlocks.begin(), + DeadLoopBlocks.end()); + MSSAU->removeBlocks(DeadLoopBlocksSet); + } + for (auto *BB : DeadLoopBlocks) { + assert(BB != L.getHeader() && + "Header of the current loop cannot be dead!"); + LLVM_DEBUG(dbgs() << "Deleting dead loop block " << BB->getName() + << "\n"); + if (LI.isLoopHeader(BB)) { + assert(LI.getLoopFor(BB) != &L && "Attempt to remove current loop!"); + LI.erase(LI.getLoopFor(BB)); + } + LI.removeBlock(BB); + DeleteDeadBlock(BB, &DTU); + ++NumLoopBlocksDeleted; + } + } + + /// Constant-fold terminators of blocks acculumated in FoldCandidates into the + /// unconditional branches. + void foldTerminators() { + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + + for (BasicBlock *BB : FoldCandidates) { + assert(LI.getLoopFor(BB) == &L && "Should be a loop block!"); + BasicBlock *TheOnlySucc = getOnlyLiveSuccessor(BB); + assert(TheOnlySucc && "Should have one live successor!"); + + LLVM_DEBUG(dbgs() << "Replacing terminator of " << BB->getName() + << " with an unconditional branch to the block " + << TheOnlySucc->getName() << "\n"); + + SmallPtrSet<BasicBlock *, 2> DeadSuccessors; + // Remove all BB's successors except for the live one. + unsigned TheOnlySuccDuplicates = 0; + for (auto *Succ : successors(BB)) + if (Succ != TheOnlySucc) { + DeadSuccessors.insert(Succ); + // If our successor lies in a different loop, we don't want to remove + // the one-input Phi because it is a LCSSA Phi. + bool PreserveLCSSAPhi = !L.contains(Succ); + Succ->removePredecessor(BB, PreserveLCSSAPhi); + if (MSSAU) + MSSAU->removeEdge(BB, Succ); + } else + ++TheOnlySuccDuplicates; + + assert(TheOnlySuccDuplicates > 0 && "Should be!"); + // If TheOnlySucc was BB's successor more than once, after transform it + // will be its successor only once. Remove redundant inputs from + // TheOnlySucc's Phis. + bool PreserveLCSSAPhi = !L.contains(TheOnlySucc); + for (unsigned Dup = 1; Dup < TheOnlySuccDuplicates; ++Dup) + TheOnlySucc->removePredecessor(BB, PreserveLCSSAPhi); + if (MSSAU && TheOnlySuccDuplicates > 1) + MSSAU->removeDuplicatePhiEdgesBetween(BB, TheOnlySucc); + + IRBuilder<> Builder(BB->getContext()); + Instruction *Term = BB->getTerminator(); + Builder.SetInsertPoint(Term); + Builder.CreateBr(TheOnlySucc); + Term->eraseFromParent(); + + for (auto *DeadSucc : DeadSuccessors) + DTU.deleteEdge(BB, DeadSucc); + + ++NumTerminatorsFolded; + } + } + +public: + ConstantTerminatorFoldingImpl(Loop &L, LoopInfo &LI, DominatorTree &DT, + ScalarEvolution &SE, + MemorySSAUpdater *MSSAU) + : L(L), LI(LI), DT(DT), SE(SE), MSSAU(MSSAU) {} + bool run() { + assert(L.getLoopLatch() && "Should be single latch!"); + + // Collect all available information about status of blocks after constant + // folding. + analyze(); + + LLVM_DEBUG(dbgs() << "In function " << L.getHeader()->getParent()->getName() + << ": "); + + if (HasIrreducibleCFG) { + LLVM_DEBUG(dbgs() << "Loops with irreducible CFG are not supported!\n"); + return false; + } + + // Nothing to constant-fold. + if (FoldCandidates.empty()) { + LLVM_DEBUG( + dbgs() << "No constant terminator folding candidates found in loop " + << L.getHeader()->getName() << "\n"); + return false; + } + + // TODO: Support deletion of the current loop. + if (DeleteCurrentLoop) { + LLVM_DEBUG( + dbgs() + << "Give up constant terminator folding in loop " + << L.getHeader()->getName() + << ": we don't currently support deletion of the current loop.\n"); + return false; + } + + // TODO: Support blocks that are not dead, but also not in loop after the + // folding. + if (BlocksInLoopAfterFolding.size() + DeadLoopBlocks.size() != + L.getNumBlocks()) { + LLVM_DEBUG( + dbgs() << "Give up constant terminator folding in loop " + << L.getHeader()->getName() + << ": we don't currently" + " support blocks that are not dead, but will stop " + "being a part of the loop after constant-folding.\n"); + return false; + } + + SE.forgetTopmostLoop(&L); + // Dump analysis results. + LLVM_DEBUG(dump()); + + LLVM_DEBUG(dbgs() << "Constant-folding " << FoldCandidates.size() + << " terminators in loop " << L.getHeader()->getName() + << "\n"); + + // Make the actual transforms. + handleDeadExits(); + foldTerminators(); + + if (!DeadLoopBlocks.empty()) { + LLVM_DEBUG(dbgs() << "Deleting " << DeadLoopBlocks.size() + << " dead blocks in loop " << L.getHeader()->getName() + << "\n"); + deleteDeadLoopBlocks(); + } + +#ifndef NDEBUG + // Make sure that we have preserved all data structures after the transform. + DT.verify(); + assert(DT.isReachableFromEntry(L.getHeader())); + LI.verify(DT); +#endif + + return true; + } +}; +} // namespace + +/// Turn branches and switches with known constant conditions into unconditional +/// branches. +static bool constantFoldTerminators(Loop &L, DominatorTree &DT, LoopInfo &LI, + ScalarEvolution &SE, + MemorySSAUpdater *MSSAU) { + if (!EnableTermFolding) + return false; + + // To keep things simple, only process loops with single latch. We + // canonicalize most loops to this form. We can support multi-latch if needed. + if (!L.getLoopLatch()) + return false; + + ConstantTerminatorFoldingImpl BranchFolder(L, LI, DT, SE, MSSAU); + return BranchFolder.run(); +} + +static bool mergeBlocksIntoPredecessors(Loop &L, DominatorTree &DT, + LoopInfo &LI, MemorySSAUpdater *MSSAU) { bool Changed = false; + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); // Copy blocks into a temporary array to avoid iterator invalidation issues // as we remove them. SmallVector<WeakTrackingVH, 16> Blocks(L.blocks()); @@ -57,19 +578,38 @@ static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, continue; // Merge Succ into Pred and delete it. - MergeBlockIntoPredecessor(Succ, &DT, &LI); + MergeBlockIntoPredecessor(Succ, &DTU, &LI, MSSAU); - SE.forgetLoop(&L); Changed = true; } return Changed; } +static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI, + ScalarEvolution &SE, MemorySSAUpdater *MSSAU) { + bool Changed = false; + + // Constant-fold terminators with known constant conditions. + Changed |= constantFoldTerminators(L, DT, LI, SE, MSSAU); + + // Eliminate unconditional branches by merging blocks into their predecessors. + Changed |= mergeBlocksIntoPredecessors(L, DT, LI, MSSAU); + + if (Changed) + SE.forgetTopmostLoop(&L); + + return Changed; +} + PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { - if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE)) + Optional<MemorySSAUpdater> MSSAU; + if (EnableMSSALoopDependency && AR.MSSA) + MSSAU = MemorySSAUpdater(AR.MSSA); + if (!simplifyLoopCFG(L, AR.DT, AR.LI, AR.SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) return PreservedAnalyses::all(); return getLoopPassPreservedAnalyses(); @@ -90,10 +630,22 @@ public: DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - return simplifyLoopCFG(*L, DT, LI, SE); + Optional<MemorySSAUpdater> MSSAU; + if (EnableMSSALoopDependency) { + MemorySSA *MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSAU = MemorySSAUpdater(MSSA); + if (VerifyMemorySSA) + MSSA->verifyMemorySSA(); + } + return simplifyLoopCFG(*L, DT, LI, SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); } void getAnalysisUsage(AnalysisUsage &AU) const override { + if (EnableMSSALoopDependency) { + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } AU.addPreserved<DependenceAnalysisWrapperPass>(); getLoopAnalysisUsage(AU); } @@ -104,6 +656,7 @@ char LoopSimplifyCFGLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", "Simplify loop CFG", false, false) INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_END(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", "Simplify loop CFG", false, false) diff --git a/lib/Transforms/Scalar/LoopSink.cpp b/lib/Transforms/Scalar/LoopSink.cpp index 760177c9c5e9..2f7ad2126ed3 100644 --- a/lib/Transforms/Scalar/LoopSink.cpp +++ b/lib/Transforms/Scalar/LoopSink.cpp @@ -152,6 +152,14 @@ findBBsToSinkInto(const Loop &L, const SmallPtrSetImpl<BasicBlock *> &UseBBs, } } + // Can't sink into blocks that have no valid insertion point. + for (BasicBlock *BB : BBsToSinkInto) { + if (BB->getFirstInsertionPt() == BB->end()) { + BBsToSinkInto.clear(); + break; + } + } + // If the total frequency of BBsToSinkInto is larger than preheader frequency, // do not sink. if (adjustedSumFreq(BBsToSinkInto, BFI) > @@ -194,17 +202,22 @@ static bool sinkInstruction(Loop &L, Instruction &I, if (BBsToSinkInto.empty()) return false; + // Return if any of the candidate blocks to sink into is non-cold. + if (BBsToSinkInto.size() > 1) { + for (auto *BB : BBsToSinkInto) + if (!LoopBlockNumber.count(BB)) + return false; + } + // Copy the final BBs into a vector and sort them using the total ordering // of the loop block numbers as iterating the set doesn't give a useful // order. No need to stable sort as the block numbers are a total ordering. SmallVector<BasicBlock *, 2> SortedBBsToSinkInto; SortedBBsToSinkInto.insert(SortedBBsToSinkInto.begin(), BBsToSinkInto.begin(), BBsToSinkInto.end()); - llvm::sort(SortedBBsToSinkInto.begin(), SortedBBsToSinkInto.end(), - [&](BasicBlock *A, BasicBlock *B) { - return LoopBlockNumber.find(A)->second < - LoopBlockNumber.find(B)->second; - }); + llvm::sort(SortedBBsToSinkInto, [&](BasicBlock *A, BasicBlock *B) { + return LoopBlockNumber.find(A)->second < LoopBlockNumber.find(B)->second; + }); BasicBlock *MoveBB = *SortedBBsToSinkInto.begin(); // FIXME: Optimize the efficiency for cloned value replacement. The current @@ -267,6 +280,7 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, // Compute alias set. for (BasicBlock *BB : L.blocks()) CurAST.add(*BB); + CurAST.add(*Preheader); // Sort loop's basic blocks by frequency SmallVector<BasicBlock *, 10> ColdLoopBBs; @@ -290,7 +304,7 @@ static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, // No need to check for instruction's operands are loop invariant. assert(L.hasLoopInvariantOperands(I) && "Insts in a loop's preheader should have loop invariant operands!"); - if (!canSinkOrHoistInst(*I, &AA, &DT, &L, &CurAST, nullptr)) + if (!canSinkOrHoistInst(*I, &AA, &DT, &L, &CurAST, nullptr, false)) continue; if (sinkInstruction(L, *I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI)) Changed = true; diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index fa83b48210bc..773ffb9df0a2 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -155,6 +155,11 @@ static cl::opt<bool> FilterSameScaledReg( cl::desc("Narrow LSR search space by filtering non-optimal formulae" " with the same ScaledReg and Scale")); +static cl::opt<unsigned> ComplexityLimit( + "lsr-complexity-limit", cl::Hidden, + cl::init(std::numeric_limits<uint16_t>::max()), + cl::desc("LSR search space complexity limit")); + #ifndef NDEBUG // Stress test IV chain generation. static cl::opt<bool> StressIVChain( @@ -1487,7 +1492,7 @@ bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const { SmallVector<const SCEV *, 4> Key = F.BaseRegs; if (F.ScaledReg) Key.push_back(F.ScaledReg); // Unstable sort by host order ok, because this is only used for uniquifying. - llvm::sort(Key.begin(), Key.end()); + llvm::sort(Key); return Uniquifier.count(Key); } @@ -1511,7 +1516,7 @@ bool LSRUse::InsertFormula(const Formula &F, const Loop &L) { SmallVector<const SCEV *, 4> Key = F.BaseRegs; if (F.ScaledReg) Key.push_back(F.ScaledReg); // Unstable sort by host order ok, because this is only used for uniquifying. - llvm::sort(Key.begin(), Key.end()); + llvm::sort(Key); if (!Uniquifier.insert(Key).second) return false; @@ -3638,32 +3643,62 @@ void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx, void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx, Formula Base) { // This method is only interesting on a plurality of registers. - if (Base.BaseRegs.size() + (Base.Scale == 1) <= 1) + if (Base.BaseRegs.size() + (Base.Scale == 1) + + (Base.UnfoldedOffset != 0) <= 1) return; // Flatten the representation, i.e., reg1 + 1*reg2 => reg1 + reg2, before // processing the formula. Base.unscale(); - Formula F = Base; - F.BaseRegs.clear(); SmallVector<const SCEV *, 4> Ops; + Formula NewBase = Base; + NewBase.BaseRegs.clear(); + Type *CombinedIntegerType = nullptr; for (const SCEV *BaseReg : Base.BaseRegs) { if (SE.properlyDominates(BaseReg, L->getHeader()) && - !SE.hasComputableLoopEvolution(BaseReg, L)) + !SE.hasComputableLoopEvolution(BaseReg, L)) { + if (!CombinedIntegerType) + CombinedIntegerType = SE.getEffectiveSCEVType(BaseReg->getType()); Ops.push_back(BaseReg); + } else - F.BaseRegs.push_back(BaseReg); + NewBase.BaseRegs.push_back(BaseReg); } - if (Ops.size() > 1) { - const SCEV *Sum = SE.getAddExpr(Ops); + + // If no register is relevant, we're done. + if (Ops.size() == 0) + return; + + // Utility function for generating the required variants of the combined + // registers. + auto GenerateFormula = [&](const SCEV *Sum) { + Formula F = NewBase; + // TODO: If Sum is zero, it probably means ScalarEvolution missed an // opportunity to fold something. For now, just ignore such cases // rather than proceed with zero in a register. - if (!Sum->isZero()) { - F.BaseRegs.push_back(Sum); - F.canonicalize(*L); - (void)InsertFormula(LU, LUIdx, F); - } + if (Sum->isZero()) + return; + + F.BaseRegs.push_back(Sum); + F.canonicalize(*L); + (void)InsertFormula(LU, LUIdx, F); + }; + + // If we collected at least two registers, generate a formula combining them. + if (Ops.size() > 1) { + SmallVector<const SCEV *, 4> OpsCopy(Ops); // Don't let SE modify Ops. + GenerateFormula(SE.getAddExpr(OpsCopy)); + } + + // If we have an unfolded offset, generate a formula combining it with the + // registers collected. + if (NewBase.UnfoldedOffset) { + assert(CombinedIntegerType && "Missing a type for the unfolded offset"); + Ops.push_back(SE.getConstant(CombinedIntegerType, NewBase.UnfoldedOffset, + true)); + NewBase.UnfoldedOffset = 0; + GenerateFormula(SE.getAddExpr(Ops)); } } @@ -4238,7 +4273,7 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { Key.push_back(F.ScaledReg); // Unstable sort by host order ok, because this is only used for // uniquifying. - llvm::sort(Key.begin(), Key.end()); + llvm::sort(Key); std::pair<BestFormulaeTy::const_iterator, bool> P = BestFormulae.insert(std::make_pair(Key, FIdx)); @@ -4281,9 +4316,6 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { }); } -// This is a rough guess that seems to work fairly well. -static const size_t ComplexityLimit = std::numeric_limits<uint16_t>::max(); - /// Estimate the worst-case number of solutions the solver might have to /// consider. It almost never considers this many solutions because it prune the /// search space, but the pruning isn't always sufficient. diff --git a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp index 86c99aed4417..da46210b6fdd 100644 --- a/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp @@ -56,6 +56,20 @@ using namespace llvm; #define DEBUG_TYPE "loop-unroll-and-jam" +/// @{ +/// Metadata attribute names +static const char *const LLVMLoopUnrollAndJamFollowupAll = + "llvm.loop.unroll_and_jam.followup_all"; +static const char *const LLVMLoopUnrollAndJamFollowupInner = + "llvm.loop.unroll_and_jam.followup_inner"; +static const char *const LLVMLoopUnrollAndJamFollowupOuter = + "llvm.loop.unroll_and_jam.followup_outer"; +static const char *const LLVMLoopUnrollAndJamFollowupRemainderInner = + "llvm.loop.unroll_and_jam.followup_remainder_inner"; +static const char *const LLVMLoopUnrollAndJamFollowupRemainderOuter = + "llvm.loop.unroll_and_jam.followup_remainder_outer"; +/// @} + static cl::opt<bool> AllowUnrollAndJam("allow-unroll-and-jam", cl::Hidden, cl::desc("Allows loops to be unroll-and-jammed.")); @@ -112,11 +126,6 @@ static bool HasUnrollAndJamEnablePragma(const Loop *L) { return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.enable"); } -// Returns true if the loop has an unroll_and_jam(disable) pragma. -static bool HasUnrollAndJamDisablePragma(const Loop *L) { - return GetUnrollMetadataForLoop(L, "llvm.loop.unroll_and_jam.disable"); -} - // If loop has an unroll_and_jam_count pragma return the (necessarily // positive) value from the pragma. Otherwise return 0. static unsigned UnrollAndJamCountPragmaValue(const Loop *L) { @@ -149,7 +158,26 @@ static bool computeUnrollAndJamCount( OptimizationRemarkEmitter *ORE, unsigned OuterTripCount, unsigned OuterTripMultiple, unsigned OuterLoopSize, unsigned InnerTripCount, unsigned InnerLoopSize, TargetTransformInfo::UnrollingPreferences &UP) { - // Check for explicit Count from the "unroll-and-jam-count" option. + // First up use computeUnrollCount from the loop unroller to get a count + // for unrolling the outer loop, plus any loops requiring explicit + // unrolling we leave to the unroller. This uses UP.Threshold / + // UP.PartialThreshold / UP.MaxCount to come up with sensible loop values. + // We have already checked that the loop has no unroll.* pragmas. + unsigned MaxTripCount = 0; + bool UseUpperBound = false; + bool ExplicitUnroll = computeUnrollCount( + L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount, + OuterTripMultiple, OuterLoopSize, UP, UseUpperBound); + if (ExplicitUnroll || UseUpperBound) { + // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it + // for the unroller instead. + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; explicit count set by " + "computeUnrollCount\n"); + UP.Count = 0; + return false; + } + + // Override with any explicit Count from the "unroll-and-jam-count" option. bool UserUnrollCount = UnrollAndJamCount.getNumOccurrences() > 0; if (UserUnrollCount) { UP.Count = UnrollAndJamCount; @@ -174,80 +202,76 @@ static bool computeUnrollAndJamCount( return true; } - // Use computeUnrollCount from the loop unroller to get a sensible count - // for the unrolling the outer loop. This uses UP.Threshold / - // UP.PartialThreshold / UP.MaxCount to come up with sensible loop values. - // We have already checked that the loop has no unroll.* pragmas. - unsigned MaxTripCount = 0; - bool UseUpperBound = false; - bool ExplicitUnroll = computeUnrollCount( - L, TTI, DT, LI, SE, EphValues, ORE, OuterTripCount, MaxTripCount, - OuterTripMultiple, OuterLoopSize, UP, UseUpperBound); - if (ExplicitUnroll || UseUpperBound) { - // If the user explicitly set the loop as unrolled, dont UnJ it. Leave it - // for the unroller instead. - UP.Count = 0; - return false; - } - bool PragmaEnableUnroll = HasUnrollAndJamEnablePragma(L); - ExplicitUnroll = PragmaCount > 0 || PragmaEnableUnroll || UserUnrollCount; + bool ExplicitUnrollAndJamCount = PragmaCount > 0 || UserUnrollCount; + bool ExplicitUnrollAndJam = PragmaEnableUnroll || ExplicitUnrollAndJamCount; // If the loop has an unrolling pragma, we want to be more aggressive with // unrolling limits. - if (ExplicitUnroll && OuterTripCount != 0) + if (ExplicitUnrollAndJam) UP.UnrollAndJamInnerLoopThreshold = PragmaUnrollAndJamThreshold; if (!UP.AllowRemainder && getUnrollAndJammedLoopSize(InnerLoopSize, UP) >= UP.UnrollAndJamInnerLoopThreshold) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't create remainder and " + "inner loop too large\n"); UP.Count = 0; return false; } + // We have a sensible limit for the outer loop, now adjust it for the inner + // loop and UP.UnrollAndJamInnerLoopThreshold. If the outer limit was set + // explicitly, we want to stick to it. + if (!ExplicitUnrollAndJamCount && UP.AllowRemainder) { + while (UP.Count != 0 && getUnrollAndJammedLoopSize(InnerLoopSize, UP) >= + UP.UnrollAndJamInnerLoopThreshold) + UP.Count--; + } + + // If we are explicitly unroll and jamming, we are done. Otherwise there are a + // number of extra performance heuristics to check. + if (ExplicitUnrollAndJam) + return true; + // If the inner loop count is known and small, leave the entire loop nest to // be the unroller - if (!ExplicitUnroll && InnerTripCount && - InnerLoopSize * InnerTripCount < UP.Threshold) { + if (InnerTripCount && InnerLoopSize * InnerTripCount < UP.Threshold) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; small inner loop count is " + "being left for the unroller\n"); UP.Count = 0; return false; } - // We have a sensible limit for the outer loop, now adjust it for the inner - // loop and UP.UnrollAndJamInnerLoopThreshold. - while (UP.Count != 0 && UP.AllowRemainder && - getUnrollAndJammedLoopSize(InnerLoopSize, UP) >= - UP.UnrollAndJamInnerLoopThreshold) - UP.Count--; - - if (!ExplicitUnroll) { - // Check for situations where UnJ is likely to be unprofitable. Including - // subloops with more than 1 block. - if (SubLoop->getBlocks().size() != 1) { - UP.Count = 0; - return false; - } + // Check for situations where UnJ is likely to be unprofitable. Including + // subloops with more than 1 block. + if (SubLoop->getBlocks().size() != 1) { + LLVM_DEBUG( + dbgs() << "Won't unroll-and-jam; More than one inner loop block\n"); + UP.Count = 0; + return false; + } - // Limit to loops where there is something to gain from unrolling and - // jamming the loop. In this case, look for loads that are invariant in the - // outer loop and can become shared. - unsigned NumInvariant = 0; - for (BasicBlock *BB : SubLoop->getBlocks()) { - for (Instruction &I : *BB) { - if (auto *Ld = dyn_cast<LoadInst>(&I)) { - Value *V = Ld->getPointerOperand(); - const SCEV *LSCEV = SE.getSCEVAtScope(V, L); - if (SE.isLoopInvariant(LSCEV, L)) - NumInvariant++; - } + // Limit to loops where there is something to gain from unrolling and + // jamming the loop. In this case, look for loads that are invariant in the + // outer loop and can become shared. + unsigned NumInvariant = 0; + for (BasicBlock *BB : SubLoop->getBlocks()) { + for (Instruction &I : *BB) { + if (auto *Ld = dyn_cast<LoadInst>(&I)) { + Value *V = Ld->getPointerOperand(); + const SCEV *LSCEV = SE.getSCEVAtScope(V, L); + if (SE.isLoopInvariant(LSCEV, L)) + NumInvariant++; } } - if (NumInvariant == 0) { - UP.Count = 0; - return false; - } + } + if (NumInvariant == 0) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; No loop invariant loads\n"); + UP.Count = 0; + return false; } - return ExplicitUnroll; + return false; } static LoopUnrollResult @@ -284,13 +308,16 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, << L->getHeader()->getParent()->getName() << "] Loop %" << L->getHeader()->getName() << "\n"); + TransformationMode EnableMode = hasUnrollAndJamTransformation(L); + if (EnableMode & TM_Disable) + return LoopUnrollResult::Unmodified; + // A loop with any unroll pragma (enabling/disabling/count/etc) is left for // the unroller, so long as it does not explicitly have unroll_and_jam // metadata. This means #pragma nounroll will disable unroll and jam as well // as unrolling - if (HasUnrollAndJamDisablePragma(L) || - (HasAnyUnrollPragma(L, "llvm.loop.unroll.") && - !HasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam."))) { + if (HasAnyUnrollPragma(L, "llvm.loop.unroll.") && + !HasAnyUnrollPragma(L, "llvm.loop.unroll_and_jam.")) { LLVM_DEBUG(dbgs() << " Disabled due to pragma.\n"); return LoopUnrollResult::Unmodified; } @@ -329,6 +356,19 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, return LoopUnrollResult::Unmodified; } + // Save original loop IDs for after the transformation. + MDNode *OrigOuterLoopID = L->getLoopID(); + MDNode *OrigSubLoopID = SubLoop->getLoopID(); + + // To assign the loop id of the epilogue, assign it before unrolling it so it + // is applied to every inner loop of the epilogue. We later apply the loop ID + // for the jammed inner loop. + Optional<MDNode *> NewInnerEpilogueLoopID = makeFollowupLoopID( + OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, + LLVMLoopUnrollAndJamFollowupRemainderInner}); + if (NewInnerEpilogueLoopID.hasValue()) + SubLoop->setLoopID(NewInnerEpilogueLoopID.getValue()); + // Find trip count and trip multiple unsigned OuterTripCount = SE.getSmallConstantTripCount(L, Latch); unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, Latch); @@ -344,9 +384,39 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, if (OuterTripCount && UP.Count > OuterTripCount) UP.Count = OuterTripCount; - LoopUnrollResult UnrollResult = - UnrollAndJamLoop(L, UP.Count, OuterTripCount, OuterTripMultiple, - UP.UnrollRemainder, LI, &SE, &DT, &AC, &ORE); + Loop *EpilogueOuterLoop = nullptr; + LoopUnrollResult UnrollResult = UnrollAndJamLoop( + L, UP.Count, OuterTripCount, OuterTripMultiple, UP.UnrollRemainder, LI, + &SE, &DT, &AC, &ORE, &EpilogueOuterLoop); + + // Assign new loop attributes. + if (EpilogueOuterLoop) { + Optional<MDNode *> NewOuterEpilogueLoopID = makeFollowupLoopID( + OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, + LLVMLoopUnrollAndJamFollowupRemainderOuter}); + if (NewOuterEpilogueLoopID.hasValue()) + EpilogueOuterLoop->setLoopID(NewOuterEpilogueLoopID.getValue()); + } + + Optional<MDNode *> NewInnerLoopID = + makeFollowupLoopID(OrigOuterLoopID, {LLVMLoopUnrollAndJamFollowupAll, + LLVMLoopUnrollAndJamFollowupInner}); + if (NewInnerLoopID.hasValue()) + SubLoop->setLoopID(NewInnerLoopID.getValue()); + else + SubLoop->setLoopID(OrigSubLoopID); + + if (UnrollResult == LoopUnrollResult::PartiallyUnrolled) { + Optional<MDNode *> NewOuterLoopID = makeFollowupLoopID( + OrigOuterLoopID, + {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupOuter}); + if (NewOuterLoopID.hasValue()) { + L->setLoopID(NewOuterLoopID.getValue()); + + // Do not setLoopAlreadyUnrolled if a followup was given. + return UnrollResult; + } + } // If loop has an unroll count pragma or unrolled by explicitly set count // mark loop as unrolled to prevent unrolling beyond that requested. diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index e955821effa0..38b80f48ed0e 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -540,7 +540,7 @@ static Optional<EstimatedUnrollCost> analyzeLoopUnrollCost( } } - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); // Add in the live successors by first checking whether we have terminator // that may be simplified based on the values simplified by this call. @@ -661,11 +661,6 @@ static bool HasUnrollEnablePragma(const Loop *L) { return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.enable"); } -// Returns true if the loop has an unroll(disable) pragma. -static bool HasUnrollDisablePragma(const Loop *L) { - return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.disable"); -} - // Returns true if the loop has an runtime unroll(disable) pragma. static bool HasRuntimeUnrollDisablePragma(const Loop *L) { return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable"); @@ -713,12 +708,19 @@ static uint64_t getUnrolledLoopSize( // Returns true if unroll count was set explicitly. // Calculates unroll count and writes it to UP.Count. +// Unless IgnoreUser is true, will also use metadata and command-line options +// that are specific to to the LoopUnroll pass (which, for instance, are +// irrelevant for the LoopUnrollAndJam pass). +// FIXME: This function is used by LoopUnroll and LoopUnrollAndJam, but consumes +// many LoopUnroll-specific options. The shared functionality should be +// refactored into it own function. bool llvm::computeUnrollCount( Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const SmallPtrSetImpl<const Value *> &EphValues, OptimizationRemarkEmitter *ORE, unsigned &TripCount, unsigned MaxTripCount, unsigned &TripMultiple, unsigned LoopSize, TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { + // Check for explicit Count. // 1st priority is unroll count set by "unroll-count" option. bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0; @@ -801,7 +803,7 @@ bool llvm::computeUnrollCount( } } - // 4th priority is loop peeling + // 4th priority is loop peeling. computePeelCount(L, LoopSize, UP, TripCount, SE); if (UP.PeelCount) { UP.Runtime = false; @@ -963,13 +965,15 @@ static LoopUnrollResult tryToUnrollLoop( Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, const TargetTransformInfo &TTI, AssumptionCache &AC, OptimizationRemarkEmitter &ORE, bool PreserveLCSSA, int OptLevel, - Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, - Optional<bool> ProvidedAllowPartial, Optional<bool> ProvidedRuntime, - Optional<bool> ProvidedUpperBound, Optional<bool> ProvidedAllowPeeling) { + bool OnlyWhenForced, Optional<unsigned> ProvidedCount, + Optional<unsigned> ProvidedThreshold, Optional<bool> ProvidedAllowPartial, + Optional<bool> ProvidedRuntime, Optional<bool> ProvidedUpperBound, + Optional<bool> ProvidedAllowPeeling) { LLVM_DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() << "] Loop %" << L->getHeader()->getName() << "\n"); - if (HasUnrollDisablePragma(L)) + TransformationMode TM = hasUnrollTransformation(L); + if (TM & TM_Disable) return LoopUnrollResult::Unmodified; if (!L->isLoopSimplifyForm()) { LLVM_DEBUG( @@ -977,6 +981,11 @@ static LoopUnrollResult tryToUnrollLoop( return LoopUnrollResult::Unmodified; } + // When automtatic unrolling is disabled, do not unroll unless overridden for + // this loop. + if (OnlyWhenForced && !(TM & TM_Enable)) + return LoopUnrollResult::Unmodified; + unsigned NumInlineCandidates; bool NotDuplicatable; bool Convergent; @@ -1066,14 +1075,39 @@ static LoopUnrollResult tryToUnrollLoop( if (TripCount && UP.Count > TripCount) UP.Count = TripCount; + // Save loop properties before it is transformed. + MDNode *OrigLoopID = L->getLoopID(); + // Unroll the loop. + Loop *RemainderLoop = nullptr; LoopUnrollResult UnrollResult = UnrollLoop( L, UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder, - LI, &SE, &DT, &AC, &ORE, PreserveLCSSA); + LI, &SE, &DT, &AC, &ORE, PreserveLCSSA, &RemainderLoop); if (UnrollResult == LoopUnrollResult::Unmodified) return LoopUnrollResult::Unmodified; + if (RemainderLoop) { + Optional<MDNode *> RemainderLoopID = + makeFollowupLoopID(OrigLoopID, {LLVMLoopUnrollFollowupAll, + LLVMLoopUnrollFollowupRemainder}); + if (RemainderLoopID.hasValue()) + RemainderLoop->setLoopID(RemainderLoopID.getValue()); + } + + if (UnrollResult != LoopUnrollResult::FullyUnrolled) { + Optional<MDNode *> NewLoopID = + makeFollowupLoopID(OrigLoopID, {LLVMLoopUnrollFollowupAll, + LLVMLoopUnrollFollowupUnrolled}); + if (NewLoopID.hasValue()) { + L->setLoopID(NewLoopID.getValue()); + + // Do not setLoopAlreadyUnrolled if loop attributes have been specified + // explicitly. + return UnrollResult; + } + } + // If loop has an unroll count pragma or unrolled by explicitly set count // mark loop as unrolled to prevent unrolling beyond that requested. // If the loop was peeled, we already "used up" the profile information @@ -1092,6 +1126,12 @@ public: static char ID; // Pass ID, replacement for typeid int OptLevel; + + /// If false, use a cost model to determine whether unrolling of a loop is + /// profitable. If true, only loops that explicitly request unrolling via + /// metadata are considered. All other loops are skipped. + bool OnlyWhenForced; + Optional<unsigned> ProvidedCount; Optional<unsigned> ProvidedThreshold; Optional<bool> ProvidedAllowPartial; @@ -1099,15 +1139,16 @@ public: Optional<bool> ProvidedUpperBound; Optional<bool> ProvidedAllowPeeling; - LoopUnroll(int OptLevel = 2, Optional<unsigned> Threshold = None, + LoopUnroll(int OptLevel = 2, bool OnlyWhenForced = false, + Optional<unsigned> Threshold = None, Optional<unsigned> Count = None, Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, Optional<bool> UpperBound = None, Optional<bool> AllowPeeling = None) - : LoopPass(ID), OptLevel(OptLevel), ProvidedCount(std::move(Count)), - ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), - ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound), - ProvidedAllowPeeling(AllowPeeling) { + : LoopPass(ID), OptLevel(OptLevel), OnlyWhenForced(OnlyWhenForced), + ProvidedCount(std::move(Count)), ProvidedThreshold(Threshold), + ProvidedAllowPartial(AllowPartial), ProvidedRuntime(Runtime), + ProvidedUpperBound(UpperBound), ProvidedAllowPeeling(AllowPeeling) { initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); } @@ -1130,8 +1171,8 @@ public: bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); LoopUnrollResult Result = tryToUnrollLoop( - L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel, ProvidedCount, - ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime, + L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel, OnlyWhenForced, + ProvidedCount, ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, ProvidedAllowPeeling); if (Result == LoopUnrollResult::FullyUnrolled) @@ -1161,14 +1202,16 @@ INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) -Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count, - int AllowPartial, int Runtime, int UpperBound, +Pass *llvm::createLoopUnrollPass(int OptLevel, bool OnlyWhenForced, + int Threshold, int Count, int AllowPartial, + int Runtime, int UpperBound, int AllowPeeling) { // TODO: It would make more sense for this function to take the optionals // directly, but that's dangerous since it would silently break out of tree // callers. return new LoopUnroll( - OptLevel, Threshold == -1 ? None : Optional<unsigned>(Threshold), + OptLevel, OnlyWhenForced, + Threshold == -1 ? None : Optional<unsigned>(Threshold), Count == -1 ? None : Optional<unsigned>(Count), AllowPartial == -1 ? None : Optional<bool>(AllowPartial), Runtime == -1 ? None : Optional<bool>(Runtime), @@ -1176,8 +1219,8 @@ Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count, AllowPeeling == -1 ? None : Optional<bool>(AllowPeeling)); } -Pass *llvm::createSimpleLoopUnrollPass(int OptLevel) { - return createLoopUnrollPass(OptLevel, -1, -1, 0, 0, 0, 0); +Pass *llvm::createSimpleLoopUnrollPass(int OptLevel, bool OnlyWhenForced) { + return createLoopUnrollPass(OptLevel, OnlyWhenForced, -1, -1, 0, 0, 0, 0); } PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, @@ -1207,7 +1250,8 @@ PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, bool Changed = tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, - /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, + /*PreserveLCSSA*/ true, OptLevel, OnlyWhenForced, + /*Count*/ None, /*Threshold*/ None, /*AllowPartial*/ false, /*Runtime*/ false, /*UpperBound*/ false, /*AllowPeeling*/ false) != LoopUnrollResult::Unmodified; @@ -1333,23 +1377,21 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, Loop *ParentL = L.getParentLoop(); #endif - // The API here is quite complex to call, but there are only two interesting - // states we support: partial and full (or "simple") unrolling. However, to - // enable these things we actually pass "None" in for the optional to avoid - // providing an explicit choice. - Optional<bool> AllowPartialParam, RuntimeParam, UpperBoundParam, - AllowPeeling; // Check if the profile summary indicates that the profiled application // has a huge working set size, in which case we disable peeling to avoid // bloating it further. + Optional<bool> LocalAllowPeeling = UnrollOpts.AllowPeeling; if (PSI && PSI->hasHugeWorkingSetSize()) - AllowPeeling = false; + LocalAllowPeeling = false; std::string LoopName = L.getName(); - LoopUnrollResult Result = - tryToUnrollLoop(&L, DT, &LI, SE, TTI, AC, ORE, - /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, - /*Threshold*/ None, AllowPartialParam, RuntimeParam, - UpperBoundParam, AllowPeeling); + // The API here is quite complex to call and we allow to select some + // flavors of unrolling during construction time (by setting UnrollOpts). + LoopUnrollResult Result = tryToUnrollLoop( + &L, DT, &LI, SE, TTI, AC, ORE, + /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, UnrollOpts.OnlyWhenForced, + /*Count*/ None, + /*Threshold*/ None, UnrollOpts.AllowPartial, UnrollOpts.AllowRuntime, + UnrollOpts.AllowUpperBound, LocalAllowPeeling); Changed |= Result != LoopUnrollResult::Unmodified; // The parent must not be damaged by unrolling! diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index 6aad077ff19e..4a089dfa7dbf 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -28,18 +28,19 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LegacyDivergenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" @@ -65,8 +66,10 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> @@ -180,11 +183,13 @@ namespace { Loop *currentLoop = nullptr; DominatorTree *DT = nullptr; + MemorySSA *MSSA = nullptr; + std::unique_ptr<MemorySSAUpdater> MSSAU; BasicBlock *loopHeader = nullptr; BasicBlock *loopPreheader = nullptr; bool SanitizeMemory; - LoopSafetyInfo SafetyInfo; + SimpleLoopSafetyInfo SafetyInfo; // LoopBlocks contains all of the basic blocks of the loop, including the // preheader of the loop, the body of the loop, and the exit blocks of the @@ -214,8 +219,12 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + if (EnableMSSALoopDependency) { + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } if (hasBranchDivergence) - AU.addRequired<DivergenceAnalysis>(); + AU.addRequired<LegacyDivergenceAnalysis>(); getLoopAnalysisUsage(AU); } @@ -237,11 +246,11 @@ namespace { bool TryTrivialLoopUnswitch(bool &Changed); bool UnswitchIfProfitable(Value *LoopCond, Constant *Val, - TerminatorInst *TI = nullptr); + Instruction *TI = nullptr); void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, - BasicBlock *ExitBlock, TerminatorInst *TI); + BasicBlock *ExitBlock, Instruction *TI); void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L, - TerminatorInst *TI); + Instruction *TI); void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, Constant *Val, bool isEqual); @@ -249,8 +258,7 @@ namespace { void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, BasicBlock *TrueDest, BasicBlock *FalseDest, - BranchInst *OldBranch, - TerminatorInst *TI); + BranchInst *OldBranch, Instruction *TI); void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); @@ -383,7 +391,8 @@ INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) @@ -515,20 +524,33 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); LPM = &LPM_Ref; DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + if (EnableMSSALoopDependency) { + MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSAU = make_unique<MemorySSAUpdater>(MSSA); + assert(DT && "Cannot update MemorySSA without a valid DomTree."); + } currentLoop = L; Function *F = currentLoop->getHeader()->getParent(); SanitizeMemory = F->hasFnAttribute(Attribute::SanitizeMemory); if (SanitizeMemory) - computeLoopSafetyInfo(&SafetyInfo, L); + SafetyInfo.computeLoopSafetyInfo(L); + + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); bool Changed = false; do { assert(currentLoop->isLCSSAForm(*DT)); + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); redoLoop = false; Changed |= processCurrentLoop(); } while(redoLoop); + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); + return Changed; } @@ -690,7 +712,7 @@ bool LoopUnswitch::processCurrentLoop() { // loop. for (Loop::block_iterator I = currentLoop->block_begin(), E = currentLoop->block_end(); I != E; ++I) { - TerminatorInst *TI = (*I)->getTerminator(); + Instruction *TI = (*I)->getTerminator(); // Unswitching on a potentially uninitialized predicate is not // MSan-friendly. Limit this to the cases when the original predicate is @@ -699,7 +721,7 @@ bool LoopUnswitch::processCurrentLoop() { // This is a workaround for the discrepancy between LLVM IR and MSan // semantics. See PR28054 for more details. if (SanitizeMemory && - !isGuaranteedToExecute(*TI, DT, currentLoop, &SafetyInfo)) + !SafetyInfo.isGuaranteedToExecute(*TI, DT, currentLoop)) continue; if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { @@ -853,7 +875,7 @@ static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) { /// simplify the loop. If we decide that this is profitable, /// unswitch the loop, reprocess the pieces, then return true. bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, - TerminatorInst *TI) { + Instruction *TI) { // Check to see if it would be profitable to unswitch current loop. if (!BranchesInfo.CostAllowsUnswitching()) { LLVM_DEBUG(dbgs() << "NOT unswitching loop %" @@ -864,7 +886,7 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, return false; } if (hasBranchDivergence && - getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) { + getAnalysis<LegacyDivergenceAnalysis>().isDivergent(LoopCond)) { LLVM_DEBUG(dbgs() << "NOT unswitching loop %" << currentLoop->getHeader()->getName() << " at non-trivial condition '" << *Val @@ -908,7 +930,7 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, BasicBlock *TrueDest, BasicBlock *FalseDest, BranchInst *OldBranch, - TerminatorInst *TI) { + Instruction *TI) { assert(OldBranch->isUnconditional() && "Preheader is not split correctly"); assert(TrueDest != FalseDest && "Branch targets should be different"); // Insert a conditional branch on LIC to the two preheaders. The original @@ -952,13 +974,16 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, if (OldBranchSucc != TrueDest && OldBranchSucc != FalseDest) { Updates.push_back({DominatorTree::Delete, OldBranchParent, OldBranchSucc}); } - DT->applyUpdates(Updates); + + if (MSSAU) + MSSAU->applyUpdates(Updates, *DT); } // If either edge is critical, split it. This helps preserve LoopSimplify // form for enclosing loops. - auto Options = CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA(); + auto Options = + CriticalEdgeSplittingOptions(DT, LI, MSSAU.get()).setPreserveLCSSA(); SplitCriticalEdge(BI, 0, Options); SplitCriticalEdge(BI, 1, Options); } @@ -970,7 +995,7 @@ void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, /// outside of the loop and updating loop info. void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, BasicBlock *ExitBlock, - TerminatorInst *TI) { + Instruction *TI) { LLVM_DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %" << loopHeader->getName() << " [" << L->getBlocks().size() << " blocks] in Function " @@ -984,7 +1009,7 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, // First step, split the preheader, so that we know that there is a safe place // to insert the conditional branch. We will change loopPreheader to have a // conditional branch on Cond. - BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, DT, LI); + BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get()); // Now that we have a place to insert the conditional branch, create a place // to branch to: this is the exit block out of the loop that we should @@ -995,7 +1020,8 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, // without actually branching to it (the exit block should be dominated by the // loop header, not the preheader). assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); - BasicBlock *NewExit = SplitBlock(ExitBlock, &ExitBlock->front(), DT, LI); + BasicBlock *NewExit = + SplitBlock(ExitBlock, &ExitBlock->front(), DT, LI, MSSAU.get()); // Okay, now we have a position to branch from and a position to branch to, // insert the new conditional branch. @@ -1015,6 +1041,7 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, // particular value, rewrite the loop with this info. We know that this will // at least eliminate the old branch. RewriteLoopBodyWithConditionConstant(L, Cond, Val, false); + ++NumTrivial; } @@ -1026,7 +1053,7 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, /// condition. bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { BasicBlock *CurrentBB = currentLoop->getHeader(); - TerminatorInst *CurrentTerm = CurrentBB->getTerminator(); + Instruction *CurrentTerm = CurrentBB->getTerminator(); LLVMContext &Context = CurrentBB->getContext(); // If loop header has only one reachable successor (currently via an @@ -1190,7 +1217,7 @@ void LoopUnswitch::SplitExitEdges(Loop *L, // Although SplitBlockPredecessors doesn't preserve loop-simplify in // general, if we call it on all predecessors of all exits then it does. - SplitBlockPredecessors(ExitBlock, Preds, ".us-lcssa", DT, LI, + SplitBlockPredecessors(ExitBlock, Preds, ".us-lcssa", DT, LI, MSSAU.get(), /*PreserveLCSSA*/ true); } } @@ -1199,7 +1226,7 @@ void LoopUnswitch::SplitExitEdges(Loop *L, /// Split it into loop versions and test the condition outside of either loop. /// Return the loops created as Out1/Out2. void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, - Loop *L, TerminatorInst *TI) { + Loop *L, Instruction *TI) { Function *F = loopHeader->getParent(); LLVM_DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" << loopHeader->getName() << " [" << L->getBlocks().size() @@ -1216,7 +1243,8 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, // First step, split the preheader and exit blocks, and add these blocks to // the LoopBlocks list. - BasicBlock *NewPreheader = SplitEdge(loopPreheader, loopHeader, DT, LI); + BasicBlock *NewPreheader = + SplitEdge(loopPreheader, loopHeader, DT, LI, MSSAU.get()); LoopBlocks.push_back(NewPreheader); // We want the loop to come after the preheader, but before the exit blocks. @@ -1318,10 +1346,24 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] && "Preheader splitting did not work correctly!"); + if (MSSAU) { + // Update MemorySSA after cloning, and before splitting to unreachables, + // since that invalidates the 1:1 mapping of clones in VMap. + LoopBlocksRPO LBRPO(L); + LBRPO.perform(LI); + MSSAU->updateForClonedLoop(LBRPO, ExitBlocks, VMap); + } + // Emit the new branch that selects between the two versions of this loop. EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR, TI); LPM->deleteSimpleAnalysisValue(OldBR, L); + if (MSSAU) { + // Update MemoryPhis in Exit blocks. + MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMap, *DT); + if (VerifyMemorySSA) + MSSA->verifyMemorySSA(); + } // The OldBr was replaced by a new one and removed (but not erased) by // EmitPreheaderBranchOnCondition. It is no longer needed, so delete it. @@ -1347,6 +1389,9 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop && LICHandle && !isa<Constant>(LICHandle)) RewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, true); + + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); } /// Remove all instances of I from the worklist vector specified. @@ -1485,7 +1530,7 @@ void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, // and hooked up so as to preserve the loop structure, because // trying to update it is complicated. So instead we preserve the // loop structure and put the block on a dead code path. - SplitEdge(Switch, SISucc, DT, LI); + SplitEdge(Switch, SISucc, DT, LI, MSSAU.get()); // Compute the successors instead of relying on the return value // of SplitEdge, since it may have split the switch successor // after PHI nodes. @@ -1539,6 +1584,8 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { Worklist.push_back(Use); LPM->deleteSimpleAnalysisValue(I, L); RemoveFromWorklist(I, Worklist); + if (MSSAU) + MSSAU->removeMemoryAccess(I); I->eraseFromParent(); ++NumSimplify; continue; @@ -1578,6 +1625,8 @@ void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { // Move all of the successor contents from Succ to Pred. Pred->getInstList().splice(BI->getIterator(), Succ->getInstList(), Succ->begin(), Succ->end()); + if (MSSAU) + MSSAU->moveAllAfterMergeBlocks(Succ, Pred, BI); LPM->deleteSimpleAnalysisValue(BI, L); RemoveFromWorklist(BI, Worklist); BI->eraseFromParent(); diff --git a/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/lib/Transforms/Scalar/LoopVersioningLICM.cpp index 06e86081e8a0..83861b98fbd8 100644 --- a/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -360,10 +360,11 @@ bool LoopVersioningLICM::legalLoopMemoryAccesses() { bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { assert(I != nullptr && "Null instruction found!"); // Check function call safety - if (isa<CallInst>(I) && !AA->doesNotAccessMemory(CallSite(I))) { - LLVM_DEBUG(dbgs() << " Unsafe call site found.\n"); - return false; - } + if (auto *Call = dyn_cast<CallBase>(I)) + if (!AA->doesNotAccessMemory(Call)) { + LLVM_DEBUG(dbgs() << " Unsafe call site found.\n"); + return false; + } // Avoid loops with possiblity of throw if (I->mayThrow()) { LLVM_DEBUG(dbgs() << " May throw instruction found in loop body\n"); @@ -594,6 +595,11 @@ bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { if (skipLoop(L)) return false; + + // Do not do the transformation if disabled by metadata. + if (hasLICMVersioningTransformation(L) & TM_Disable) + return false; + // Get Analysis information. AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); @@ -628,6 +634,8 @@ bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { // Set Loop Versioning metaData for version loop. addStringMetadataToLoop(LVer.getVersionedLoop(), LICMVersioningMetaData); // Set "llvm.mem.parallel_loop_access" metaData to versioned loop. + // FIXME: "llvm.mem.parallel_loop_access" annotates memory access + // instructions, not loops. addStringMetadataToLoop(LVer.getVersionedLoop(), "llvm.mem.parallel_loop_access"); // Update version loop with aggressive aliasing assumption. diff --git a/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp index 070114a84cc5..4867b33d671f 100644 --- a/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp +++ b/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -15,25 +15,19 @@ #include "llvm/Transforms/Scalar/LowerGuardIntrinsic.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/GuardUtils.h" using namespace llvm; -static cl::opt<uint32_t> PredicatePassBranchWeight( - "guards-predicate-pass-branch-weight", cl::Hidden, cl::init(1 << 20), - cl::desc("The probability of a guard failing is assumed to be the " - "reciprocal of this value (default = 1 << 20)")); - namespace { struct LowerGuardIntrinsicLegacyPass : public FunctionPass { static char ID; @@ -46,45 +40,6 @@ struct LowerGuardIntrinsicLegacyPass : public FunctionPass { }; } -static void MakeGuardControlFlowExplicit(Function *DeoptIntrinsic, - CallInst *CI) { - OperandBundleDef DeoptOB(*CI->getOperandBundle(LLVMContext::OB_deopt)); - SmallVector<Value *, 4> Args(std::next(CI->arg_begin()), CI->arg_end()); - - auto *CheckBB = CI->getParent(); - auto *DeoptBlockTerm = - SplitBlockAndInsertIfThen(CI->getArgOperand(0), CI, true); - - auto *CheckBI = cast<BranchInst>(CheckBB->getTerminator()); - - // SplitBlockAndInsertIfThen inserts control flow that branches to - // DeoptBlockTerm if the condition is true. We want the opposite. - CheckBI->swapSuccessors(); - - CheckBI->getSuccessor(0)->setName("guarded"); - CheckBI->getSuccessor(1)->setName("deopt"); - - if (auto *MD = CI->getMetadata(LLVMContext::MD_make_implicit)) - CheckBI->setMetadata(LLVMContext::MD_make_implicit, MD); - - MDBuilder MDB(CI->getContext()); - CheckBI->setMetadata(LLVMContext::MD_prof, - MDB.createBranchWeights(PredicatePassBranchWeight, 1)); - - IRBuilder<> B(DeoptBlockTerm); - auto *DeoptCall = B.CreateCall(DeoptIntrinsic, Args, {DeoptOB}, ""); - - if (DeoptIntrinsic->getReturnType()->isVoidTy()) { - B.CreateRetVoid(); - } else { - DeoptCall->setName("deoptcall"); - B.CreateRet(DeoptCall); - } - - DeoptCall->setCallingConv(CI->getCallingConv()); - DeoptBlockTerm->eraseFromParent(); -} - static bool lowerGuardIntrinsic(Function &F) { // Check if we can cheaply rule out the possibility of not having any work to // do. @@ -95,10 +50,8 @@ static bool lowerGuardIntrinsic(Function &F) { SmallVector<CallInst *, 8> ToLower; for (auto &I : instructions(F)) - if (auto *CI = dyn_cast<CallInst>(&I)) - if (auto *F = CI->getCalledFunction()) - if (F->getIntrinsicID() == Intrinsic::experimental_guard) - ToLower.push_back(CI); + if (isGuard(&I)) + ToLower.push_back(cast<CallInst>(&I)); if (ToLower.empty()) return false; @@ -108,7 +61,7 @@ static bool lowerGuardIntrinsic(Function &F) { DeoptIntrinsic->setCallingConv(GuardDecl->getCallingConv()); for (auto *CI : ToLower) { - MakeGuardControlFlowExplicit(DeoptIntrinsic, CI); + makeGuardControlFlowExplicit(DeoptIntrinsic, CI); CI->eraseFromParent(); } diff --git a/lib/Transforms/Scalar/MakeGuardsExplicit.cpp b/lib/Transforms/Scalar/MakeGuardsExplicit.cpp new file mode 100644 index 000000000000..1ba3994eba0e --- /dev/null +++ b/lib/Transforms/Scalar/MakeGuardsExplicit.cpp @@ -0,0 +1,120 @@ +//===- MakeGuardsExplicit.cpp - Turn guard intrinsics into guard branches -===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass lowers the @llvm.experimental.guard intrinsic to the new form of +// guard represented as widenable explicit branch to the deopt block. The +// difference between this pass and LowerGuardIntrinsic is that after this pass +// the guard represented as intrinsic: +// +// call void(i1, ...) @llvm.experimental.guard(i1 %old_cond) [ "deopt"() ] +// +// transforms to a guard represented as widenable explicit branch: +// +// %widenable_cond = call i1 @llvm.experimental.widenable.condition() +// br i1 (%old_cond & %widenable_cond), label %guarded, label %deopt +// +// Here: +// - The semantics of @llvm.experimental.widenable.condition allows to replace +// %widenable_cond with the construction (%widenable_cond & %any_other_cond) +// without loss of correctness; +// - %guarded is the lower part of old guard intrinsic's parent block split by +// the intrinsic call; +// - %deopt is a block containing a sole call to @llvm.experimental.deoptimize +// intrinsic. +// +// Therefore, this branch preserves the property of widenability. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/MakeGuardsExplicit.h" +#include "llvm/Analysis/GuardUtils.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/GuardUtils.h" + +using namespace llvm; + +namespace { +struct MakeGuardsExplicitLegacyPass : public FunctionPass { + static char ID; + MakeGuardsExplicitLegacyPass() : FunctionPass(ID) { + initializeMakeGuardsExplicitLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; +}; +} + +static void turnToExplicitForm(CallInst *Guard, Function *DeoptIntrinsic) { + // Replace the guard with an explicit branch (just like in GuardWidening). + BasicBlock *BB = Guard->getParent(); + makeGuardControlFlowExplicit(DeoptIntrinsic, Guard); + BranchInst *ExplicitGuard = cast<BranchInst>(BB->getTerminator()); + assert(ExplicitGuard->isConditional() && "Must be!"); + + // We want the guard to be expressed as explicit control flow, but still be + // widenable. For that, we add Widenable Condition intrinsic call to the + // guard's condition. + IRBuilder<> B(ExplicitGuard); + auto *WidenableCondition = + B.CreateIntrinsic(Intrinsic::experimental_widenable_condition, + {}, {}, nullptr, "widenable_cond"); + WidenableCondition->setCallingConv(Guard->getCallingConv()); + auto *NewCond = + B.CreateAnd(ExplicitGuard->getCondition(), WidenableCondition); + NewCond->setName("exiplicit_guard_cond"); + ExplicitGuard->setCondition(NewCond); + Guard->eraseFromParent(); +} + +static bool explicifyGuards(Function &F) { + // Check if we can cheaply rule out the possibility of not having any work to + // do. + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + if (!GuardDecl || GuardDecl->use_empty()) + return false; + + SmallVector<CallInst *, 8> GuardIntrinsics; + for (auto &I : instructions(F)) + if (isGuard(&I)) + GuardIntrinsics.push_back(cast<CallInst>(&I)); + + if (GuardIntrinsics.empty()) + return false; + + auto *DeoptIntrinsic = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::experimental_deoptimize, {F.getReturnType()}); + DeoptIntrinsic->setCallingConv(GuardDecl->getCallingConv()); + + for (auto *Guard : GuardIntrinsics) + turnToExplicitForm(Guard, DeoptIntrinsic); + + return true; +} + +bool MakeGuardsExplicitLegacyPass::runOnFunction(Function &F) { + return explicifyGuards(F); +} + +char MakeGuardsExplicitLegacyPass::ID = 0; +INITIALIZE_PASS(MakeGuardsExplicitLegacyPass, "make-guards-explicit", + "Lower the guard intrinsic to explicit control flow form", + false, false) + +PreservedAnalyses MakeGuardsExplicitPass::run(Function &F, + FunctionAnalysisManager &) { + if (explicifyGuards(F)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 3b74421a47a0..ced923d6973d 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -398,7 +398,7 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, MemsetRanges Ranges(DL); BasicBlock::iterator BI(StartInst); - for (++BI; !isa<TerminatorInst>(BI); ++BI) { + for (++BI; !BI->isTerminator(); ++BI) { if (!isa<StoreInst>(BI) && !isa<MemSetInst>(BI)) { // If the instruction is readnone, ignore it, otherwise bail out. We // don't even allow readonly here because we don't want something like: @@ -413,7 +413,10 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, if (!NextStore->isSimple()) break; // Check to see if this stored value is of the same byte-splattable value. - if (ByteVal != isBytewiseValue(NextStore->getOperand(0))) + Value *StoredByte = isBytewiseValue(NextStore->getOperand(0)); + if (isa<UndefValue>(ByteVal) && StoredByte) + ByteVal = StoredByte; + if (ByteVal != StoredByte) break; // Check to see if this store is to a constant offset from the start ptr. @@ -543,8 +546,8 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, // Memory locations of lifted instructions. SmallVector<MemoryLocation, 8> MemLocs{StoreLoc}; - // Lifted callsites. - SmallVector<ImmutableCallSite, 8> CallSites; + // Lifted calls. + SmallVector<const CallBase *, 8> Calls; const MemoryLocation LoadLoc = MemoryLocation::get(LI); @@ -562,10 +565,9 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, }); if (!NeedLift) - NeedLift = - llvm::any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { - return isModOrRefSet(AA.getModRefInfo(C, CS)); - }); + NeedLift = llvm::any_of(Calls, [C, &AA](const CallBase *Call) { + return isModOrRefSet(AA.getModRefInfo(C, Call)); + }); } if (!NeedLift) @@ -576,12 +578,12 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, // none of them may modify its source. if (isModSet(AA.getModRefInfo(C, LoadLoc))) return false; - else if (auto CS = ImmutableCallSite(C)) { + else if (const auto *Call = dyn_cast<CallBase>(C)) { // If we can't lift this before P, it's game over. - if (isModOrRefSet(AA.getModRefInfo(P, CS))) + if (isModOrRefSet(AA.getModRefInfo(P, Call))) return false; - CallSites.push_back(CS); + Calls.push_back(Call); } else if (isa<LoadInst>(C) || isa<StoreInst>(C) || isa<VAArgInst>(C)) { // If we can't lift this before P, it's game over. auto ML = MemoryLocation::get(C); @@ -672,13 +674,11 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (UseMemMove) M = Builder.CreateMemMove( SI->getPointerOperand(), findStoreAlignment(DL, SI), - LI->getPointerOperand(), findLoadAlignment(DL, LI), Size, - SI->isVolatile()); + LI->getPointerOperand(), findLoadAlignment(DL, LI), Size); else M = Builder.CreateMemCpy( SI->getPointerOperand(), findStoreAlignment(DL, SI), - LI->getPointerOperand(), findLoadAlignment(DL, LI), Size, - SI->isVolatile()); + LI->getPointerOperand(), findLoadAlignment(DL, LI), Size); LLVM_DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI << " => " << *M << "\n"); @@ -767,8 +767,8 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { if (!Align) Align = DL.getABITypeAlignment(T); IRBuilder<> Builder(SI); - auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, - Size, Align, SI->isVolatile()); + auto *M = + Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, Size, Align); LLVM_DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); @@ -916,8 +916,7 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, continue; } if (const IntrinsicInst *IT = dyn_cast<IntrinsicInst>(U)) - if (IT->getIntrinsicID() == Intrinsic::lifetime_start || - IT->getIntrinsicID() == Intrinsic::lifetime_end) + if (IT->isLifetimeStartOrEnd()) continue; if (U != C && U != cpy) @@ -942,10 +941,10 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, // the use analysis, we also need to know that it does not sneakily // access dest. We rely on AA to figure this out for us. AliasAnalysis &AA = LookupAliasAnalysis(); - ModRefInfo MR = AA.getModRefInfo(C, cpyDest, srcSize); + ModRefInfo MR = AA.getModRefInfo(C, cpyDest, LocationSize::precise(srcSize)); // If necessary, perform additional analysis. if (isModOrRefSet(MR)) - MR = AA.callCapturesBefore(C, cpyDest, srcSize, &DT); + MR = AA.callCapturesBefore(C, cpyDest, LocationSize::precise(srcSize), &DT); if (isModOrRefSet(MR)) return false; @@ -993,8 +992,9 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, // handled here, but combineMetadata doesn't support them yet unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, LLVMContext::MD_noalias, - LLVMContext::MD_invariant_group}; - combineMetadata(C, cpy, KnownIDs); + LLVMContext::MD_invariant_group, + LLVMContext::MD_access_group}; + combineMetadata(C, cpy, KnownIDs, true); // Remove the memcpy. MD->removeInstruction(cpy); @@ -1056,6 +1056,8 @@ bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, UseMemMove = true; // If all checks passed, then we can transform M. + LLVM_DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy->memcpy src:\n" + << *MDep << '\n' << *M << '\n'); // TODO: Is this worth it if we're creating a less aligned memcpy? For // example we could be moving from movaps -> movq on x86. @@ -1141,6 +1143,21 @@ bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, return true; } +/// Determine whether the instruction has undefined content for the given Size, +/// either because it was freshly alloca'd or started its lifetime. +static bool hasUndefContents(Instruction *I, ConstantInt *Size) { + if (isa<AllocaInst>(I)) + return true; + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + if (ConstantInt *LTSize = dyn_cast<ConstantInt>(II->getArgOperand(0))) + if (LTSize->getZExtValue() >= Size->getZExtValue()) + return true; + + return false; +} + /// Transform memcpy to memset when its source was just memset. /// In other words, turn: /// \code @@ -1164,12 +1181,27 @@ bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, if (!AA.isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource())) return false; - ConstantInt *CopySize = cast<ConstantInt>(MemCpy->getLength()); + // A known memset size is required. ConstantInt *MemSetSize = dyn_cast<ConstantInt>(MemSet->getLength()); + if (!MemSetSize) + return false; + // Make sure the memcpy doesn't read any more than what the memset wrote. // Don't worry about sizes larger than i64. - if (!MemSetSize || CopySize->getZExtValue() > MemSetSize->getZExtValue()) - return false; + ConstantInt *CopySize = cast<ConstantInt>(MemCpy->getLength()); + if (CopySize->getZExtValue() > MemSetSize->getZExtValue()) { + // If the memcpy is larger than the memset, but the memory was undef prior + // to the memset, we can just ignore the tail. Technically we're only + // interested in the bytes from MemSetSize..CopySize here, but as we can't + // easily represent this location, we use the full 0..CopySize range. + MemoryLocation MemCpyLoc = MemoryLocation::getForSource(MemCpy); + MemDepResult DepInfo = MD->getPointerDependencyFrom( + MemCpyLoc, true, MemSet->getIterator(), MemSet->getParent()); + if (DepInfo.isDef() && hasUndefContents(DepInfo.getInst(), CopySize)) + CopySize = MemSetSize; + else + return false; + } IRBuilder<> Builder(MemCpy); Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), @@ -1249,19 +1281,7 @@ bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { if (MemCpyInst *MDep = dyn_cast<MemCpyInst>(SrcDepInfo.getInst())) return processMemCpyMemCpyDependence(M, MDep); } else if (SrcDepInfo.isDef()) { - Instruction *I = SrcDepInfo.getInst(); - bool hasUndefContents = false; - - if (isa<AllocaInst>(I)) { - hasUndefContents = true; - } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { - if (II->getIntrinsicID() == Intrinsic::lifetime_start) - if (ConstantInt *LTSize = dyn_cast<ConstantInt>(II->getArgOperand(0))) - if (LTSize->getZExtValue() >= CopySize->getZExtValue()) - hasUndefContents = true; - } - - if (hasUndefContents) { + if (hasUndefContents(SrcDepInfo.getInst(), CopySize)) { MD->removeInstruction(M); M->eraseFromParent(); ++NumMemCpyInstr; @@ -1320,7 +1340,7 @@ bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { Type *ByValTy = cast<PointerType>(ByValArg->getType())->getElementType(); uint64_t ByValSize = DL.getTypeAllocSize(ByValTy); MemDepResult DepInfo = MD->getPointerDependencyFrom( - MemoryLocation(ByValArg, ByValSize), true, + MemoryLocation(ByValArg, LocationSize::precise(ByValSize)), true, CS.getInstruction()->getIterator(), CS.getInstruction()->getParent()); if (!DepInfo.isClobber()) return false; diff --git a/lib/Transforms/Scalar/MergeICmps.cpp b/lib/Transforms/Scalar/MergeICmps.cpp index ff0183a8ea2d..69fd8b163a07 100644 --- a/lib/Transforms/Scalar/MergeICmps.cpp +++ b/lib/Transforms/Scalar/MergeICmps.cpp @@ -41,6 +41,15 @@ namespace { #define DEBUG_TYPE "mergeicmps" +// Returns true if the instruction is a simple load or a simple store +static bool isSimpleLoadOrStore(const Instruction *I) { + if (const LoadInst *LI = dyn_cast<LoadInst>(I)) + return LI->isSimple(); + if (const StoreInst *SI = dyn_cast<StoreInst>(I)) + return SI->isSimple(); + return false; +} + // A BCE atom. struct BCEAtom { BCEAtom() : GEP(nullptr), LoadI(nullptr), Offset() {} @@ -81,14 +90,15 @@ BCEAtom visitICmpLoadOperand(Value *const Val) { LLVM_DEBUG(dbgs() << "used outside of block\n"); return {}; } - if (LoadI->isVolatile()) { - LLVM_DEBUG(dbgs() << "volatile\n"); + // Do not optimize atomic loads to non-atomic memcmp + if (!LoadI->isSimple()) { + LLVM_DEBUG(dbgs() << "volatile or atomic\n"); return {}; } Value *const Addr = LoadI->getOperand(0); if (auto *const GEP = dyn_cast<GetElementPtrInst>(Addr)) { LLVM_DEBUG(dbgs() << "GEP\n"); - if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { LLVM_DEBUG(dbgs() << "used outside of block\n"); return {}; } @@ -150,18 +160,19 @@ class BCECmpBlock { // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp // instructions in the block. - bool canSplit() const; + bool canSplit(AliasAnalysis *AA) const; // Return true if this all the relevant instructions in the BCE-cmp-block can // be sunk below this instruction. By doing this, we know we can separate the // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the // block. - bool canSinkBCECmpInst(const Instruction *, DenseSet<Instruction *> &) const; + bool canSinkBCECmpInst(const Instruction *, DenseSet<Instruction *> &, + AliasAnalysis *AA) const; // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block // instructions. Split the old block and move all non-BCE-cmp-insts into the // new parent block. - void split(BasicBlock *NewParent) const; + void split(BasicBlock *NewParent, AliasAnalysis *AA) const; // The basic block where this comparison happens. BasicBlock *BB = nullptr; @@ -179,12 +190,21 @@ private: }; bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, - DenseSet<Instruction *> &BlockInsts) const { + DenseSet<Instruction *> &BlockInsts, + AliasAnalysis *AA) const { // If this instruction has side effects and its in middle of the BCE cmp block // instructions, then bail for now. - // TODO: use alias analysis to tell whether there is real interference. - if (Inst->mayHaveSideEffects()) - return false; + if (Inst->mayHaveSideEffects()) { + // Bail if this is not a simple load or store + if (!isSimpleLoadOrStore(Inst)) + return false; + // Disallow stores that might alias the BCE operands + MemoryLocation LLoc = MemoryLocation::get(Lhs_.LoadI); + MemoryLocation RLoc = MemoryLocation::get(Rhs_.LoadI); + if (isModSet(AA->getModRefInfo(Inst, LLoc)) || + isModSet(AA->getModRefInfo(Inst, RLoc))) + return false; + } // Make sure this instruction does not use any of the BCE cmp block // instructions as operand. for (auto BI : BlockInsts) { @@ -194,14 +214,15 @@ bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, return true; } -void BCECmpBlock::split(BasicBlock *NewParent) const { +void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis *AA) const { DenseSet<Instruction *> BlockInsts( {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); llvm::SmallVector<Instruction *, 4> OtherInsts; for (Instruction &Inst : *BB) { if (BlockInsts.count(&Inst)) continue; - assert(canSinkBCECmpInst(&Inst, BlockInsts) && "Split unsplittable block"); + assert(canSinkBCECmpInst(&Inst, BlockInsts, AA) && + "Split unsplittable block"); // This is a non-BCE-cmp-block instruction. And it can be separated // from the BCE-cmp-block instruction. OtherInsts.push_back(&Inst); @@ -213,12 +234,12 @@ void BCECmpBlock::split(BasicBlock *NewParent) const { } } -bool BCECmpBlock::canSplit() const { +bool BCECmpBlock::canSplit(AliasAnalysis *AA) const { DenseSet<Instruction *> BlockInsts( {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); for (Instruction &Inst : *BB) { if (!BlockInsts.count(&Inst)) { - if (!canSinkBCECmpInst(&Inst, BlockInsts)) + if (!canSinkBCECmpInst(&Inst, BlockInsts, AA)) return false; } } @@ -262,8 +283,9 @@ BCECmpBlock visitICmp(const ICmpInst *const CmpI, if (!Lhs.Base()) return {}; auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1)); if (!Rhs.Base()) return {}; + const auto &DL = CmpI->getModule()->getDataLayout(); return BCECmpBlock(std::move(Lhs), std::move(Rhs), - CmpI->getOperand(0)->getType()->getScalarSizeInBits()); + DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); } return {}; } @@ -324,7 +346,8 @@ static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, // A chain of comparisons. class BCECmpChain { public: - BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi); + BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, + AliasAnalysis *AA); int size() const { return Comparisons_.size(); } @@ -332,7 +355,7 @@ class BCECmpChain { void dump() const; #endif // MERGEICMPS_DOT_ON - bool simplify(const TargetLibraryInfo *const TLI); + bool simplify(const TargetLibraryInfo *const TLI, AliasAnalysis *AA); private: static bool IsContiguous(const BCECmpBlock &First, @@ -348,7 +371,7 @@ class BCECmpChain { // null, the merged block will link to the phi block. void mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, BasicBlock *const NextBBInChain, PHINode &Phi, - const TargetLibraryInfo *const TLI); + const TargetLibraryInfo *const TLI, AliasAnalysis *AA); PHINode &Phi_; std::vector<BCECmpBlock> Comparisons_; @@ -356,7 +379,8 @@ class BCECmpChain { BasicBlock *EntryBlock_; }; -BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) +BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, + AliasAnalysis *AA) : Phi_(Phi) { assert(!Blocks.empty() && "a chain should have at least one block"); // Now look inside blocks to check for BCE comparisons. @@ -388,7 +412,7 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) // and start anew. // // NOTE: we only handle block with single predecessor for now. - if (Comparison.canSplit()) { + if (Comparison.canSplit(AA)) { LLVM_DEBUG(dbgs() << "Split initial block '" << Comparison.BB->getName() << "' that does extra work besides compare\n"); @@ -442,10 +466,9 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) #endif // MERGEICMPS_DOT_ON // Reorder blocks by LHS. We can do that without changing the // semantics because we are only accessing dereferencable memory. - llvm::sort(Comparisons_.begin(), Comparisons_.end(), - [](const BCECmpBlock &a, const BCECmpBlock &b) { - return a.Lhs() < b.Lhs(); - }); + llvm::sort(Comparisons_, [](const BCECmpBlock &a, const BCECmpBlock &b) { + return a.Lhs() < b.Lhs(); + }); #ifdef MERGEICMPS_DOT_ON errs() << "AFTER REORDERING:\n\n"; dump(); @@ -475,7 +498,8 @@ void BCECmpChain::dump() const { } #endif // MERGEICMPS_DOT_ON -bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI) { +bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI, + AliasAnalysis *AA) { // First pass to check if there is at least one merge. If not, we don't do // anything and we keep analysis passes intact. { @@ -523,13 +547,13 @@ bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI) { // Merge all previous comparisons and start a new merge block. mergeComparisons( makeArrayRef(Comparisons_).slice(I - NumMerged, NumMerged), - Comparisons_[I].BB, Phi_, TLI); + Comparisons_[I].BB, Phi_, TLI, AA); NumMerged = 1; } } mergeComparisons(makeArrayRef(Comparisons_) .slice(Comparisons_.size() - NumMerged, NumMerged), - nullptr, Phi_, TLI); + nullptr, Phi_, TLI, AA); return true; } @@ -537,7 +561,8 @@ bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI) { void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, BasicBlock *const NextBBInChain, PHINode &Phi, - const TargetLibraryInfo *const TLI) { + const TargetLibraryInfo *const TLI, + AliasAnalysis *AA) { assert(!Comparisons.empty()); const auto &FirstComparison = *Comparisons.begin(); BasicBlock *const BB = FirstComparison.BB; @@ -550,7 +575,7 @@ void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, auto C = std::find_if(Comparisons.begin(), Comparisons.end(), [](const BCECmpBlock &B) { return B.RequireSplit; }); if (C != Comparisons.end()) - C->split(EntryBlock_); + C->split(EntryBlock_, AA); LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); const auto TotalSize = @@ -666,7 +691,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, return Blocks; } -bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { +bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI, + AliasAnalysis *AA) { LLVM_DEBUG(dbgs() << "processPhi()\n"); if (Phi.getNumIncomingValues() <= 1) { LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); @@ -724,14 +750,14 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { const auto Blocks = getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues()); if (Blocks.empty()) return false; - BCECmpChain CmpChain(Blocks, Phi); + BCECmpChain CmpChain(Blocks, Phi, AA); if (CmpChain.size() < 2) { LLVM_DEBUG(dbgs() << "skip: only one compare block\n"); return false; } - return CmpChain.simplify(TLI); + return CmpChain.simplify(TLI, AA); } class MergeICmps : public FunctionPass { @@ -746,7 +772,8 @@ class MergeICmps : public FunctionPass { if (skipFunction(F)) return false; const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - auto PA = runImpl(F, &TLI, &TTI); + AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto PA = runImpl(F, &TLI, &TTI, AA); return !PA.areAllPreserved(); } @@ -754,14 +781,16 @@ class MergeICmps : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); } PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI); + const TargetTransformInfo *TTI, AliasAnalysis *AA); }; PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI) { + const TargetTransformInfo *TTI, + AliasAnalysis *AA) { LLVM_DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); // We only try merging comparisons if the target wants to expand memcmp later. @@ -777,7 +806,7 @@ PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { // A Phi operation is always first in a basic block. if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) - MadeChange |= processPhi(*Phi, TLI); + MadeChange |= processPhi(*Phi, TLI, AA); } if (MadeChange) return PreservedAnalyses::none(); @@ -791,6 +820,7 @@ INITIALIZE_PASS_BEGIN(MergeICmps, "mergeicmps", "Merge contiguous icmps into a memcmp", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_END(MergeICmps, "mergeicmps", "Merge contiguous icmps into a memcmp", false, false) diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 3464b759280f..ee21feca8d2c 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -211,6 +211,7 @@ PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, auto *NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink", &BB->front()); + NewPN->applyMergedLocation(S0->getDebugLoc(), S1->getDebugLoc()); NewPN->addIncoming(Opd1, S0->getParent()); NewPN->addIncoming(Opd2, S1->getParent()); return NewPN; diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp index 3e47e9441d15..7cbb0fe70f82 100644 --- a/lib/Transforms/Scalar/NewGVN.cpp +++ b/lib/Transforms/Scalar/NewGVN.cpp @@ -657,8 +657,8 @@ public: TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA, const DataLayout &DL) : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), - PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)), SQ(DL, TLI, DT, AC) { - } + PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)), + SQ(DL, TLI, DT, AC, /*CtxI=*/nullptr, /*UseInstrInfo=*/false) {} bool runGVN(); @@ -777,7 +777,7 @@ private: // Reachability handling. void updateReachableEdge(BasicBlock *, BasicBlock *); - void processOutgoingEdges(TerminatorInst *, BasicBlock *); + void processOutgoingEdges(Instruction *, BasicBlock *); Value *findConditionEquivalence(Value *) const; // Elimination. @@ -959,8 +959,7 @@ static bool isCopyOfAPHI(const Value *V) { // order. The BlockInstRange numbers are generated in an RPO walk of the basic // blocks. void NewGVN::sortPHIOps(MutableArrayRef<ValPair> Ops) const { - llvm::sort(Ops.begin(), Ops.end(), - [&](const ValPair &P1, const ValPair &P2) { + llvm::sort(Ops, [&](const ValPair &P1, const ValPair &P2) { return BlockInstRange.lookup(P1.second).first < BlockInstRange.lookup(P2.second).first; }); @@ -1087,9 +1086,13 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E, CongruenceClass *CC = ValueToClass.lookup(V); if (CC) { if (CC->getLeader() && CC->getLeader() != I) { - // Don't add temporary instructions to the user lists. - if (!AllTempInstructions.count(I)) - addAdditionalUsers(V, I); + // If we simplified to something else, we need to communicate + // that we're users of the value we simplified to. + if (I != V) { + // Don't add temporary instructions to the user lists. + if (!AllTempInstructions.count(I)) + addAdditionalUsers(V, I); + } return createVariableOrConstant(CC->getLeader()); } if (CC->getDefiningExpr()) { @@ -1752,7 +1755,7 @@ NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps, return true; }); // If we are left with no operands, it's dead. - if (Filtered.begin() == Filtered.end()) { + if (empty(Filtered)) { // If it has undef at this point, it means there are no-non-undef arguments, // and thus, the value of the phi node must be undef. if (HasUndef) { @@ -2484,7 +2487,7 @@ Value *NewGVN::findConditionEquivalence(Value *Cond) const { } // Process the outgoing edges of a block for reachability. -void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { +void NewGVN::processOutgoingEdges(Instruction *TI, BasicBlock *B) { // Evaluate reachability of terminator instruction. BranchInst *BR; if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) { @@ -2925,7 +2928,7 @@ void NewGVN::initializeCongruenceClasses(Function &F) { PHINodeUses.insert(UInst); // Don't insert void terminators into the class. We don't value number // them, and they just end up sitting in TOP. - if (isa<TerminatorInst>(I) && I.getType()->isVoidTy()) + if (I.isTerminator() && I.getType()->isVoidTy()) continue; TOPClass->insert(&I); ValueToClass[&I] = TOPClass; @@ -3134,7 +3137,7 @@ void NewGVN::valueNumberInstruction(Instruction *I) { auto *Symbolized = createUnknownExpression(I); performCongruenceFinding(I, Symbolized); } - processOutgoingEdges(dyn_cast<TerminatorInst>(I), I->getParent()); + processOutgoingEdges(I, I->getParent()); } } @@ -3172,12 +3175,8 @@ bool NewGVN::singleReachablePHIPath( auto FilteredPhiArgs = make_filter_range(MP->operands(), ReachableOperandPred); SmallVector<const Value *, 32> OperandList; - std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), - std::back_inserter(OperandList)); - bool Okay = OperandList.size() == 1; - if (!Okay) - Okay = - std::equal(OperandList.begin(), OperandList.end(), OperandList.begin()); + llvm::copy(FilteredPhiArgs, std::back_inserter(OperandList)); + bool Okay = is_splat(OperandList); if (Okay) return singleReachablePHIPath(Visited, cast<MemoryAccess>(OperandList[0]), Second); @@ -3272,8 +3271,7 @@ void NewGVN::verifyMemoryCongruency() const { const MemoryDef *MD = cast<MemoryDef>(U); return ValueToClass.lookup(MD->getMemoryInst()); }); - assert(std::equal(PhiOpClasses.begin(), PhiOpClasses.end(), - PhiOpClasses.begin()) && + assert(is_splat(PhiOpClasses) && "All MemoryPhi arguments should be in the same class"); } } @@ -3501,9 +3499,11 @@ bool NewGVN::runGVN() { if (!ToErase->use_empty()) ToErase->replaceAllUsesWith(UndefValue::get(ToErase->getType())); - if (ToErase->getParent()) - ToErase->eraseFromParent(); + assert(ToErase->getParent() && + "BB containing ToErase deleted unexpectedly!"); + ToErase->eraseFromParent(); } + Changed |= !InstructionsToErase.empty(); // Delete all unreachable blocks. auto UnreachableBlockPred = [&](const BasicBlock &BB) { @@ -3697,37 +3697,6 @@ void NewGVN::convertClassToLoadsAndStores( } } -static void patchReplacementInstruction(Instruction *I, Value *Repl) { - auto *ReplInst = dyn_cast<Instruction>(Repl); - if (!ReplInst) - return; - - // Patch the replacement so that it is not more restrictive than the value - // being replaced. - // Note that if 'I' is a load being replaced by some operation, - // for example, by an arithmetic operation, then andIRFlags() - // would just erase all math flags from the original arithmetic - // operation, which is clearly not wanted and not needed. - if (!isa<LoadInst>(I)) - ReplInst->andIRFlags(I); - - // FIXME: If both the original and replacement value are part of the - // same control-flow region (meaning that the execution of one - // guarantees the execution of the other), then we can combine the - // noalias scopes here and do better than the general conservative - // answer used in combineMetadata(). - - // In general, GVN unifies expressions over different control-flow - // regions, and so we need a conservative combination of the noalias - // scopes. - static const unsigned KnownIDs[] = { - LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_range, - LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, - LLVMContext::MD_invariant_group}; - combineMetadata(ReplInst, I, KnownIDs); -} - static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { patchReplacementInstruction(I, Repl); I->replaceAllUsesWith(Repl); @@ -3988,7 +3957,7 @@ bool NewGVN::eliminateInstructions(Function &F) { convertClassToDFSOrdered(*CC, DFSOrderedSet, UseCounts, ProbablyDead); // Sort the whole thing. - llvm::sort(DFSOrderedSet.begin(), DFSOrderedSet.end()); + llvm::sort(DFSOrderedSet); for (auto &VD : DFSOrderedSet) { int MemberDFSIn = VD.DFSIn; int MemberDFSOut = VD.DFSOut; @@ -4124,10 +4093,13 @@ bool NewGVN::eliminateInstructions(Function &F) { // It's about to be alive again. if (LeaderUseCount == 0 && isa<Instruction>(DominatingLeader)) ProbablyDead.erase(cast<Instruction>(DominatingLeader)); - // Copy instructions, however, are still dead because we use their - // operand as the leader. - if (LeaderUseCount == 0 && isSSACopy) - ProbablyDead.insert(II); + // For copy instructions, we use their operand as a leader, + // which means we remove a user of the copy and it may become dead. + if (isSSACopy) { + unsigned &IIUseCount = UseCounts[II]; + if (--IIUseCount == 0) + ProbablyDead.insert(II); + } ++LeaderUseCount; AnythingReplaced = true; } @@ -4151,7 +4123,7 @@ bool NewGVN::eliminateInstructions(Function &F) { // If we have possible dead stores to look at, try to eliminate them. if (CC->getStoreCount() > 0) { convertClassToLoadsAndStores(*CC, PossibleDeadStores); - llvm::sort(PossibleDeadStores.begin(), PossibleDeadStores.end()); + llvm::sort(PossibleDeadStores); ValueDFSStack EliminationStack; for (auto &VD : PossibleDeadStores) { int MemberDFSIn = VD.DFSIn; diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 1748815c5941..05ea9144f66c 100644 --- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/Support/DebugCounter.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -24,6 +25,8 @@ using namespace llvm; #define DEBUG_TYPE "partially-inline-libcalls" +DEBUG_COUNTER(PILCounter, "partially-inline-libcalls-transform", + "Controls transformations in partially-inline-libcalls"); static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, BasicBlock &CurrBB, Function::iterator &BB, @@ -33,6 +36,9 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, if (Call->onlyReadsMemory()) return false; + if (!DebugCounter::shouldExecute(PILCounter)) + return false; + // Do the following transformation: // // (before) diff --git a/lib/Transforms/Scalar/PlaceSafepoints.cpp b/lib/Transforms/Scalar/PlaceSafepoints.cpp index 8f30bccf48f1..fd2eb85fd7bf 100644 --- a/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -105,7 +105,7 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { /// The output of the pass - gives a list of each backedge (described by /// pointing at the branch) which need a poll inserted. - std::vector<TerminatorInst *> PollLocations; + std::vector<Instruction *> PollLocations; /// True unless we're running spp-no-calls in which case we need to disable /// the call-dependent placement opts. @@ -348,7 +348,7 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { // Safepoint insertion would involve creating a new basic block (as the // target of the current backedge) which does the safepoint (of all live // variables) and branches to the true header - TerminatorInst *Term = Pred->getTerminator(); + Instruction *Term = Pred->getTerminator(); LLVM_DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term); @@ -524,7 +524,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { }; // We need the order of list to be stable so that naming ends up stable // when we split edges. This makes test cases much easier to write. - llvm::sort(PollLocations.begin(), PollLocations.end(), OrderByBBName); + llvm::sort(PollLocations, OrderByBBName); // We can sometimes end up with duplicate poll locations. This happens if // a single loop is visited more than once. The fact this happens seems @@ -535,7 +535,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { // Insert a poll at each point the analysis pass identified // The poll location must be the terminator of a loop latch block. - for (TerminatorInst *Term : PollLocations) { + for (Instruction *Term : PollLocations) { // We are inserting a poll, the function is modified Modified = true; diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index 1df0a9c49fb1..cb893eab1654 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -63,6 +63,7 @@ using namespace llvm; using namespace reassociate; +using namespace PatternMatch; #define DEBUG_TYPE "reassociate" @@ -125,10 +126,10 @@ XorOpnd::XorOpnd(Value *V) { Value *V0 = I->getOperand(0); Value *V1 = I->getOperand(1); const APInt *C; - if (match(V0, PatternMatch::m_APInt(C))) + if (match(V0, m_APInt(C))) std::swap(V0, V1); - if (match(V1, PatternMatch::m_APInt(C))) { + if (match(V1, m_APInt(C))) { ConstPart = *C; SymbolicPart = V0; isOr = (I->getOpcode() == Instruction::Or); @@ -204,10 +205,10 @@ unsigned ReassociatePass::getRank(Value *V) { for (unsigned i = 0, e = I->getNumOperands(); i != e && Rank != MaxRank; ++i) Rank = std::max(Rank, getRank(I->getOperand(i))); - // If this is a not or neg instruction, do not count it for rank. This + // If this is a 'not' or 'neg' instruction, do not count it for rank. This // assures us that X and ~X will have the same rank. - if (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I) && - !BinaryOperator::isFNeg(I)) + if (!match(I, m_Not(m_Value())) && !match(I, m_Neg(m_Value())) && + !match(I, m_FNeg(m_Value()))) ++Rank; LLVM_DEBUG(dbgs() << "Calculated Rank[" << V->getName() << "] = " << Rank @@ -573,8 +574,8 @@ static bool LinearizeExprTree(BinaryOperator *I, // If this is a multiply expression, turn any internal negations into // multiplies by -1 so they can be reassociated. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) - if ((Opcode == Instruction::Mul && BinaryOperator::isNeg(BO)) || - (Opcode == Instruction::FMul && BinaryOperator::isFNeg(BO))) { + if ((Opcode == Instruction::Mul && match(BO, m_Neg(m_Value()))) || + (Opcode == Instruction::FMul && match(BO, m_FNeg(m_Value())))) { LLVM_DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); BO = LowerNegateToMultiply(BO); @@ -788,13 +789,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, // Discard any debug info related to the expressions that has changed (we // can leave debug infor related to the root, since the result of the // expression tree should be the same even after reassociation). - SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; - findDbgUsers(DbgUsers, ExpressionChanged); - for (auto *DII : DbgUsers) { - Value *Undef = UndefValue::get(ExpressionChanged->getType()); - DII->setOperand(0, MetadataAsValue::get(DII->getContext(), - ValueAsMetadata::get(Undef))); - } + replaceDbgUsesWithUndef(ExpressionChanged); ExpressionChanged->moveBefore(I); ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin()); @@ -854,7 +849,7 @@ static Value *NegateValue(Value *V, Instruction *BI, // Okay, we need to materialize a negated version of V with an instruction. // Scan the use lists of V to see if we have one already. for (User *U : V->users()) { - if (!BinaryOperator::isNeg(U) && !BinaryOperator::isFNeg(U)) + if (!match(U, m_Neg(m_Value())) && !match(U, m_FNeg(m_Value()))) continue; // We found one! Now we have to make sure that the definition dominates @@ -899,7 +894,7 @@ static Value *NegateValue(Value *V, Instruction *BI, /// Return true if we should break up this subtract of X-Y into (X + -Y). static bool ShouldBreakUpSubtract(Instruction *Sub) { // If this is a negation, we can't split it up! - if (BinaryOperator::isNeg(Sub) || BinaryOperator::isFNeg(Sub)) + if (match(Sub, m_Neg(m_Value())) || match(Sub, m_FNeg(m_Value()))) return false; // Don't breakup X - undef. @@ -1113,8 +1108,8 @@ static Value *OptimizeAndOrXor(unsigned Opcode, for (unsigned i = 0, e = Ops.size(); i != e; ++i) { // First, check for X and ~X in the operand list. assert(i < Ops.size()); - if (BinaryOperator::isNot(Ops[i].Op)) { // Cannot occur for ^. - Value *X = BinaryOperator::getNotArgument(Ops[i].Op); + Value *X; + if (match(Ops[i].Op, m_Not(m_Value(X)))) { // Cannot occur for ^. unsigned FoundX = FindInOperandList(Ops, i, X); if (FoundX != i) { if (Opcode == Instruction::And) // ...&X&~X = 0 @@ -1304,7 +1299,7 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, Value *V = Ops[i].Op; const APInt *C; // TODO: Support non-splat vectors. - if (match(V, PatternMatch::m_APInt(C))) { + if (match(V, m_APInt(C))) { ConstOpnd ^= *C; } else { XorOpnd O(V); @@ -1460,27 +1455,22 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, } // Check for X and -X or X and ~X in the operand list. - if (!BinaryOperator::isNeg(TheOp) && !BinaryOperator::isFNeg(TheOp) && - !BinaryOperator::isNot(TheOp)) + Value *X; + if (!match(TheOp, m_Neg(m_Value(X))) && !match(TheOp, m_Not(m_Value(X))) && + !match(TheOp, m_FNeg(m_Value(X)))) continue; - Value *X = nullptr; - if (BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp)) - X = BinaryOperator::getNegArgument(TheOp); - else if (BinaryOperator::isNot(TheOp)) - X = BinaryOperator::getNotArgument(TheOp); - unsigned FoundX = FindInOperandList(Ops, i, X); if (FoundX == i) continue; // Remove X and -X from the operand list. if (Ops.size() == 2 && - (BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp))) + (match(TheOp, m_Neg(m_Value())) || match(TheOp, m_FNeg(m_Value())))) return Constant::getNullValue(X->getType()); // Remove X and ~X from the operand list. - if (Ops.size() == 2 && BinaryOperator::isNot(TheOp)) + if (Ops.size() == 2 && match(TheOp, m_Not(m_Value()))) return Constant::getAllOnesValue(X->getType()); Ops.erase(Ops.begin()+i); @@ -1494,7 +1484,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, e -= 2; // Removed two elements. // if X and ~X we append -1 to the operand list. - if (BinaryOperator::isNot(TheOp)) { + if (match(TheOp, m_Not(m_Value()))) { Value *V = Constant::getAllOnesValue(X->getType()); Ops.insert(Ops.end(), ValueEntry(getRank(V), V)); e += 1; @@ -2058,7 +2048,7 @@ void ReassociatePass::OptimizeInst(Instruction *I) { RedoInsts.insert(I); MadeChange = true; I = NI; - } else if (BinaryOperator::isNeg(I)) { + } else if (match(I, m_Neg(m_Value()))) { // Otherwise, this is a negation. See if the operand is a multiply tree // and if this is not an inner node of a multiply tree. if (isReassociableOp(I->getOperand(1), Instruction::Mul) && @@ -2082,7 +2072,7 @@ void ReassociatePass::OptimizeInst(Instruction *I) { RedoInsts.insert(I); MadeChange = true; I = NI; - } else if (BinaryOperator::isFNeg(I)) { + } else if (match(I, m_FNeg(m_Value()))) { // Otherwise, this is a negation. See if the operand is a multiply tree // and if this is not an inner node of a multiply tree. if (isReassociableOp(I->getOperand(1), Instruction::FMul) && diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index 0de2bc72b522..42d7ed5bc534 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -28,7 +28,6 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -38,6 +37,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -65,6 +65,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> #include <cassert> @@ -346,7 +347,7 @@ static bool containsGCPtrType(Type *Ty) { if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) return containsGCPtrType(AT->getElementType()); if (StructType *ST = dyn_cast<StructType>(Ty)) - return llvm::any_of(ST->subtypes(), containsGCPtrType); + return llvm::any_of(ST->elements(), containsGCPtrType); return false; } @@ -1824,7 +1825,7 @@ static void relocationViaAlloca( } } - llvm::sort(Uses.begin(), Uses.end()); + llvm::sort(Uses); auto Last = std::unique(Uses.begin(), Uses.end()); Uses.erase(Last, Uses.end()); @@ -1850,13 +1851,13 @@ static void relocationViaAlloca( StoreInst *Store = new StoreInst(Def, Alloca); if (Instruction *Inst = dyn_cast<Instruction>(Def)) { if (InvokeInst *Invoke = dyn_cast<InvokeInst>(Inst)) { - // InvokeInst is a TerminatorInst so the store need to be inserted - // into its normal destination block. + // InvokeInst is a terminator so the store need to be inserted into its + // normal destination block. BasicBlock *NormalDest = Invoke->getNormalDest(); Store->insertBefore(NormalDest->getFirstNonPHI()); } else { assert(!Inst->isTerminator() && - "The only TerminatorInst that can produce a value is " + "The only terminator that can produce a value is " "InvokeInst which is handled above."); Store->insertAfter(Inst); } @@ -2534,9 +2535,10 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, // Delete any unreachable statepoints so that we don't have unrewritten // statepoints surviving this pass. This makes testing easier and the // resulting IR less confusing to human readers. - DeferredDominance DD(DT); - bool MadeChange = removeUnreachableBlocks(F, nullptr, &DD); - DD.flush(); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + bool MadeChange = removeUnreachableBlocks(F, nullptr, &DTU); + // Flush the Dominator Tree. + DTU.getDomTree(); // Gather all the statepoints which need rewritten. Be careful to only // consider those in reachable code since we need to ask dominance queries @@ -2582,7 +2584,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, // increase the liveset of any statepoint we move over. This is profitable // as long as all statepoints are in rare blocks. If we had in-register // lowering for live values this would be a much safer transform. - auto getConditionInst = [](TerminatorInst *TI) -> Instruction* { + auto getConditionInst = [](Instruction *TI) -> Instruction * { if (auto *BI = dyn_cast<BranchInst>(TI)) if (BI->isConditional()) return dyn_cast<Instruction>(BI->getCondition()); @@ -2590,7 +2592,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, return nullptr; }; for (BasicBlock &BB : F) { - TerminatorInst *TI = BB.getTerminator(); + Instruction *TI = BB.getTerminator(); if (auto *Cond = getConditionInst(TI)) // TODO: Handle more than just ICmps here. We should be able to move // most instructions without side effects or memory access. @@ -2673,7 +2675,7 @@ static SetVector<Value *> computeKillSet(BasicBlock *BB) { /// Check that the items in 'Live' dominate 'TI'. This is used as a basic /// sanity check for the liveness computation. static void checkBasicSSA(DominatorTree &DT, SetVector<Value *> &Live, - TerminatorInst *TI, bool TermOkay = false) { + Instruction *TI, bool TermOkay = false) { for (Value *V : Live) { if (auto *I = dyn_cast<Instruction>(V)) { // The terminator can be a member of the LiveOut set. LLVM's definition diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index 5e3ddeda2d49..2f6ed05c023b 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -55,6 +55,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/PredicateInfo.h" #include <cassert> #include <utility> #include <vector> @@ -246,7 +247,27 @@ class SCCPSolver : public InstVisitor<SCCPSolver> { using Edge = std::pair<BasicBlock *, BasicBlock *>; DenseSet<Edge> KnownFeasibleEdges; + DenseMap<Function *, AnalysisResultsForFn> AnalysisResults; + DenseMap<Value *, SmallPtrSet<User *, 2>> AdditionalUsers; + public: + void addAnalysis(Function &F, AnalysisResultsForFn A) { + AnalysisResults.insert({&F, std::move(A)}); + } + + const PredicateBase *getPredicateInfoFor(Instruction *I) { + auto A = AnalysisResults.find(I->getParent()->getParent()); + if (A == AnalysisResults.end()) + return nullptr; + return A->second.PredInfo->getPredicateInfoFor(I); + } + + DomTreeUpdater getDTU(Function &F) { + auto A = AnalysisResults.find(&F); + assert(A != AnalysisResults.end() && "Need analysis results for function."); + return {A->second.DT, A->second.PDT, DomTreeUpdater::UpdateStrategy::Lazy}; + } + SCCPSolver(const DataLayout &DL, const TargetLibraryInfo *tli) : DL(DL), TLI(tli) {} @@ -548,7 +569,7 @@ private: // getFeasibleSuccessors - Return a vector of booleans to indicate which // successors are reachable from a given terminator instruction. - void getFeasibleSuccessors(TerminatorInst &TI, SmallVectorImpl<bool> &Succs); + void getFeasibleSuccessors(Instruction &TI, SmallVectorImpl<bool> &Succs); // OperandChangedState - This method is invoked on all of the users of an // instruction that was just changed state somehow. Based on this @@ -558,6 +579,26 @@ private: visit(*I); } + // Add U as additional user of V. + void addAdditionalUser(Value *V, User *U) { + auto Iter = AdditionalUsers.insert({V, {}}); + Iter.first->second.insert(U); + } + + // Mark I's users as changed, including AdditionalUsers. + void markUsersAsChanged(Value *I) { + for (User *U : I->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + OperandChangedState(UI); + + auto Iter = AdditionalUsers.find(I); + if (Iter != AdditionalUsers.end()) { + for (User *U : Iter->second) + if (auto *UI = dyn_cast<Instruction>(U)) + OperandChangedState(UI); + } + } + private: friend class InstVisitor<SCCPSolver>; @@ -569,7 +610,7 @@ private: // Terminators void visitReturnInst(ReturnInst &I); - void visitTerminatorInst(TerminatorInst &TI); + void visitTerminator(Instruction &TI); void visitCastInst(CastInst &I); void visitSelectInst(SelectInst &I); @@ -580,7 +621,7 @@ private: void visitCatchSwitchInst(CatchSwitchInst &CPI) { markOverdefined(&CPI); - visitTerminatorInst(CPI); + visitTerminator(CPI); } // Instructions that cannot be folded away. @@ -595,12 +636,12 @@ private: void visitInvokeInst (InvokeInst &II) { visitCallSite(&II); - visitTerminatorInst(II); + visitTerminator(II); } void visitCallSite (CallSite CS); - void visitResumeInst (TerminatorInst &I) { /*returns void*/ } - void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ } + void visitResumeInst (ResumeInst &I) { /*returns void*/ } + void visitUnreachableInst(UnreachableInst &I) { /*returns void*/ } void visitFenceInst (FenceInst &I) { /*returns void*/ } void visitInstruction(Instruction &I) { @@ -615,7 +656,7 @@ private: // getFeasibleSuccessors - Return a vector of booleans to indicate which // successors are reachable from a given terminator instruction. -void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, +void SCCPSolver::getFeasibleSuccessors(Instruction &TI, SmallVectorImpl<bool> &Succs) { Succs.resize(TI.getNumSuccessors()); if (auto *BI = dyn_cast<BranchInst>(&TI)) { @@ -640,7 +681,7 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, } // Unwinding instructions successors are always executable. - if (TI.isExceptional()) { + if (TI.isExceptionalTerminator()) { Succs.assign(TI.getNumSuccessors(), true); return; } @@ -802,7 +843,7 @@ void SCCPSolver::visitReturnInst(ReturnInst &I) { } } -void SCCPSolver::visitTerminatorInst(TerminatorInst &TI) { +void SCCPSolver::visitTerminator(Instruction &TI) { SmallVector<bool, 16> SuccFeasible; getFeasibleSuccessors(TI, SuccFeasible); @@ -982,8 +1023,9 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { // Handle ICmpInst instruction. void SCCPSolver::visitCmpInst(CmpInst &I) { - LatticeVal &IV = ValueState[&I]; - if (IV.isOverdefined()) return; + // Do not cache this lookup, getValueState calls later in the function might + // invalidate the reference. + if (ValueState[&I].isOverdefined()) return; Value *Op1 = I.getOperand(0); Value *Op2 = I.getOperand(1); @@ -1011,7 +1053,8 @@ void SCCPSolver::visitCmpInst(CmpInst &I) { } // If operands are still unknown, wait for it to resolve. - if (!V1State.isOverdefined() && !V2State.isOverdefined() && !IV.isConstant()) + if (!V1State.isOverdefined() && !V2State.isOverdefined() && + !ValueState[&I].isConstant()) return; markOverdefined(&I); @@ -1119,6 +1162,65 @@ void SCCPSolver::visitCallSite(CallSite CS) { Function *F = CS.getCalledFunction(); Instruction *I = CS.getInstruction(); + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + if (ValueState[I].isOverdefined()) + return; + + auto *PI = getPredicateInfoFor(I); + if (!PI) + return; + + Value *CopyOf = I->getOperand(0); + auto *PBranch = dyn_cast<PredicateBranch>(PI); + if (!PBranch) { + mergeInValue(ValueState[I], I, getValueState(CopyOf)); + return; + } + + Value *Cond = PBranch->Condition; + + // Everything below relies on the condition being a comparison. + auto *Cmp = dyn_cast<CmpInst>(Cond); + if (!Cmp) { + mergeInValue(ValueState[I], I, getValueState(CopyOf)); + return; + } + + Value *CmpOp0 = Cmp->getOperand(0); + Value *CmpOp1 = Cmp->getOperand(1); + if (CopyOf != CmpOp0 && CopyOf != CmpOp1) { + mergeInValue(ValueState[I], I, getValueState(CopyOf)); + return; + } + + if (CmpOp0 != CopyOf) + std::swap(CmpOp0, CmpOp1); + + LatticeVal OriginalVal = getValueState(CopyOf); + LatticeVal EqVal = getValueState(CmpOp1); + LatticeVal &IV = ValueState[I]; + if (PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_EQ) { + addAdditionalUser(CmpOp1, I); + if (OriginalVal.isConstant()) + mergeInValue(IV, I, OriginalVal); + else + mergeInValue(IV, I, EqVal); + return; + } + if (!PBranch->TrueEdge && Cmp->getPredicate() == CmpInst::ICMP_NE) { + addAdditionalUser(CmpOp1, I); + if (OriginalVal.isConstant()) + mergeInValue(IV, I, OriginalVal); + else + mergeInValue(IV, I, EqVal); + return; + } + + return (void)mergeInValue(IV, I, getValueState(CopyOf)); + } + } + // The common case is that we aren't tracking the callee, either because we // are not doing interprocedural analysis or the callee is indirect, or is // external. Handle these cases first. @@ -1134,6 +1236,8 @@ CallOverdefined: SmallVector<Constant*, 8> Operands; for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); AI != E; ++AI) { + if (AI->get()->getType()->isStructTy()) + return markOverdefined(I); // Can't handle struct args. LatticeVal State = getValueState(*AI); if (State.isUnknown()) @@ -1238,9 +1342,7 @@ void SCCPSolver::Solve() { // since all of its users will have already been marked as overdefined // Update all of the users of this instruction's value. // - for (User *U : I->users()) - if (auto *UI = dyn_cast<Instruction>(U)) - OperandChangedState(UI); + markUsersAsChanged(I); } // Process the instruction work list. @@ -1257,9 +1359,7 @@ void SCCPSolver::Solve() { // Update all of the users of this instruction's value. // if (I->getType()->isStructTy() || !getValueState(I).isOverdefined()) - for (User *U : I->users()) - if (auto *UI = dyn_cast<Instruction>(U)) - OperandChangedState(UI); + markUsersAsChanged(I); } // Process the basic block work list. @@ -1522,7 +1622,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // Check to see if we have a branch or switch on an undefined value. If so // we force the branch to go one way or the other to make the successor // values live. It doesn't really matter which way we force it. - TerminatorInst *TI = BB.getTerminator(); + Instruction *TI = BB.getTerminator(); if (auto *BI = dyn_cast<BranchInst>(TI)) { if (!BI->isConditional()) continue; if (!getValueState(BI->getCondition()).isUnknown()) @@ -1694,7 +1794,7 @@ static bool runSCCP(Function &F, const DataLayout &DL, // constants if we have found them to be of constant values. for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { Instruction *Inst = &*BI++; - if (Inst->getType()->isVoidTy() || isa<TerminatorInst>(Inst)) + if (Inst->getType()->isVoidTy() || Inst->isTerminator()) continue; if (tryToReplaceWithConstant(Solver, Inst)) { @@ -1798,8 +1898,44 @@ static void findReturnsToZap(Function &F, } } -bool llvm::runIPSCCP(Module &M, const DataLayout &DL, - const TargetLibraryInfo *TLI) { +// Update the condition for terminators that are branching on indeterminate +// values, forcing them to use a specific edge. +static void forceIndeterminateEdge(Instruction* I, SCCPSolver &Solver) { + BasicBlock *Dest = nullptr; + Constant *C = nullptr; + if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { + if (!isa<ConstantInt>(SI->getCondition())) { + // Indeterminate switch; use first case value. + Dest = SI->case_begin()->getCaseSuccessor(); + C = SI->case_begin()->getCaseValue(); + } + } else if (BranchInst *BI = dyn_cast<BranchInst>(I)) { + if (!isa<ConstantInt>(BI->getCondition())) { + // Indeterminate branch; use false. + Dest = BI->getSuccessor(1); + C = ConstantInt::getFalse(BI->getContext()); + } + } else if (IndirectBrInst *IBR = dyn_cast<IndirectBrInst>(I)) { + if (!isa<BlockAddress>(IBR->getAddress()->stripPointerCasts())) { + // Indeterminate indirectbr; use successor 0. + Dest = IBR->getSuccessor(0); + C = BlockAddress::get(IBR->getSuccessor(0)); + } + } else { + llvm_unreachable("Unexpected terminator instruction"); + } + if (C) { + assert(Solver.isEdgeFeasible(I->getParent(), Dest) && + "Didn't find feasible edge?"); + (void)Dest; + + I->setOperand(0, C); + } +} + +bool llvm::runIPSCCP( + Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI, + function_ref<AnalysisResultsForFn(Function &)> getAnalysis) { SCCPSolver Solver(DL, TLI); // Loop over all functions, marking arguments to those with their addresses @@ -1808,6 +1944,8 @@ bool llvm::runIPSCCP(Module &M, const DataLayout &DL, if (F.isDeclaration()) continue; + Solver.addAnalysis(F, getAnalysis(F)); + // Determine if we can track the function's return values. If so, add the // function to the solver's set of return-tracked functions. if (canTrackReturnsInterprocedurally(&F)) @@ -1856,12 +1994,13 @@ bool llvm::runIPSCCP(Module &M, const DataLayout &DL, // Iterate over all of the instructions in the module, replacing them with // constants if we have found them to be of constant values. - SmallVector<BasicBlock*, 512> BlocksToErase; for (Function &F : M) { if (F.isDeclaration()) continue; + SmallVector<BasicBlock *, 512> BlocksToErase; + if (Solver.isBlockExecutable(&F.front())) for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E; ++AI) { @@ -1897,23 +2036,26 @@ bool llvm::runIPSCCP(Module &M, const DataLayout &DL, } } - // Change dead blocks to unreachable. We do it after replacing constants in - // all executable blocks, because changeToUnreachable may remove PHI nodes - // in executable blocks we found values for. The function's entry block is - // not part of BlocksToErase, so we have to handle it separately. - for (BasicBlock *BB : BlocksToErase) + DomTreeUpdater DTU = Solver.getDTU(F); + // Change dead blocks to unreachable. We do it after replacing constants + // in all executable blocks, because changeToUnreachable may remove PHI + // nodes in executable blocks we found values for. The function's entry + // block is not part of BlocksToErase, so we have to handle it separately. + for (BasicBlock *BB : BlocksToErase) { NumInstRemoved += - changeToUnreachable(BB->getFirstNonPHI(), /*UseLLVMTrap=*/false); + changeToUnreachable(BB->getFirstNonPHI(), /*UseLLVMTrap=*/false, + /*PreserveLCSSA=*/false, &DTU); + } if (!Solver.isBlockExecutable(&F.front())) NumInstRemoved += changeToUnreachable(F.front().getFirstNonPHI(), - /*UseLLVMTrap=*/false); + /*UseLLVMTrap=*/false, + /*PreserveLCSSA=*/false, &DTU); - // Now that all instructions in the function are constant folded, erase dead - // blocks, because we can now use ConstantFoldTerminator to get rid of - // in-edges. - for (unsigned i = 0, e = BlocksToErase.size(); i != e; ++i) { + // Now that all instructions in the function are constant folded, + // use ConstantFoldTerminator to get rid of in-edges, record DT updates and + // delete dead BBs. + for (BasicBlock *DeadBB : BlocksToErase) { // If there are any PHI nodes in this successor, drop entries for BB now. - BasicBlock *DeadBB = BlocksToErase[i]; for (Value::user_iterator UI = DeadBB->user_begin(), UE = DeadBB->user_end(); UI != UE;) { @@ -1925,41 +2067,34 @@ bool llvm::runIPSCCP(Module &M, const DataLayout &DL, // Ignore blockaddress users; BasicBlock's dtor will handle them. if (!I) continue; - bool Folded = ConstantFoldTerminator(I->getParent()); - if (!Folded) { - // If the branch can't be folded, we must have forced an edge - // for an indeterminate value. Force the terminator to fold - // to that edge. - Constant *C; - BasicBlock *Dest; - if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { - Dest = SI->case_begin()->getCaseSuccessor(); - C = SI->case_begin()->getCaseValue(); - } else if (BranchInst *BI = dyn_cast<BranchInst>(I)) { - Dest = BI->getSuccessor(1); - C = ConstantInt::getFalse(BI->getContext()); - } else if (IndirectBrInst *IBR = dyn_cast<IndirectBrInst>(I)) { - Dest = IBR->getSuccessor(0); - C = BlockAddress::get(IBR->getSuccessor(0)); - } else { - llvm_unreachable("Unexpected terminator instruction"); - } - assert(Solver.isEdgeFeasible(I->getParent(), Dest) && - "Didn't find feasible edge?"); - (void)Dest; - - I->setOperand(0, C); - Folded = ConstantFoldTerminator(I->getParent()); - } + // If we have forced an edge for an indeterminate value, then force the + // terminator to fold to that edge. + forceIndeterminateEdge(I, Solver); + bool Folded = ConstantFoldTerminator(I->getParent(), + /*DeleteDeadConditions=*/false, + /*TLI=*/nullptr, &DTU); assert(Folded && "Expect TermInst on constantint or blockaddress to be folded"); (void) Folded; } + // Mark dead BB for deletion. + DTU.deleteBB(DeadBB); + } - // Finally, delete the basic block. - F.getBasicBlockList().erase(DeadBB); + for (BasicBlock &BB : F) { + for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { + Instruction *Inst = &*BI++; + if (Solver.getPredicateInfoFor(Inst)) { + if (auto *II = dyn_cast<IntrinsicInst>(Inst)) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) { + Value *Op = II->getOperand(0); + Inst->replaceAllUsesWith(Op); + Inst->eraseFromParent(); + } + } + } + } } - BlocksToErase.clear(); } // If we inferred constant or undef return values for a function, we replaced diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index de16b608f752..eab77cf4cda9 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -913,8 +913,7 @@ private: if (!IsOffsetKnown) return PI.setAborted(&II); - if (II.getIntrinsicID() == Intrinsic::lifetime_start || - II.getIntrinsicID() == Intrinsic::lifetime_end) { + if (II.isLifetimeStartOrEnd()) { ConstantInt *Length = cast<ConstantInt>(II.getArgOperand(0)); uint64_t Size = std::min(AllocSize - Offset.getLimitedValue(), Length->getLimitedValue()); @@ -1060,7 +1059,7 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) // Sort the uses. This arranges for the offsets to be in ascending order, // and the sizes to be in descending order. - llvm::sort(Slices.begin(), Slices.end()); + llvm::sort(Slices); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -1211,7 +1210,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) { // predecessor blocks. The only thing to watch out for is that we can't put // a possibly trapping load in the predecessor if it is a critical edge. for (unsigned Idx = 0, Num = PN.getNumIncomingValues(); Idx != Num; ++Idx) { - TerminatorInst *TI = PN.getIncomingBlock(Idx)->getTerminator(); + Instruction *TI = PN.getIncomingBlock(Idx)->getTerminator(); Value *InVal = PN.getIncomingValue(Idx); // If the value is produced by the terminator of the predecessor (an @@ -1275,7 +1274,7 @@ static void speculatePHINodeLoads(PHINode &PN) { continue; } - TerminatorInst *TI = Pred->getTerminator(); + Instruction *TI = Pred->getTerminator(); IRBuilderTy PredBuilder(TI); LoadInst *Load = PredBuilder.CreateLoad( @@ -1400,8 +1399,8 @@ static Value *getNaturalGEPWithType(IRBuilderTy &IRB, const DataLayout &DL, if (Ty == TargetTy) return buildGEP(IRB, BasePtr, Indices, NamePrefix); - // Pointer size to use for the indices. - unsigned PtrSize = DL.getPointerTypeSizeInBits(BasePtr->getType()); + // Offset size to use for the indices. + unsigned OffsetSize = DL.getIndexTypeSizeInBits(BasePtr->getType()); // See if we can descend into a struct and locate a field with the correct // type. @@ -1413,7 +1412,7 @@ static Value *getNaturalGEPWithType(IRBuilderTy &IRB, const DataLayout &DL, if (ArrayType *ArrayTy = dyn_cast<ArrayType>(ElementTy)) { ElementTy = ArrayTy->getElementType(); - Indices.push_back(IRB.getIntN(PtrSize, 0)); + Indices.push_back(IRB.getIntN(OffsetSize, 0)); } else if (VectorType *VectorTy = dyn_cast<VectorType>(ElementTy)) { ElementTy = VectorTy->getElementType(); Indices.push_back(IRB.getInt32(0)); @@ -1807,8 +1806,7 @@ static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, if (!S.isSplittable()) return false; // Skip any unsplittable intrinsics. } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { - if (II->getIntrinsicID() != Intrinsic::lifetime_start && - II->getIntrinsicID() != Intrinsic::lifetime_end) + if (!II->isLifetimeStartOrEnd()) return false; } else if (U->get()->getType()->getPointerElementType()->isStructTy()) { // Disable vector promotion when there are loads or stores of an FCA. @@ -1906,7 +1904,7 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { "All non-integer types eliminated!"); return RHSTy->getNumElements() < LHSTy->getNumElements(); }; - llvm::sort(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes); + llvm::sort(CandidateTys, RankVectorTypes); CandidateTys.erase( std::unique(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes), CandidateTys.end()); @@ -2029,8 +2027,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S, if (!S.isSplittable()) return false; // Skip any unsplittable intrinsics. } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { - if (II->getIntrinsicID() != Intrinsic::lifetime_start && - II->getIntrinsicID() != Intrinsic::lifetime_end) + if (!II->isLifetimeStartOrEnd()) return false; } else { return false; @@ -2377,7 +2374,7 @@ private: #endif return getAdjustedPtr(IRB, DL, &NewAI, - APInt(DL.getPointerTypeSizeInBits(PointerTy), Offset), + APInt(DL.getIndexTypeSizeInBits(PointerTy), Offset), PointerTy, #ifndef NDEBUG Twine(OldName) + "." @@ -2593,7 +2590,8 @@ private: } V = convertValue(DL, IRB, V, NewAllocaTy); StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); - Store->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); + Store->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access, + LLVMContext::MD_access_group}); if (AATags) Store->setAAMetadata(AATags); Pass.DeadInsts.insert(&SI); @@ -2662,7 +2660,8 @@ private: NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()), SI.isVolatile()); } - NewSI->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); + NewSI->copyMetadata(SI, {LLVMContext::MD_mem_parallel_loop_access, + LLVMContext::MD_access_group}); if (AATags) NewSI->setAAMetadata(AATags); if (SI.isVolatile()) @@ -2899,8 +2898,8 @@ private: unsigned OtherAS = OtherPtrTy->getPointerAddressSpace(); // Compute the relative offset for the other pointer within the transfer. - unsigned IntPtrWidth = DL.getPointerSizeInBits(OtherAS); - APInt OtherOffset(IntPtrWidth, NewBeginOffset - BeginOffset); + unsigned OffsetWidth = DL.getIndexSizeInBits(OtherAS); + APInt OtherOffset(OffsetWidth, NewBeginOffset - BeginOffset); unsigned OtherAlign = IsDest ? II.getSourceAlignment() : II.getDestAlignment(); OtherAlign = MinAlign(OtherAlign ? OtherAlign : 1, @@ -3011,8 +3010,7 @@ private: } bool visitIntrinsicInst(IntrinsicInst &II) { - assert(II.getIntrinsicID() == Intrinsic::lifetime_start || - II.getIntrinsicID() == Intrinsic::lifetime_end); + assert(II.isLifetimeStartOrEnd()); LLVM_DEBUG(dbgs() << " original: " << II << "\n"); assert(II.getArgOperand(1) == OldPtr); @@ -3046,6 +3044,42 @@ private: return true; } + void fixLoadStoreAlign(Instruction &Root) { + // This algorithm implements the same visitor loop as + // hasUnsafePHIOrSelectUse, and fixes the alignment of each load + // or store found. + SmallPtrSet<Instruction *, 4> Visited; + SmallVector<Instruction *, 4> Uses; + Visited.insert(&Root); + Uses.push_back(&Root); + do { + Instruction *I = Uses.pop_back_val(); + + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + unsigned LoadAlign = LI->getAlignment(); + if (!LoadAlign) + LoadAlign = DL.getABITypeAlignment(LI->getType()); + LI->setAlignment(std::min(LoadAlign, getSliceAlign())); + continue; + } + if (StoreInst *SI = dyn_cast<StoreInst>(I)) { + unsigned StoreAlign = SI->getAlignment(); + if (!StoreAlign) { + Value *Op = SI->getOperand(0); + StoreAlign = DL.getABITypeAlignment(Op->getType()); + } + SI->setAlignment(std::min(StoreAlign, getSliceAlign())); + continue; + } + + assert(isa<BitCastInst>(I) || isa<PHINode>(I) || + isa<SelectInst>(I) || isa<GetElementPtrInst>(I)); + for (User *U : I->users()) + if (Visited.insert(cast<Instruction>(U)).second) + Uses.push_back(cast<Instruction>(U)); + } while (!Uses.empty()); + } + bool visitPHINode(PHINode &PN) { LLVM_DEBUG(dbgs() << " original: " << PN << "\n"); assert(BeginOffset >= NewAllocaBeginOffset && "PHIs are unsplittable"); @@ -3069,6 +3103,9 @@ private: LLVM_DEBUG(dbgs() << " to: " << PN << "\n"); deleteIfTriviallyDead(OldPtr); + // Fix the alignment of any loads or stores using this PHI node. + fixLoadStoreAlign(PN); + // PHIs can't be promoted on their own, but often can be speculated. We // check the speculation outside of the rewriter so that we see the // fully-rewritten alloca. @@ -3093,6 +3130,9 @@ private: LLVM_DEBUG(dbgs() << " to: " << SI << "\n"); deleteIfTriviallyDead(OldPtr); + // Fix the alignment of any loads or stores using this select. + fixLoadStoreAlign(SI); + // Selects can't be promoted on their own, but often can be speculated. We // check the speculation outside of the rewriter so that we see the // fully-rewritten alloca. @@ -3122,7 +3162,12 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> { /// value (as opposed to the user). Use *U; + /// Used to calculate offsets, and hence alignment, of subobjects. + const DataLayout &DL; + public: + AggLoadStoreRewriter(const DataLayout &DL) : DL(DL) {} + /// Rewrite loads and stores through a pointer and all pointers derived from /// it. bool rewrite(Instruction &I) { @@ -3166,10 +3211,22 @@ private: /// split operations. Value *Ptr; + /// The base pointee type being GEPed into. + Type *BaseTy; + + /// Known alignment of the base pointer. + unsigned BaseAlign; + + /// To calculate offset of each component so we can correctly deduce + /// alignments. + const DataLayout &DL; + /// Initialize the splitter with an insertion point, Ptr and start with a /// single zero GEP index. - OpSplitter(Instruction *InsertionPoint, Value *Ptr) - : IRB(InsertionPoint), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr) {} + OpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy, + unsigned BaseAlign, const DataLayout &DL) + : IRB(InsertionPoint), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr), + BaseTy(BaseTy), BaseAlign(BaseAlign), DL(DL) {} public: /// Generic recursive split emission routine. @@ -3186,8 +3243,11 @@ private: /// \param Agg The aggregate value being built up or stored, depending on /// whether this is splitting a load or a store respectively. void emitSplitOps(Type *Ty, Value *&Agg, const Twine &Name) { - if (Ty->isSingleValueType()) - return static_cast<Derived *>(this)->emitFunc(Ty, Agg, Name); + if (Ty->isSingleValueType()) { + unsigned Offset = DL.getIndexedOffsetInType(BaseTy, GEPIndices); + return static_cast<Derived *>(this)->emitFunc( + Ty, Agg, MinAlign(BaseAlign, Offset), Name); + } if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) { unsigned OldSize = Indices.size(); @@ -3226,17 +3286,19 @@ private: struct LoadOpSplitter : public OpSplitter<LoadOpSplitter> { AAMDNodes AATags; - LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr, AAMDNodes AATags) - : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr), AATags(AATags) {} + LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy, + AAMDNodes AATags, unsigned BaseAlign, const DataLayout &DL) + : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr, BaseTy, BaseAlign, + DL), AATags(AATags) {} /// Emit a leaf load of a single value. This is called at the leaves of the /// recursive emission to actually load values. - void emitFunc(Type *Ty, Value *&Agg, const Twine &Name) { + void emitFunc(Type *Ty, Value *&Agg, unsigned Align, const Twine &Name) { assert(Ty->isSingleValueType()); // Load the single value and insert it using the indices. Value *GEP = IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); - LoadInst *Load = IRB.CreateLoad(GEP, Name + ".load"); + LoadInst *Load = IRB.CreateAlignedLoad(GEP, Align, Name + ".load"); if (AATags) Load->setAAMetadata(AATags); Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); @@ -3253,7 +3315,8 @@ private: LLVM_DEBUG(dbgs() << " original: " << LI << "\n"); AAMDNodes AATags; LI.getAAMetadata(AATags); - LoadOpSplitter Splitter(&LI, *U, AATags); + LoadOpSplitter Splitter(&LI, *U, LI.getType(), AATags, + getAdjustedAlignment(&LI, 0, DL), DL); Value *V = UndefValue::get(LI.getType()); Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca"); LI.replaceAllUsesWith(V); @@ -3262,13 +3325,15 @@ private: } struct StoreOpSplitter : public OpSplitter<StoreOpSplitter> { - StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr, AAMDNodes AATags) - : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr), AATags(AATags) {} + StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr, Type *BaseTy, + AAMDNodes AATags, unsigned BaseAlign, const DataLayout &DL) + : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr, BaseTy, BaseAlign, + DL), + AATags(AATags) {} AAMDNodes AATags; - /// Emit a leaf store of a single value. This is called at the leaves of the /// recursive emission to actually produce stores. - void emitFunc(Type *Ty, Value *&Agg, const Twine &Name) { + void emitFunc(Type *Ty, Value *&Agg, unsigned Align, const Twine &Name) { assert(Ty->isSingleValueType()); // Extract the single value and store it using the indices. // @@ -3278,7 +3343,8 @@ private: IRB.CreateExtractValue(Agg, Indices, Name + ".extract"); Value *InBoundsGEP = IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); - StoreInst *Store = IRB.CreateStore(ExtractValue, InBoundsGEP); + StoreInst *Store = + IRB.CreateAlignedStore(ExtractValue, InBoundsGEP, Align); if (AATags) Store->setAAMetadata(AATags); LLVM_DEBUG(dbgs() << " to: " << *Store << "\n"); @@ -3296,7 +3362,8 @@ private: LLVM_DEBUG(dbgs() << " original: " << SI << "\n"); AAMDNodes AATags; SI.getAAMetadata(AATags); - StoreOpSplitter Splitter(&SI, *U, AATags); + StoreOpSplitter Splitter(&SI, *U, V->getType(), AATags, + getAdjustedAlignment(&SI, 0, DL), DL); Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); SI.eraseFromParent(); return true; @@ -3730,7 +3797,8 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { PartPtrTy, BasePtr->getName() + "."), getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, LI->getName()); - PLoad->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); + PLoad->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access, + LLVMContext::MD_access_group}); // Append this load onto the list of split loads so we can find it later // to rewrite the stores. @@ -3786,7 +3854,8 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { APInt(DL.getIndexSizeInBits(AS), PartOffset), PartPtrTy, StoreBasePtr->getName() + "."), getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); - PStore->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); + PStore->copyMetadata(*LI, {LLVMContext::MD_mem_parallel_loop_access, + LLVMContext::MD_access_group}); LLVM_DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); } @@ -4179,7 +4248,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } if (!IsSorted) - llvm::sort(AS.begin(), AS.end()); + llvm::sort(AS); /// Describes the allocas introduced by rewritePartition in order to migrate /// the debug info. @@ -4212,7 +4281,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // Migrate debug information from the old alloca to the new alloca(s) // and the individual partitions. - TinyPtrVector<DbgInfoIntrinsic *> DbgDeclares = FindDbgAddrUses(&AI); + TinyPtrVector<DbgVariableIntrinsic *> DbgDeclares = FindDbgAddrUses(&AI); if (!DbgDeclares.empty()) { auto *Var = DbgDeclares.front()->getVariable(); auto *Expr = DbgDeclares.front()->getExpression(); @@ -4264,7 +4333,7 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { } // Remove any existing intrinsics describing the same alloca. - for (DbgInfoIntrinsic *OldDII : FindDbgAddrUses(Fragment.Alloca)) + for (DbgVariableIntrinsic *OldDII : FindDbgAddrUses(Fragment.Alloca)) OldDII->eraseFromParent(); DIB.insertDeclare(Fragment.Alloca, Var, FragmentExpr, @@ -4314,7 +4383,7 @@ bool SROA::runOnAlloca(AllocaInst &AI) { // First, split any FCA loads and stores touching this alloca to promote // better splitting and promotion opportunities. - AggLoadStoreRewriter AggRewriter; + AggLoadStoreRewriter AggRewriter(DL); Changed |= AggRewriter.rewrite(AI); // Build the slices using a recursive instruction-visiting builder. @@ -4379,7 +4448,7 @@ bool SROA::deleteDeadInstructions( // not be able to find it. if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { DeletedAllocas.insert(AI); - for (DbgInfoIntrinsic *OldDII : FindDbgAddrUses(AI)) + for (DbgVariableIntrinsic *OldDII : FindDbgAddrUses(AI)) OldDII->eraseFromParent(); } diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index 526487d3477e..976daf4c78c2 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -25,7 +25,9 @@ #include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" #include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/Scalarizer.h" #include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" +#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" using namespace llvm; @@ -42,7 +44,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeDCELegacyPassPass(Registry); initializeDeadInstEliminationPass(Registry); initializeDivRemPairsLegacyPassPass(Registry); - initializeScalarizerPass(Registry); + initializeScalarizerLegacyPassPass(Registry); initializeDSELegacyPassPass(Registry); initializeGuardWideningLegacyPassPass(Registry); initializeLoopGuardWideningLegacyPassPass(Registry); @@ -50,6 +52,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeNewGVNLegacyPassPass(Registry); initializeEarlyCSELegacyPassPass(Registry); initializeEarlyCSEMemSSALegacyPassPass(Registry); + initializeMakeGuardsExplicitLegacyPassPass(Registry); initializeGVNHoistLegacyPassPass(Registry); initializeGVNSinkLegacyPassPass(Registry); initializeFlattenCFGPassPass(Registry); @@ -72,6 +75,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoopUnrollPass(Registry); initializeLoopUnrollAndJamPass(Registry); initializeLoopUnswitchPass(Registry); + initializeWarnMissedTransformationsLegacyPass(Registry); initializeLoopVersioningLICMPass(Registry); initializeLoopIdiomRecognizeLegacyPassPass(Registry); initializeLowerAtomicLegacyPassPass(Registry); @@ -194,6 +198,10 @@ void LLVMAddLoopUnswitchPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopUnswitchPass()); } +void LLVMAddLowerAtomicPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLowerAtomicPass()); +} + void LLVMAddMemCpyOptPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createMemCpyOptPass()); } @@ -274,3 +282,7 @@ void LLVMAddBasicAliasAnalysisPass(LLVMPassManagerRef PM) { void LLVMAddLowerExpectIntrinsicPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLowerExpectIntrinsicPass()); } + +void LLVMAddUnifyFunctionExitNodesPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createUnifyFunctionExitNodesPass()); +} diff --git a/lib/Transforms/Scalar/Scalarizer.cpp b/lib/Transforms/Scalar/Scalarizer.cpp index 34ed126155be..5eb3fdab6d5c 100644 --- a/lib/Transforms/Scalar/Scalarizer.cpp +++ b/lib/Transforms/Scalar/Scalarizer.cpp @@ -14,6 +14,7 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/VectorUtils.h" @@ -38,6 +39,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/Options.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/Scalarizer.h" #include <cassert> #include <cstdint> #include <iterator> @@ -48,6 +50,13 @@ using namespace llvm; #define DEBUG_TYPE "scalarizer" +// This is disabled by default because having separate loads and stores +// makes it more likely that the -combiner-alias-analysis limits will be +// reached. +static cl::opt<bool> + ScalarizeLoadStore("scalarize-load-store", cl::init(false), cl::Hidden, + cl::desc("Allow the scalarizer pass to scalarize loads and store")); + namespace { // Used to store the scattered form of a vector. @@ -151,17 +160,13 @@ struct VectorLayout { uint64_t ElemSize = 0; }; -class Scalarizer : public FunctionPass, - public InstVisitor<Scalarizer, bool> { +class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> { public: - static char ID; - - Scalarizer() : FunctionPass(ID) { - initializeScalarizerPass(*PassRegistry::getPassRegistry()); + ScalarizerVisitor(unsigned ParallelLoopAccessMDKind) + : ParallelLoopAccessMDKind(ParallelLoopAccessMDKind) { } - bool doInitialization(Module &M) override; - bool runOnFunction(Function &F) override; + bool visit(Function &F); // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. @@ -179,16 +184,6 @@ public: bool visitStoreInst(StoreInst &SI); bool visitCallInst(CallInst &ICI); - static void registerOptions() { - // This is disabled by default because having separate loads and stores - // makes it more likely that the -combiner-alias-analysis limits will be - // reached. - OptionRegistry::registerOption<bool, Scalarizer, - &Scalarizer::ScalarizeLoadStore>( - "scalarize-load-store", - "Allow the scalarizer pass to scalarize loads and store", false); - } - private: Scatterer scatter(Instruction *Point, Value *V); void gather(Instruction *Op, const ValueVector &CV); @@ -204,16 +199,28 @@ private: ScatterMap Scattered; GatherList Gathered; + unsigned ParallelLoopAccessMDKind; - bool ScalarizeLoadStore; }; -} // end anonymous namespace +class ScalarizerLegacyPass : public FunctionPass { +public: + static char ID; -char Scalarizer::ID = 0; + ScalarizerLegacyPass() : FunctionPass(ID) { + initializeScalarizerLegacyPassPass(*PassRegistry::getPassRegistry()); + } -INITIALIZE_PASS_WITH_OPTIONS(Scalarizer, "scalarizer", - "Scalarize vector operations", false, false) + bool runOnFunction(Function &F) override; +}; + +} // end anonymous namespace + +char ScalarizerLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer", + "Scalarize vector operations", false, false) +INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer", + "Scalarize vector operations", false, false) Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, ValueVector *cachePtr) @@ -277,22 +284,31 @@ Value *Scatterer::operator[](unsigned I) { return CV[I]; } -bool Scalarizer::doInitialization(Module &M) { - ParallelLoopAccessMDKind = +bool ScalarizerLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + Module &M = *F.getParent(); + unsigned ParallelLoopAccessMDKind = M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); - ScalarizeLoadStore = - M.getContext().getOption<bool, Scalarizer, &Scalarizer::ScalarizeLoadStore>(); - return false; + ScalarizerVisitor Impl(ParallelLoopAccessMDKind); + return Impl.visit(F); } -bool Scalarizer::runOnFunction(Function &F) { - if (skipFunction(F)) - return false; +FunctionPass *llvm::createScalarizerPass() { + return new ScalarizerLegacyPass(); +} + +bool ScalarizerVisitor::visit(Function &F) { assert(Gathered.empty() && Scattered.empty()); - for (BasicBlock &BB : F) { - for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { + + // To ensure we replace gathered components correctly we need to do an ordered + // traversal of the basic blocks in the function. + ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock()); + for (BasicBlock *BB : RPOT) { + for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { Instruction *I = &*II; - bool Done = visit(I); + bool Done = InstVisitor::visit(I); ++II; if (Done && I->getType()->isVoidTy()) I->eraseFromParent(); @@ -303,7 +319,7 @@ bool Scalarizer::runOnFunction(Function &F) { // Return a scattered form of V that can be accessed by Point. V must be a // vector or a pointer to a vector. -Scatterer Scalarizer::scatter(Instruction *Point, Value *V) { +Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V) { if (Argument *VArg = dyn_cast<Argument>(V)) { // Put the scattered form of arguments in the entry block, // so that it can be used everywhere. @@ -327,7 +343,7 @@ Scatterer Scalarizer::scatter(Instruction *Point, Value *V) { // deletion of Op and creation of the gathered form to the end of the pass, // so that we can avoid creating the gathered form if all uses of Op are // replaced with uses of CV. -void Scalarizer::gather(Instruction *Op, const ValueVector &CV) { +void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV) { // Since we're not deleting Op yet, stub out its operands, so that it // doesn't make anything live unnecessarily. for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) @@ -356,19 +372,20 @@ void Scalarizer::gather(Instruction *Op, const ValueVector &CV) { // Return true if it is safe to transfer the given metadata tag from // vector to scalar instructions. -bool Scalarizer::canTransferMetadata(unsigned Tag) { +bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) { return (Tag == LLVMContext::MD_tbaa || Tag == LLVMContext::MD_fpmath || Tag == LLVMContext::MD_tbaa_struct || Tag == LLVMContext::MD_invariant_load || Tag == LLVMContext::MD_alias_scope || Tag == LLVMContext::MD_noalias - || Tag == ParallelLoopAccessMDKind); + || Tag == ParallelLoopAccessMDKind + || Tag == LLVMContext::MD_access_group); } // Transfer metadata from Op to the instructions in CV if it is known // to be safe to do so. -void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) { +void ScalarizerVisitor::transferMetadata(Instruction *Op, const ValueVector &CV) { SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; Op->getAllMetadataOtherThanDebugLoc(MDs); for (unsigned I = 0, E = CV.size(); I != E; ++I) { @@ -384,7 +401,7 @@ void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) { // Try to fill in Layout from Ty, returning true on success. Alignment is // the alignment of the vector, or 0 if the ABI default should be used. -bool Scalarizer::getVectorLayout(Type *Ty, unsigned Alignment, +bool ScalarizerVisitor::getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout, const DataLayout &DL) { // Make sure we're dealing with a vector. Layout.VecTy = dyn_cast<VectorType>(Ty); @@ -408,7 +425,7 @@ bool Scalarizer::getVectorLayout(Type *Ty, unsigned Alignment, // Scalarize two-operand instruction I, using Split(Builder, X, Y, Name) // to create an instruction like I with operands X and Y and name Name. template<typename Splitter> -bool Scalarizer::splitBinary(Instruction &I, const Splitter &Split) { +bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { VectorType *VT = dyn_cast<VectorType>(I.getType()); if (!VT) return false; @@ -441,7 +458,7 @@ static Function *getScalarIntrinsicDeclaration(Module *M, /// If a call to a vector typed intrinsic function, split into a scalar call per /// element if possible for the intrinsic. -bool Scalarizer::splitCall(CallInst &CI) { +bool ScalarizerVisitor::splitCall(CallInst &CI) { VectorType *VT = dyn_cast<VectorType>(CI.getType()); if (!VT) return false; @@ -499,7 +516,7 @@ bool Scalarizer::splitCall(CallInst &CI) { return true; } -bool Scalarizer::visitSelectInst(SelectInst &SI) { +bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) { VectorType *VT = dyn_cast<VectorType>(SI.getType()); if (!VT) return false; @@ -529,19 +546,19 @@ bool Scalarizer::visitSelectInst(SelectInst &SI) { return true; } -bool Scalarizer::visitICmpInst(ICmpInst &ICI) { +bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) { return splitBinary(ICI, ICmpSplitter(ICI)); } -bool Scalarizer::visitFCmpInst(FCmpInst &FCI) { +bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) { return splitBinary(FCI, FCmpSplitter(FCI)); } -bool Scalarizer::visitBinaryOperator(BinaryOperator &BO) { +bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) { return splitBinary(BO, BinarySplitter(BO)); } -bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) { +bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) { VectorType *VT = dyn_cast<VectorType>(GEPI.getType()); if (!VT) return false; @@ -587,7 +604,7 @@ bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) { return true; } -bool Scalarizer::visitCastInst(CastInst &CI) { +bool ScalarizerVisitor::visitCastInst(CastInst &CI) { VectorType *VT = dyn_cast<VectorType>(CI.getDestTy()); if (!VT) return false; @@ -605,7 +622,7 @@ bool Scalarizer::visitCastInst(CastInst &CI) { return true; } -bool Scalarizer::visitBitCastInst(BitCastInst &BCI) { +bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) { VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy()); VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy()); if (!DstVT || !SrcVT) @@ -660,7 +677,7 @@ bool Scalarizer::visitBitCastInst(BitCastInst &BCI) { return true; } -bool Scalarizer::visitShuffleVectorInst(ShuffleVectorInst &SVI) { +bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) { VectorType *VT = dyn_cast<VectorType>(SVI.getType()); if (!VT) return false; @@ -684,7 +701,7 @@ bool Scalarizer::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return true; } -bool Scalarizer::visitPHINode(PHINode &PHI) { +bool ScalarizerVisitor::visitPHINode(PHINode &PHI) { VectorType *VT = dyn_cast<VectorType>(PHI.getType()); if (!VT) return false; @@ -709,7 +726,7 @@ bool Scalarizer::visitPHINode(PHINode &PHI) { return true; } -bool Scalarizer::visitLoadInst(LoadInst &LI) { +bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) { if (!ScalarizeLoadStore) return false; if (!LI.isSimple()) @@ -733,7 +750,7 @@ bool Scalarizer::visitLoadInst(LoadInst &LI) { return true; } -bool Scalarizer::visitStoreInst(StoreInst &SI) { +bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) { if (!ScalarizeLoadStore) return false; if (!SI.isSimple()) @@ -760,13 +777,13 @@ bool Scalarizer::visitStoreInst(StoreInst &SI) { return true; } -bool Scalarizer::visitCallInst(CallInst &CI) { +bool ScalarizerVisitor::visitCallInst(CallInst &CI) { return splitCall(CI); } // Delete the instructions that we scalarized. If a full vector result // is still needed, recreate it using InsertElements. -bool Scalarizer::finish() { +bool ScalarizerVisitor::finish() { // The presence of data in Gathered or Scattered indicates changes // made to the Function. if (Gathered.empty() && Scattered.empty()) @@ -797,6 +814,11 @@ bool Scalarizer::finish() { return true; } -FunctionPass *llvm::createScalarizerPass() { - return new Scalarizer(); +PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) { + Module &M = *F.getParent(); + unsigned ParallelLoopAccessMDKind = + M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); + ScalarizerVisitor Impl(ParallelLoopAccessMDKind); + bool Changed = Impl.visit(F); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } diff --git a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 5834b619046b..5a67178cef37 100644 --- a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -19,11 +19,14 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/Utils/Local.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -59,7 +62,11 @@ using namespace llvm; STATISTIC(NumBranches, "Number of branches unswitched"); STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumGuards, "Number of guards turned into branches for unswitching"); STATISTIC(NumTrivial, "Number of unswitches that are trivial"); +STATISTIC( + NumCostMultiplierSkipped, + "Number of unswitch candidates that had their cost multiplier skipped"); static cl::opt<bool> EnableNonTrivialUnswitch( "enable-nontrivial-unswitch", cl::init(false), cl::Hidden, @@ -70,6 +77,22 @@ static cl::opt<int> UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, cl::desc("The cost threshold for unswitching a loop.")); +static cl::opt<bool> EnableUnswitchCostMultiplier( + "enable-unswitch-cost-multiplier", cl::init(true), cl::Hidden, + cl::desc("Enable unswitch cost multiplier that prohibits exponential " + "explosion in nontrivial unswitch.")); +static cl::opt<int> UnswitchSiblingsToplevelDiv( + "unswitch-siblings-toplevel-div", cl::init(2), cl::Hidden, + cl::desc("Toplevel siblings divisor for cost multiplier.")); +static cl::opt<int> UnswitchNumInitialUnscaledCandidates( + "unswitch-num-initial-unscaled-candidates", cl::init(8), cl::Hidden, + cl::desc("Number of unswitch candidates that are ignored when calculating " + "cost multiplier.")); +static cl::opt<bool> UnswitchGuards( + "simple-loop-unswitch-guards", cl::init(true), cl::Hidden, + cl::desc("If enabled, simple loop unswitching will also consider " + "llvm.experimental.guard intrinsics as unswitch candidates.")); + /// Collect all of the loop invariant input values transitively used by the /// homogeneous instruction graph from a given root. /// @@ -302,10 +325,11 @@ static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader, formLCSSA(*OldContainingL, DT, &LI, nullptr); // We shouldn't need to form dedicated exits because the exit introduced - // here is the (just split by unswitching) preheader. As such, it is - // necessarily dedicated. - assert(OldContainingL->hasDedicatedExits() && - "Unexpected predecessor of hoisted loop preheader!"); + // here is the (just split by unswitching) preheader. However, after trivial + // unswitching it is possible to get new non-dedicated exits out of parent + // loop so let's conservatively form dedicated exit blocks and figure out + // if we can optimize later. + formDedicatedExitBlocks(OldContainingL, &DT, &LI, /*PreserveLCSSA*/ true); } } @@ -327,7 +351,8 @@ static void hoistLoopToNewParent(Loop &L, BasicBlock &Preheader, /// If `SE` is not null, it will be updated based on the potential loop SCEVs /// invalidated by this. static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, - LoopInfo &LI, ScalarEvolution *SE) { + LoopInfo &LI, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); @@ -401,11 +426,14 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, SE->forgetTopmostLoop(&L); } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional // branch on LoopCond. BasicBlock *OldPH = L.getLoopPreheader(); - BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI, MSSAU); // Now that we have a place to insert the conditional branch, create a place // to branch to: this is the exit block out of the loop that we are @@ -417,9 +445,13 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, "A branch's parent isn't a predecessor!"); UnswitchedBB = LoopExitBB; } else { - UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); + UnswitchedBB = + SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI, MSSAU); } + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // Actually move the invariant uses into the unswitched position. If possible, // we do this by moving the instructions, but when doing partial unswitching // we do it by building a new merge of the values in the unswitched position. @@ -430,12 +462,17 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // its successors. OldPH->getInstList().splice(OldPH->end(), BI.getParent()->getInstList(), BI); + if (MSSAU) { + // Temporarily clone the terminator, to make MSSA update cheaper by + // separating "insert edge" updates from "remove edge" ones. + ParentBB->getInstList().push_back(BI.clone()); + } else { + // Create a new unconditional branch that will continue the loop as a new + // terminator. + BranchInst::Create(ContinueBB, ParentBB); + } BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); - - // Create a new unconditional branch that will continue the loop as a new - // terminator. - BranchInst::Create(ContinueBB, ParentBB); } else { // Only unswitching a subset of inputs to the condition, so we will need to // build a new branch that merges the invariant inputs. @@ -451,6 +488,32 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, *UnswitchedBB, *NewPH); } + // Update the dominator tree with the added edge. + DT.insertEdge(OldPH, UnswitchedBB); + + // After the dominator tree was updated with the added edge, update MemorySSA + // if available. + if (MSSAU) { + SmallVector<CFGUpdate, 1> Updates; + Updates.push_back({cfg::UpdateKind::Insert, OldPH, UnswitchedBB}); + MSSAU->applyInsertUpdates(Updates, DT); + } + + // Finish updating dominator tree and memory ssa for full unswitch. + if (FullUnswitch) { + if (MSSAU) { + // Remove the cloned branch instruction. + ParentBB->getTerminator()->eraseFromParent(); + // Create unconditional branch now. + BranchInst::Create(ContinueBB, ParentBB); + MSSAU->removeEdge(ParentBB, LoopExitBB); + } + DT.deleteEdge(ParentBB, LoopExitBB); + } + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // Rewrite the relevant PHI nodes. if (UnswitchedBB == LoopExitBB) rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH); @@ -458,13 +521,6 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, rewritePHINodesForExitAndUnswitchedBlocks(*LoopExitBB, *UnswitchedBB, *ParentBB, *OldPH, FullUnswitch); - // Now we need to update the dominator tree. - SmallVector<DominatorTree::UpdateType, 2> DTUpdates; - DTUpdates.push_back({DT.Insert, OldPH, UnswitchedBB}); - if (FullUnswitch) - DTUpdates.push_back({DT.Delete, ParentBB, LoopExitBB}); - DT.applyUpdates(DTUpdates); - // The constant we can replace all of our invariants with inside the loop // body. If any of the invariants have a value other than this the loop won't // be entered. @@ -482,6 +538,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, if (FullUnswitch) hoistLoopToNewParent(L, *NewPH, DT, LI); + LLVM_DEBUG(dbgs() << " done: unswitching trivial branch...\n"); ++NumTrivial; ++NumBranches; return true; @@ -514,7 +571,8 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, /// If `SE` is not null, it will be updated based on the potential loop SCEVs /// invalidated by this. static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, - LoopInfo &LI, ScalarEvolution *SE) { + LoopInfo &LI, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU) { LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); Value *LoopCond = SI.getCondition(); @@ -539,7 +597,10 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, else if (ExitCaseIndices.empty()) return false; - LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n"); + LLVM_DEBUG(dbgs() << " unswitching trivial switch...\n"); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); // We may need to invalidate SCEVs for the outermost loop reached by any of // the exits. @@ -603,7 +664,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, // Split the preheader, so that we know that there is a safe place to insert // the switch. BasicBlock *OldPH = L.getLoopPreheader(); - BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI, MSSAU); OldPH->getTerminator()->eraseFromParent(); // Now add the unswitched switch. @@ -626,9 +687,10 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH); } else { auto *SplitBB = - SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); - rewritePHINodesForExitAndUnswitchedBlocks( - *DefaultExitBB, *SplitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); + SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI, MSSAU); + rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, + *ParentBB, *OldPH, + /*FullUnswitch*/ true); DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; } } @@ -652,9 +714,10 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, BasicBlock *&SplitExitBB = SplitExitBBMap[ExitBB]; if (!SplitExitBB) { // If this is the first time we see this, do the split and remember it. - SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); - rewritePHINodesForExitAndUnswitchedBlocks( - *ExitBB, *SplitExitBB, *ParentBB, *OldPH, /*FullUnswitch*/ true); + SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSAU); + rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, + *ParentBB, *OldPH, + /*FullUnswitch*/ true); } // Update the case pair to point to the split block. CasePair.second = SplitExitBB; @@ -731,6 +794,13 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, DTUpdates.push_back({DT.Insert, OldPH, UnswitchedBB}); } DT.applyUpdates(DTUpdates); + + if (MSSAU) { + MSSAU->applyUpdates(DTUpdates, DT); + if (VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + } + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); // We may have changed the nesting relationship for this loop so hoist it to @@ -739,6 +809,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, ++NumTrivial; ++NumSwitches; + LLVM_DEBUG(dbgs() << " done: unswitching trivial switch...\n"); return true; } @@ -755,7 +826,8 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, /// If `SE` is not null, it will be updated based on the potential loop SCEVs /// invalidated by this. static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, - LoopInfo &LI, ScalarEvolution *SE) { + LoopInfo &LI, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU) { bool Changed = false; // If loop header has only one reachable successor we should keep looking for @@ -780,7 +852,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, [](Instruction &I) { return I.mayHaveSideEffects(); })) return Changed; - TerminatorInst *CurrentTerm = CurrentBB->getTerminator(); + Instruction *CurrentTerm = CurrentBB->getTerminator(); if (auto *SI = dyn_cast<SwitchInst>(CurrentTerm)) { // Don't bother trying to unswitch past a switch with a constant @@ -789,7 +861,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, if (isa<Constant>(SI->getCondition())) return Changed; - if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE)) + if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE, MSSAU)) // Couldn't unswitch this one so we're done. return Changed; @@ -821,7 +893,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, // Found a trivial condition candidate: non-foldable conditional branch. If // we fail to unswitch this, we can't do anything else that is trivial. - if (!unswitchTrivialBranch(L, *BI, DT, LI, SE)) + if (!unswitchTrivialBranch(L, *BI, DT, LI, SE, MSSAU)) return Changed; // Mark that we managed to unswitch something. @@ -874,7 +946,7 @@ static BasicBlock *buildClonedLoopBlocks( const SmallDenseMap<BasicBlock *, BasicBlock *, 16> &DominatingSucc, ValueToValueMapTy &VMap, SmallVectorImpl<DominatorTree::UpdateType> &DTUpdates, AssumptionCache &AC, - DominatorTree &DT, LoopInfo &LI) { + DominatorTree &DT, LoopInfo &LI, MemorySSAUpdater *MSSAU) { SmallVector<BasicBlock *, 4> NewBlocks; NewBlocks.reserve(L.getNumBlocks() + ExitBlocks.size()); @@ -919,7 +991,7 @@ static BasicBlock *buildClonedLoopBlocks( // place to merge the CFG, so split the exit first. This is always safe to // do because there cannot be any non-loop predecessors of a loop exit in // loop simplified form. - auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSAU); // Rearrange the names to make it easier to write test cases by having the // exit block carry the suffix rather than the merge block carrying the @@ -1262,11 +1334,10 @@ static void buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, // matter as we're just trying to build up the map from inside-out; we use // the map in a more stably ordered way below. auto OrderedClonedExitsInLoops = ClonedExitsInLoops; - llvm::sort(OrderedClonedExitsInLoops.begin(), OrderedClonedExitsInLoops.end(), - [&](BasicBlock *LHS, BasicBlock *RHS) { - return ExitLoopMap.lookup(LHS)->getLoopDepth() < - ExitLoopMap.lookup(RHS)->getLoopDepth(); - }); + llvm::sort(OrderedClonedExitsInLoops, [&](BasicBlock *LHS, BasicBlock *RHS) { + return ExitLoopMap.lookup(LHS)->getLoopDepth() < + ExitLoopMap.lookup(RHS)->getLoopDepth(); + }); // Populate the existing ExitLoopMap with everything reachable from each // exit, starting from the inner most exit. @@ -1351,7 +1422,7 @@ static void buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, static void deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, ArrayRef<std::unique_ptr<ValueToValueMapTy>> VMaps, - DominatorTree &DT) { + DominatorTree &DT, MemorySSAUpdater *MSSAU) { // Find all the dead clones, and remove them from their successors. SmallVector<BasicBlock *, 16> DeadBlocks; for (BasicBlock *BB : llvm::concat<BasicBlock *const>(L.blocks(), ExitBlocks)) @@ -1363,6 +1434,13 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, DeadBlocks.push_back(ClonedBB); } + // Remove all MemorySSA in the dead blocks + if (MSSAU) { + SmallPtrSet<BasicBlock *, 16> DeadBlockSet(DeadBlocks.begin(), + DeadBlocks.end()); + MSSAU->removeBlocks(DeadBlockSet); + } + // Drop any remaining references to break cycles. for (BasicBlock *BB : DeadBlocks) BB->dropAllReferences(); @@ -1371,21 +1449,33 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, BB->eraseFromParent(); } -static void -deleteDeadBlocksFromLoop(Loop &L, - SmallVectorImpl<BasicBlock *> &ExitBlocks, - DominatorTree &DT, LoopInfo &LI) { - // Find all the dead blocks, and remove them from their successors. - SmallVector<BasicBlock *, 16> DeadBlocks; - for (BasicBlock *BB : llvm::concat<BasicBlock *const>(L.blocks(), ExitBlocks)) - if (!DT.isReachableFromEntry(BB)) { - for (BasicBlock *SuccBB : successors(BB)) +static void deleteDeadBlocksFromLoop(Loop &L, + SmallVectorImpl<BasicBlock *> &ExitBlocks, + DominatorTree &DT, LoopInfo &LI, + MemorySSAUpdater *MSSAU) { + // Find all the dead blocks tied to this loop, and remove them from their + // successors. + SmallPtrSet<BasicBlock *, 16> DeadBlockSet; + + // Start with loop/exit blocks and get a transitive closure of reachable dead + // blocks. + SmallVector<BasicBlock *, 16> DeathCandidates(ExitBlocks.begin(), + ExitBlocks.end()); + DeathCandidates.append(L.blocks().begin(), L.blocks().end()); + while (!DeathCandidates.empty()) { + auto *BB = DeathCandidates.pop_back_val(); + if (!DeadBlockSet.count(BB) && !DT.isReachableFromEntry(BB)) { + for (BasicBlock *SuccBB : successors(BB)) { SuccBB->removePredecessor(BB); - DeadBlocks.push_back(BB); + DeathCandidates.push_back(SuccBB); + } + DeadBlockSet.insert(BB); } + } - SmallPtrSet<BasicBlock *, 16> DeadBlockSet(DeadBlocks.begin(), - DeadBlocks.end()); + // Remove all MemorySSA in the dead blocks + if (MSSAU) + MSSAU->removeBlocks(DeadBlockSet); // Filter out the dead blocks from the exit blocks list so that it can be // used in the caller. @@ -1394,7 +1484,7 @@ deleteDeadBlocksFromLoop(Loop &L, // Walk from this loop up through its parents removing all of the dead blocks. for (Loop *ParentL = &L; ParentL; ParentL = ParentL->getParentLoop()) { - for (auto *BB : DeadBlocks) + for (auto *BB : DeadBlockSet) ParentL->getBlocksSet().erase(BB); llvm::erase_if(ParentL->getBlocksVector(), [&](BasicBlock *BB) { return DeadBlockSet.count(BB); }); @@ -1419,7 +1509,7 @@ deleteDeadBlocksFromLoop(Loop &L, // Remove the loop mappings for the dead blocks and drop all the references // from these blocks to others to handle cyclic references as we start // deleting the blocks themselves. - for (auto *BB : DeadBlocks) { + for (auto *BB : DeadBlockSet) { // Check that the dominator tree has already been updated. assert(!DT.getNode(BB) && "Should already have cleared domtree!"); LI.changeLoopFor(BB, nullptr); @@ -1428,7 +1518,7 @@ deleteDeadBlocksFromLoop(Loop &L, // Actually delete the blocks now that they've been fully unhooked from the // IR. - for (auto *BB : DeadBlocks) + for (auto *BB : DeadBlockSet) BB->eraseFromParent(); } @@ -1782,11 +1872,11 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { } while (!DomWorklist.empty()); } -static bool unswitchNontrivialInvariants( - Loop &L, TerminatorInst &TI, ArrayRef<Value *> Invariants, - DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE) { +static void unswitchNontrivialInvariants( + Loop &L, Instruction &TI, ArrayRef<Value *> Invariants, + SmallVectorImpl<BasicBlock *> &ExitBlocks, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast<BranchInst>(&TI); SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI); @@ -1803,6 +1893,9 @@ static bool unswitchNontrivialInvariants( assert(isa<Instruction>(BI->getCondition()) && "Partial unswitching requires an instruction as the condition!"); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // Constant and BBs tracking the cloned and continuing successor. When we are // unswitching the entire condition, this can just be trivially chosen to // unswitch towards `true`. However, when we are unswitching a set of @@ -1841,19 +1934,12 @@ static bool unswitchNontrivialInvariants( // whatever reason). assert(LI.getLoopFor(ParentBB) == &L && "Branch in an inner loop!"); - SmallVector<BasicBlock *, 4> ExitBlocks; - L.getUniqueExitBlocks(ExitBlocks); - - // We cannot unswitch if exit blocks contain a cleanuppad instruction as we - // don't know how to split those exit blocks. - // FIXME: We should teach SplitBlock to handle this and remove this - // restriction. - for (auto *ExitBB : ExitBlocks) - if (isa<CleanupPadInst>(ExitBB->getFirstNonPHI())) - return false; - // Compute the parent loop now before we start hacking on things. Loop *ParentL = L.getParentLoop(); + // Get blocks in RPO order for MSSA update, before changing the CFG. + LoopBlocksRPO LBRPO(&L); + if (MSSAU) + LBRPO.perform(&LI); // Compute the outer-most loop containing one of our exit blocks. This is the // furthest up our loopnest which can be mutated, which we will use below to @@ -1903,7 +1989,7 @@ static bool unswitchNontrivialInvariants( // between the unswitched versions, and we will have a new preheader for the // original loop. BasicBlock *SplitBB = L.getLoopPreheader(); - BasicBlock *LoopPH = SplitEdge(SplitBB, L.getHeader(), &DT, &LI); + BasicBlock *LoopPH = SplitEdge(SplitBB, L.getHeader(), &DT, &LI, MSSAU); // Keep track of the dominator tree updates needed. SmallVector<DominatorTree::UpdateType, 4> DTUpdates; @@ -1916,7 +2002,7 @@ static bool unswitchNontrivialInvariants( VMaps.emplace_back(new ValueToValueMapTy()); ClonedPHs[SuccBB] = buildClonedLoopBlocks( L, LoopPH, SplitBB, ExitBlocks, ParentBB, SuccBB, RetainedSuccBB, - DominatingSucc, *VMaps.back(), DTUpdates, AC, DT, LI); + DominatingSucc, *VMaps.back(), DTUpdates, AC, DT, LI, MSSAU); } // The stitching of the branched code back together depends on whether we're @@ -1924,7 +2010,63 @@ static bool unswitchNontrivialInvariants( // nuke the initial terminator placed in the split block. SplitBB->getTerminator()->eraseFromParent(); if (FullUnswitch) { - // First we need to unhook the successor relationship as we'll be replacing + // Splice the terminator from the original loop and rewrite its + // successors. + SplitBB->getInstList().splice(SplitBB->end(), ParentBB->getInstList(), TI); + + // Keep a clone of the terminator for MSSA updates. + Instruction *NewTI = TI.clone(); + ParentBB->getInstList().push_back(NewTI); + + // First wire up the moved terminator to the preheaders. + if (BI) { + BasicBlock *ClonedPH = ClonedPHs.begin()->second; + BI->setSuccessor(ClonedSucc, ClonedPH); + BI->setSuccessor(1 - ClonedSucc, LoopPH); + DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); + } else { + assert(SI && "Must either be a branch or switch!"); + + // Walk the cases and directly update their successors. + assert(SI->getDefaultDest() == RetainedSuccBB && + "Not retaining default successor!"); + SI->setDefaultDest(LoopPH); + for (auto &Case : SI->cases()) + if (Case.getCaseSuccessor() == RetainedSuccBB) + Case.setSuccessor(LoopPH); + else + Case.setSuccessor(ClonedPHs.find(Case.getCaseSuccessor())->second); + + // We need to use the set to populate domtree updates as even when there + // are multiple cases pointing at the same successor we only want to + // remove and insert one edge in the domtree. + for (BasicBlock *SuccBB : UnswitchedSuccBBs) + DTUpdates.push_back( + {DominatorTree::Insert, SplitBB, ClonedPHs.find(SuccBB)->second}); + } + + if (MSSAU) { + DT.applyUpdates(DTUpdates); + DTUpdates.clear(); + + // Remove all but one edge to the retained block and all unswitched + // blocks. This is to avoid having duplicate entries in the cloned Phis, + // when we know we only keep a single edge for each case. + MSSAU->removeDuplicatePhiEdgesBetween(ParentBB, RetainedSuccBB); + for (BasicBlock *SuccBB : UnswitchedSuccBBs) + MSSAU->removeDuplicatePhiEdgesBetween(ParentBB, SuccBB); + + for (auto &VMap : VMaps) + MSSAU->updateForClonedLoop(LBRPO, ExitBlocks, *VMap, + /*IgnoreIncomingWithNoClones=*/true); + MSSAU->updateExitBlocksForClonedLoop(ExitBlocks, VMaps, DT); + + // Remove all edges to unswitched blocks. + for (BasicBlock *SuccBB : UnswitchedSuccBBs) + MSSAU->removeEdge(ParentBB, SuccBB); + } + + // Now unhook the successor relationship as we'll be replacing // the terminator with a direct branch. This is much simpler for branches // than switches so we handle those first. if (BI) { @@ -1942,9 +2084,10 @@ static bool unswitchNontrivialInvariants( // is a duplicate edge to the retained successor as the retained successor // is always the default successor and as we'll replace this with a direct // branch we no longer need the duplicate entries in the PHI nodes. - assert(SI->getDefaultDest() == RetainedSuccBB && + SwitchInst *NewSI = cast<SwitchInst>(NewTI); + assert(NewSI->getDefaultDest() == RetainedSuccBB && "Not retaining default successor!"); - for (auto &Case : SI->cases()) + for (auto &Case : NewSI->cases()) Case.getCaseSuccessor()->removePredecessor( ParentBB, /*DontDeleteUselessPHIs*/ true); @@ -1956,34 +2099,8 @@ static bool unswitchNontrivialInvariants( DTUpdates.push_back({DominatorTree::Delete, ParentBB, SuccBB}); } - // Now that we've unhooked the successor relationship, splice the terminator - // from the original loop to the split. - SplitBB->getInstList().splice(SplitBB->end(), ParentBB->getInstList(), TI); - - // Now wire up the terminator to the preheaders. - if (BI) { - BasicBlock *ClonedPH = ClonedPHs.begin()->second; - BI->setSuccessor(ClonedSucc, ClonedPH); - BI->setSuccessor(1 - ClonedSucc, LoopPH); - DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); - } else { - assert(SI && "Must either be a branch or switch!"); - - // Walk the cases and directly update their successors. - SI->setDefaultDest(LoopPH); - for (auto &Case : SI->cases()) - if (Case.getCaseSuccessor() == RetainedSuccBB) - Case.setSuccessor(LoopPH); - else - Case.setSuccessor(ClonedPHs.find(Case.getCaseSuccessor())->second); - - // We need to use the set to populate domtree updates as even when there - // are multiple cases pointing at the same successor we only want to - // remove and insert one edge in the domtree. - for (BasicBlock *SuccBB : UnswitchedSuccBBs) - DTUpdates.push_back( - {DominatorTree::Insert, SplitBB, ClonedPHs.find(SuccBB)->second}); - } + // After MSSAU update, remove the cloned terminator instruction NewTI. + ParentBB->getTerminator()->eraseFromParent(); // Create a new unconditional branch to the continuing block (as opposed to // the one cloned). @@ -2002,12 +2119,19 @@ static bool unswitchNontrivialInvariants( // Apply the updates accumulated above to get an up-to-date dominator tree. DT.applyUpdates(DTUpdates); + if (!FullUnswitch && MSSAU) { + // Update MSSA for partial unswitch, after DT update. + SmallVector<CFGUpdate, 1> Updates; + Updates.push_back( + {cfg::UpdateKind::Insert, SplitBB, ClonedPHs.begin()->second}); + MSSAU->applyInsertUpdates(Updates, DT); + } // Now that we have an accurate dominator tree, first delete the dead cloned // blocks so that we can accurately build any cloned loops. It is important to // not delete the blocks from the original loop yet because we still want to // reference the original loop to understand the cloned loop's structure. - deleteDeadClonedBlocks(L, ExitBlocks, VMaps, DT); + deleteDeadClonedBlocks(L, ExitBlocks, VMaps, DT, MSSAU); // Build the cloned loop structure itself. This may be substantially // different from the original structure due to the simplified CFG. This also @@ -2019,10 +2143,17 @@ static bool unswitchNontrivialInvariants( // Now that our cloned loops have been built, we can update the original loop. // First we delete the dead blocks from it and then we rebuild the loop // structure taking these deletions into account. - deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI); + deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + SmallVector<Loop *, 4> HoistedLoops; bool IsStillLoop = rebuildLoopAfterUnswitch(L, ExitBlocks, LI, HoistedLoops); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // This transformation has a high risk of corrupting the dominator tree, and // the below steps to rebuild loop structures will result in hard to debug // errors in that case so verify that the dominator tree is sane first. @@ -2038,6 +2169,18 @@ static bool unswitchNontrivialInvariants( assert(UnswitchedSuccBBs.size() == 1 && "Only one possible unswitched block for a branch!"); BasicBlock *ClonedPH = ClonedPHs.begin()->second; + + // When considering multiple partially-unswitched invariants + // we cant just go replace them with constants in both branches. + // + // For 'AND' we infer that true branch ("continue") means true + // for each invariant operand. + // For 'OR' we can infer that false branch ("continue") means false + // for each invariant operand. + // So it happens that for multiple-partial case we dont replace + // in the unswitched branch. + bool ReplaceUnswitched = FullUnswitch || (Invariants.size() == 1); + ConstantInt *UnswitchedReplacement = Direction ? ConstantInt::getTrue(BI->getContext()) : ConstantInt::getFalse(BI->getContext()); @@ -2057,7 +2200,8 @@ static bool unswitchNontrivialInvariants( // unswitched if in the cloned blocks. if (DT.dominates(LoopPH, UserI->getParent())) U->set(ContinueReplacement); - else if (DT.dominates(ClonedPH, UserI->getParent())) + else if (ReplaceUnswitched && + DT.dominates(ClonedPH, UserI->getParent())) U->set(UnswitchedReplacement); } } @@ -2134,8 +2278,13 @@ static bool unswitchNontrivialInvariants( SibLoops.push_back(UpdatedL); UnswitchCB(IsStillLoop, SibLoops); - ++NumBranches; - return true; + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + + if (BI) + ++NumBranches; + else + ++NumSwitches; } /// Recursively compute the cost of a dominator subtree based on the per-block @@ -2171,19 +2320,208 @@ computeDomSubtreeCost(DomTreeNode &N, return Cost; } +/// Turns a llvm.experimental.guard intrinsic into implicit control flow branch, +/// making the following replacement: +/// +/// --code before guard-- +/// call void (i1, ...) @llvm.experimental.guard(i1 %cond) [ "deopt"() ] +/// --code after guard-- +/// +/// into +/// +/// --code before guard-- +/// br i1 %cond, label %guarded, label %deopt +/// +/// guarded: +/// --code after guard-- +/// +/// deopt: +/// call void (i1, ...) @llvm.experimental.guard(i1 false) [ "deopt"() ] +/// unreachable +/// +/// It also makes all relevant DT and LI updates, so that all structures are in +/// valid state after this transform. +static BranchInst * +turnGuardIntoBranch(IntrinsicInst *GI, Loop &L, + SmallVectorImpl<BasicBlock *> &ExitBlocks, + DominatorTree &DT, LoopInfo &LI, MemorySSAUpdater *MSSAU) { + SmallVector<DominatorTree::UpdateType, 4> DTUpdates; + LLVM_DEBUG(dbgs() << "Turning " << *GI << " into a branch.\n"); + BasicBlock *CheckBB = GI->getParent(); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + + // Remove all CheckBB's successors from DomTree. A block can be seen among + // successors more than once, but for DomTree it should be added only once. + SmallPtrSet<BasicBlock *, 4> Successors; + for (auto *Succ : successors(CheckBB)) + if (Successors.insert(Succ).second) + DTUpdates.push_back({DominatorTree::Delete, CheckBB, Succ}); + + Instruction *DeoptBlockTerm = + SplitBlockAndInsertIfThen(GI->getArgOperand(0), GI, true); + BranchInst *CheckBI = cast<BranchInst>(CheckBB->getTerminator()); + // SplitBlockAndInsertIfThen inserts control flow that branches to + // DeoptBlockTerm if the condition is true. We want the opposite. + CheckBI->swapSuccessors(); + + BasicBlock *GuardedBlock = CheckBI->getSuccessor(0); + GuardedBlock->setName("guarded"); + CheckBI->getSuccessor(1)->setName("deopt"); + BasicBlock *DeoptBlock = CheckBI->getSuccessor(1); + + // We now have a new exit block. + ExitBlocks.push_back(CheckBI->getSuccessor(1)); + + if (MSSAU) + MSSAU->moveAllAfterSpliceBlocks(CheckBB, GuardedBlock, GI); + + GI->moveBefore(DeoptBlockTerm); + GI->setArgOperand(0, ConstantInt::getFalse(GI->getContext())); + + // Add new successors of CheckBB into DomTree. + for (auto *Succ : successors(CheckBB)) + DTUpdates.push_back({DominatorTree::Insert, CheckBB, Succ}); + + // Now the blocks that used to be CheckBB's successors are GuardedBlock's + // successors. + for (auto *Succ : Successors) + DTUpdates.push_back({DominatorTree::Insert, GuardedBlock, Succ}); + + // Make proper changes to DT. + DT.applyUpdates(DTUpdates); + // Inform LI of a new loop block. + L.addBasicBlockToLoop(GuardedBlock, LI); + + if (MSSAU) { + MemoryDef *MD = cast<MemoryDef>(MSSAU->getMemorySSA()->getMemoryAccess(GI)); + MSSAU->moveToPlace(MD, DeoptBlock, MemorySSA::End); + if (VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + } + + ++NumGuards; + return CheckBI; +} + +/// Cost multiplier is a way to limit potentially exponential behavior +/// of loop-unswitch. Cost is multipied in proportion of 2^number of unswitch +/// candidates available. Also accounting for the number of "sibling" loops with +/// the idea to account for previous unswitches that already happened on this +/// cluster of loops. There was an attempt to keep this formula simple, +/// just enough to limit the worst case behavior. Even if it is not that simple +/// now it is still not an attempt to provide a detailed heuristic size +/// prediction. +/// +/// TODO: Make a proper accounting of "explosion" effect for all kinds of +/// unswitch candidates, making adequate predictions instead of wild guesses. +/// That requires knowing not just the number of "remaining" candidates but +/// also costs of unswitching for each of these candidates. +static int calculateUnswitchCostMultiplier( + Instruction &TI, Loop &L, LoopInfo &LI, DominatorTree &DT, + ArrayRef<std::pair<Instruction *, TinyPtrVector<Value *>>> + UnswitchCandidates) { + + // Guards and other exiting conditions do not contribute to exponential + // explosion as soon as they dominate the latch (otherwise there might be + // another path to the latch remaining that does not allow to eliminate the + // loop copy on unswitch). + BasicBlock *Latch = L.getLoopLatch(); + BasicBlock *CondBlock = TI.getParent(); + if (DT.dominates(CondBlock, Latch) && + (isGuard(&TI) || + llvm::count_if(successors(&TI), [&L](BasicBlock *SuccBB) { + return L.contains(SuccBB); + }) <= 1)) { + NumCostMultiplierSkipped++; + return 1; + } + + auto *ParentL = L.getParentLoop(); + int SiblingsCount = (ParentL ? ParentL->getSubLoopsVector().size() + : std::distance(LI.begin(), LI.end())); + // Count amount of clones that all the candidates might cause during + // unswitching. Branch/guard counts as 1, switch counts as log2 of its cases. + int UnswitchedClones = 0; + for (auto Candidate : UnswitchCandidates) { + Instruction *CI = Candidate.first; + BasicBlock *CondBlock = CI->getParent(); + bool SkipExitingSuccessors = DT.dominates(CondBlock, Latch); + if (isGuard(CI)) { + if (!SkipExitingSuccessors) + UnswitchedClones++; + continue; + } + int NonExitingSuccessors = llvm::count_if( + successors(CondBlock), [SkipExitingSuccessors, &L](BasicBlock *SuccBB) { + return !SkipExitingSuccessors || L.contains(SuccBB); + }); + UnswitchedClones += Log2_32(NonExitingSuccessors); + } + + // Ignore up to the "unscaled candidates" number of unswitch candidates + // when calculating the power-of-two scaling of the cost. The main idea + // with this control is to allow a small number of unswitches to happen + // and rely more on siblings multiplier (see below) when the number + // of candidates is small. + unsigned ClonesPower = + std::max(UnswitchedClones - (int)UnswitchNumInitialUnscaledCandidates, 0); + + // Allowing top-level loops to spread a bit more than nested ones. + int SiblingsMultiplier = + std::max((ParentL ? SiblingsCount + : SiblingsCount / (int)UnswitchSiblingsToplevelDiv), + 1); + // Compute the cost multiplier in a way that won't overflow by saturating + // at an upper bound. + int CostMultiplier; + if (ClonesPower > Log2_32(UnswitchThreshold) || + SiblingsMultiplier > UnswitchThreshold) + CostMultiplier = UnswitchThreshold; + else + CostMultiplier = std::min(SiblingsMultiplier * (1 << ClonesPower), + (int)UnswitchThreshold); + + LLVM_DEBUG(dbgs() << " Computed multiplier " << CostMultiplier + << " (siblings " << SiblingsMultiplier << " * clones " + << (1 << ClonesPower) << ")" + << " for unswitch candidate: " << TI << "\n"); + return CostMultiplier; +} + static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, TargetTransformInfo &TTI, function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE) { + ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). - SmallVector<std::pair<TerminatorInst *, TinyPtrVector<Value *>>, 4> + SmallVector<std::pair<Instruction *, TinyPtrVector<Value *>>, 4> UnswitchCandidates; + + // Whether or not we should also collect guards in the loop. + bool CollectGuards = false; + if (UnswitchGuards) { + auto *GuardDecl = L.getHeader()->getParent()->getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + if (GuardDecl && !GuardDecl->use_empty()) + CollectGuards = true; + } + for (auto *BB : L.blocks()) { if (LI.getLoopFor(BB) != &L) continue; + if (CollectGuards) + for (auto &I : *BB) + if (isGuard(&I)) { + auto *Cond = cast<IntrinsicInst>(&I)->getArgOperand(0); + // TODO: Support AND, OR conditions and partial unswitching. + if (!isa<Constant>(Cond) && L.isLoopInvariant(Cond)) + UnswitchCandidates.push_back({&I, {Cond}}); + } + if (auto *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { // We can only consider fully loop-invariant switch conditions as we need // to completely eliminate the switch after unswitching. @@ -2231,6 +2569,19 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, if (containsIrreducibleCFG<const BasicBlock *>(RPOT, LI)) return false; + SmallVector<BasicBlock *, 4> ExitBlocks; + L.getUniqueExitBlocks(ExitBlocks); + + // We cannot unswitch if exit blocks contain a cleanuppad instruction as we + // don't know how to split those exit blocks. + // FIXME: We should teach SplitBlock to handle this and remove this + // restriction. + for (auto *ExitBB : ExitBlocks) + if (isa<CleanupPadInst>(ExitBB->getFirstNonPHI())) { + dbgs() << "Cannot unswitch because of cleanuppad in exit block\n"; + return false; + } + LLVM_DEBUG( dbgs() << "Considering " << UnswitchCandidates.size() << " non-trivial loop invariant conditions for unswitching.\n"); @@ -2288,7 +2639,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, SmallDenseMap<DomTreeNode *, int, 4> DTCostMap; // Given a terminator which might be unswitched, computes the non-duplicated // cost for that terminator. - auto ComputeUnswitchedCost = [&](TerminatorInst &TI, bool FullUnswitch) { + auto ComputeUnswitchedCost = [&](Instruction &TI, bool FullUnswitch) { BasicBlock &BB = *TI.getParent(); SmallPtrSet<BasicBlock *, 4> Visited; @@ -2335,22 +2686,40 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, // Now scale the cost by the number of unique successors minus one. We // subtract one because there is already at least one copy of the entire // loop. This is computing the new cost of unswitching a condition. - assert(Visited.size() > 1 && + // Note that guards always have 2 unique successors that are implicit and + // will be materialized if we decide to unswitch it. + int SuccessorsCount = isGuard(&TI) ? 2 : Visited.size(); + assert(SuccessorsCount > 1 && "Cannot unswitch a condition without multiple distinct successors!"); - return Cost * (Visited.size() - 1); + return Cost * (SuccessorsCount - 1); }; - TerminatorInst *BestUnswitchTI = nullptr; + Instruction *BestUnswitchTI = nullptr; int BestUnswitchCost; ArrayRef<Value *> BestUnswitchInvariants; for (auto &TerminatorAndInvariants : UnswitchCandidates) { - TerminatorInst &TI = *TerminatorAndInvariants.first; + Instruction &TI = *TerminatorAndInvariants.first; ArrayRef<Value *> Invariants = TerminatorAndInvariants.second; BranchInst *BI = dyn_cast<BranchInst>(&TI); int CandidateCost = ComputeUnswitchedCost( TI, /*FullUnswitch*/ !BI || (Invariants.size() == 1 && Invariants[0] == BI->getCondition())); - LLVM_DEBUG(dbgs() << " Computed cost of " << CandidateCost - << " for unswitch candidate: " << TI << "\n"); + // Calculate cost multiplier which is a tool to limit potentially + // exponential behavior of loop-unswitch. + if (EnableUnswitchCostMultiplier) { + int CostMultiplier = + calculateUnswitchCostMultiplier(TI, L, LI, DT, UnswitchCandidates); + assert( + (CostMultiplier > 0 && CostMultiplier <= UnswitchThreshold) && + "cost multiplier needs to be in the range of 1..UnswitchThreshold"); + CandidateCost *= CostMultiplier; + LLVM_DEBUG(dbgs() << " Computed cost of " << CandidateCost + << " (multiplier: " << CostMultiplier << ")" + << " for unswitch candidate: " << TI << "\n"); + } else { + LLVM_DEBUG(dbgs() << " Computed cost of " << CandidateCost + << " for unswitch candidate: " << TI << "\n"); + } + if (!BestUnswitchTI || CandidateCost < BestUnswitchCost) { BestUnswitchTI = &TI; BestUnswitchCost = CandidateCost; @@ -2364,11 +2733,17 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, return false; } - LLVM_DEBUG(dbgs() << " Trying to unswitch non-trivial (cost = " + // If the best candidate is a guard, turn it into a branch. + if (isGuard(BestUnswitchTI)) + BestUnswitchTI = turnGuardIntoBranch(cast<IntrinsicInst>(BestUnswitchTI), L, + ExitBlocks, DT, LI, MSSAU); + + LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << BestUnswitchCost << ") terminator: " << *BestUnswitchTI << "\n"); - return unswitchNontrivialInvariants( - L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE); + unswitchNontrivialInvariants(L, *BestUnswitchTI, BestUnswitchInvariants, + ExitBlocks, DT, LI, AC, UnswitchCB, SE, MSSAU); + return true; } /// Unswitch control flow predicated on loop invariant conditions. @@ -2380,6 +2755,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, /// /// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also /// updated based on the unswitch. +/// The `MSSA` analysis is also updated if valid (i.e. its use is enabled). /// /// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is /// true, we will attempt to do non-trivial unswitching as well as trivial @@ -2395,7 +2771,7 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, TargetTransformInfo &TTI, bool NonTrivial, function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE) { + ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); bool Changed = false; @@ -2405,7 +2781,7 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, return false; // Try trivial unswitch first before loop over other basic blocks in the loop. - if (unswitchAllTrivialConditions(L, DT, LI, SE)) { + if (unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. UnswitchCB(/*CurrentLoopValid*/ true, {}); @@ -2426,7 +2802,7 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE)) + if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE, MSSAU)) return true; // No other opportunities to unswitch. @@ -2460,10 +2836,19 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, U.markLoopAsDeleted(L, LoopName); }; + Optional<MemorySSAUpdater> MSSAU; + if (AR.MSSA) { + MSSAU = MemorySSAUpdater(AR.MSSA); + if (VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + } if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB, - &AR.SE)) + &AR.SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) return PreservedAnalyses::all(); + if (AR.MSSA && VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + // Historically this pass has had issues with the dominator tree so verify it // in asserts builds. assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); @@ -2489,6 +2874,10 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + if (EnableMSSALoopDependency) { + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } getLoopAnalysisUsage(AU); } }; @@ -2508,6 +2897,12 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + MemorySSA *MSSA = nullptr; + Optional<MemorySSAUpdater> MSSAU; + if (EnableMSSALoopDependency) { + MSSA = &getAnalysis<MemorySSAWrapperPass>().getMSSA(); + MSSAU = MemorySSAUpdater(MSSA); + } auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; @@ -2527,7 +2922,14 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { LPM.markLoopAsDeleted(*L); }; - bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE); + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); + + bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); + + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); // If anything was unswitched, also clear any cached information about this // loop. @@ -2547,6 +2949,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", "Simple unswitch loops", false, false) diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp index ca6b93e0b4a9..c99da8f0737a 100644 --- a/lib/Transforms/Scalar/Sink.cpp +++ b/lib/Transforms/Scalar/Sink.cpp @@ -72,18 +72,18 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis &AA, return false; } - if (isa<TerminatorInst>(Inst) || isa<PHINode>(Inst) || Inst->isEHPad() || + if (Inst->isTerminator() || isa<PHINode>(Inst) || Inst->isEHPad() || Inst->mayThrow()) return false; - if (auto CS = CallSite(Inst)) { + if (auto *Call = dyn_cast<CallBase>(Inst)) { // Convergent operations cannot be made control-dependent on additional // values. - if (CS.hasFnAttr(Attribute::Convergent)) + if (Call->hasFnAttr(Attribute::Convergent)) return false; for (Instruction *S : Stores) - if (isModSet(AA.getModRefInfo(S, CS))) + if (isModSet(AA.getModRefInfo(S, Call))) return false; } @@ -104,7 +104,7 @@ static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo, // It's never legal to sink an instruction into a block which terminates in an // EH-pad. - if (SuccToSinkTo->getTerminator()->isExceptional()) + if (SuccToSinkTo->getTerminator()->isExceptionalTerminator()) return false; // If the block has multiple predecessors, this would introduce computation diff --git a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp index 6743e19a7c92..c0f75ddddbe0 100644 --- a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp +++ b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -33,7 +33,7 @@ STATISTIC(NumSpeculatedInstructions, STATISTIC(NumNewRedundantInstructions, "Number of new, redundant instructions inserted"); -/// Check wether speculating the users of a PHI node around the PHI +/// Check whether speculating the users of a PHI node around the PHI /// will be safe. /// /// This checks both that all of the users are safe and also that all of their diff --git a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 2061db13639a..b5089b006bdd 100644 --- a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -640,12 +640,12 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( Value *Reduced = nullptr; // equivalent to but weaker than C.Ins switch (C.CandidateKind) { case Candidate::Add: - case Candidate::Mul: + case Candidate::Mul: { // C = Basis + Bump - if (BinaryOperator::isNeg(Bump)) { + Value *NegBump; + if (match(Bump, m_Neg(m_Value(NegBump)))) { // If Bump is a neg instruction, emit C = Basis - (-Bump). - Reduced = - Builder.CreateSub(Basis.Ins, BinaryOperator::getNegArgument(Bump)); + Reduced = Builder.CreateSub(Basis.Ins, NegBump); // We only use the negative argument of Bump, and Bump itself may be // trivially dead. RecursivelyDeleteTriviallyDeadInstructions(Bump); @@ -662,6 +662,7 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( Reduced = Builder.CreateAdd(Basis.Ins, Bump); } break; + } case Candidate::GEP: { Type *IntPtrTy = DL->getIntPtrType(C.Ins->getType()); diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp index d650264176aa..0db762d846f2 100644 --- a/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -13,7 +13,8 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LegacyDivergenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" @@ -183,7 +184,7 @@ class StructurizeCFG : public RegionPass { Function *Func; Region *ParentRegion; - DivergenceAnalysis *DA; + LegacyDivergenceAnalysis *DA; DominatorTree *DT; LoopInfo *LI; @@ -269,7 +270,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { if (SkipUniformRegions) - AU.addRequired<DivergenceAnalysis>(); + AU.addRequired<LegacyDivergenceAnalysis>(); AU.addRequiredID(LowerSwitchID); AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); @@ -285,7 +286,7 @@ char StructurizeCFG::ID = 0; INITIALIZE_PASS_BEGIN(StructurizeCFG, "structurizecfg", "Structurize the CFG", false, false) -INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(LegacyDivergenceAnalysis) INITIALIZE_PASS_DEPENDENCY(LowerSwitch) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) @@ -596,7 +597,8 @@ void StructurizeCFG::addPhiValues(BasicBlock *From, BasicBlock *To) { /// Add the real PHI value as soon as everything is set up void StructurizeCFG::setPhiValues() { - SSAUpdater Updater; + SmallVector<PHINode *, 8> InsertedPhis; + SSAUpdater Updater(&InsertedPhis); for (const auto &AddedPhi : AddedPhis) { BasicBlock *To = AddedPhi.first; const BBVector &From = AddedPhi.second; @@ -632,11 +634,31 @@ void StructurizeCFG::setPhiValues() { DeletedPhis.erase(To); } assert(DeletedPhis.empty()); + + // Simplify any phis inserted by the SSAUpdater if possible + bool Changed; + do { + Changed = false; + + SimplifyQuery Q(Func->getParent()->getDataLayout()); + Q.DT = DT; + for (size_t i = 0; i < InsertedPhis.size(); ++i) { + PHINode *Phi = InsertedPhis[i]; + if (Value *V = SimplifyInstruction(Phi, Q)) { + Phi->replaceAllUsesWith(V); + Phi->eraseFromParent(); + InsertedPhis[i] = InsertedPhis.back(); + InsertedPhis.pop_back(); + i--; + Changed = true; + } + } + } while (Changed); } /// Remove phi values from all successors and then remove the terminator. void StructurizeCFG::killTerminator(BasicBlock *BB) { - TerminatorInst *Term = BB->getTerminator(); + Instruction *Term = BB->getTerminator(); if (!Term) return; @@ -914,7 +936,7 @@ void StructurizeCFG::rebuildSSA() { } static bool hasOnlyUniformBranches(Region *R, unsigned UniformMDKindID, - const DivergenceAnalysis &DA) { + const LegacyDivergenceAnalysis &DA) { for (auto E : R->elements()) { if (!E->isSubRegion()) { auto Br = dyn_cast<BranchInst>(E->getEntry()->getTerminator()); @@ -962,7 +984,7 @@ bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { // but we shouldn't rely on metadata for correctness! unsigned UniformMDKindID = R->getEntry()->getContext().getMDKindID("structurizecfg.uniform"); - DA = &getAnalysis<DivergenceAnalysis>(); + DA = &getAnalysis<LegacyDivergenceAnalysis>(); if (hasOnlyUniformBranches(R, UniformMDKindID, *DA)) { LLVM_DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp index f8cd6c17a5a6..0f6db21f73b6 100644 --- a/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -61,6 +61,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" @@ -68,6 +69,8 @@ #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/DomTreeUpdater.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -124,6 +127,12 @@ struct AllocaDerivedValueTracker { case Instruction::Call: case Instruction::Invoke: { CallSite CS(I); + // If the alloca-derived argument is passed byval it is not an escape + // point, or a use of an alloca. Calling with byval copies the contents + // of the alloca into argument registers or stack slots, which exist + // beyond the lifetime of the current frame. + if (CS.isArgOperand(U) && CS.isByValArgument(CS.getArgumentNo(U))) + continue; bool IsNocapture = CS.isDataOperand(U) && CS.doesNotCapture(CS.getDataOperandNo(U)); callUsesLocalStack(CS, IsNocapture); @@ -488,12 +497,10 @@ static CallInst *findTRECandidate(Instruction *TI, return CI; } -static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, - BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - AliasAnalysis *AA, - OptimizationRemarkEmitter *ORE) { +static bool eliminateRecursiveTailCall( + CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { // If we are introducing accumulator recursion to eliminate operations after // the call instruction that are both associative and commutative, the initial // value for the accumulator is placed in this variable. If this value is set @@ -566,7 +573,8 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry); NewEntry->takeName(OldEntry); OldEntry->setName("tailrecurse"); - BranchInst::Create(OldEntry, NewEntry); + BranchInst *BI = BranchInst::Create(OldEntry, NewEntry); + BI->setDebugLoc(CI->getDebugLoc()); // If this tail call is marked 'tail' and if there are any allocas in the // entry block, move them up to the new entry block. @@ -592,6 +600,10 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, PN->addIncoming(&*I, NewEntry); ArgumentPHIs.push_back(PN); } + // The entry block was changed from OldEntry to NewEntry. + // The forward DominatorTree needs to be recalculated when the EntryBB is + // changed. In this corner-case we recalculate the entire tree. + DTU.recalculate(*NewEntry->getParent()); } // If this function has self recursive calls in the tail position where some @@ -667,6 +679,7 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, BB->getInstList().erase(Ret); // Remove return. BB->getInstList().erase(CI); // Remove call. + DTU.insertEdge(BB, OldEntry); ++NumEliminated; return true; } @@ -675,7 +688,7 @@ static bool foldReturnAndProcessPred( BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA, OptimizationRemarkEmitter *ORE) { + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { bool Change = false; // Make sure this block is a trivial return block. @@ -689,7 +702,7 @@ static bool foldReturnAndProcessPred( SmallVector<BranchInst*, 8> UncondBranchPreds; for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { BasicBlock *Pred = *PI; - TerminatorInst *PTI = Pred->getTerminator(); + Instruction *PTI = Pred->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(PTI)) if (BI->isUnconditional()) UncondBranchPreds.push_back(BI); @@ -701,17 +714,17 @@ static bool foldReturnAndProcessPred( if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ LLVM_DEBUG(dbgs() << "FOLDING: " << *BB << "INTO UNCOND BRANCH PRED: " << *Pred); - ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred); + ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred, &DTU); // Cleanup: if all predecessors of BB have been eliminated by // FoldReturnIntoUncondBranch, delete it. It is important to empty it, // because the ret instruction in there is still using a value which // eliminateRecursiveTailCall will attempt to remove. if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) - BB->eraseFromParent(); + DTU.deleteBB(BB); eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE); + ArgumentPHIs, AA, ORE, DTU); ++NumRetDuped; Change = true; } @@ -720,24 +733,23 @@ static bool foldReturnAndProcessPred( return Change; } -static bool processReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, - const TargetTransformInfo *TTI, - AliasAnalysis *AA, - OptimizationRemarkEmitter *ORE) { +static bool processReturningBlock( + ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, + SmallVectorImpl<PHINode *> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE, DomTreeUpdater &DTU) { CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); if (!CI) return false; return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA, ORE); + ArgumentPHIs, AA, ORE, DTU); } static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, AliasAnalysis *AA, - OptimizationRemarkEmitter *ORE) { + OptimizationRemarkEmitter *ORE, + DomTreeUpdater &DTU) { if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") return false; @@ -772,11 +784,11 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, !CanTRETailMarkedCall, - TTI, AA, ORE); + TTI, AA, ORE, DTU); if (!Change && BB->getFirstNonPHIOrDbg() == Ret) - Change = foldReturnAndProcessPred(BB, Ret, OldEntry, - TailCallsAreMarkedTail, ArgumentPHIs, - !CanTRETailMarkedCall, TTI, AA, ORE); + Change = foldReturnAndProcessPred( + BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, + !CanTRETailMarkedCall, TTI, AA, ORE, DTU); MadeChange |= Change; } } @@ -809,16 +821,27 @@ struct TailCallElim : public FunctionPass { AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<PostDominatorTreeWrapperPass>(); } bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + auto *PDTWP = getAnalysisIfAvailable<PostDominatorTreeWrapperPass>(); + auto *PDT = PDTWP ? &PDTWP->getPostDomTree() : nullptr; + // There is no noticable performance difference here between Lazy and Eager + // UpdateStrategy based on some test results. It is feasible to switch the + // UpdateStrategy to Lazy if we find it profitable later. + DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); + return eliminateTailRecursion( F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F), &getAnalysis<AAResultsWrapperPass>().getAAResults(), - &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(), DTU); } }; } @@ -842,12 +865,19 @@ PreservedAnalyses TailCallElimPass::run(Function &F, TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); AliasAnalysis &AA = AM.getResult<AAManager>(F); auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - - bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE); + auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + auto *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F); + // There is no noticable performance difference here between Lazy and Eager + // UpdateStrategy based on some test results. It is feasible to switch the + // UpdateStrategy to Lazy if we find it profitable later. + DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Eager); + bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE, DTU); if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); return PA; } diff --git a/lib/Transforms/Scalar/WarnMissedTransforms.cpp b/lib/Transforms/Scalar/WarnMissedTransforms.cpp new file mode 100644 index 000000000000..80f761e53774 --- /dev/null +++ b/lib/Transforms/Scalar/WarnMissedTransforms.cpp @@ -0,0 +1,149 @@ +//===- LoopTransformWarning.cpp - ----------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Emit warnings if forced code transformations have not been performed. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/WarnMissedTransforms.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "transform-warning" + +/// Emit warnings for forced (i.e. user-defined) loop transformations which have +/// still not been performed. +static void warnAboutLeftoverTransformations(Loop *L, + OptimizationRemarkEmitter *ORE) { + if (hasUnrollTransformation(L) == TM_ForcedByUser) { + LLVM_DEBUG(dbgs() << "Leftover unroll transformation\n"); + ORE->emit( + DiagnosticInfoOptimizationFailure(DEBUG_TYPE, + "FailedRequestedUnrolling", + L->getStartLoc(), L->getHeader()) + << "loop not unrolled: the optimizer was unable to perform the " + "requested transformation; the transformation might be disabled or " + "specified as part of an unsupported transformation ordering"); + } + + if (hasUnrollAndJamTransformation(L) == TM_ForcedByUser) { + LLVM_DEBUG(dbgs() << "Leftover unroll-and-jam transformation\n"); + ORE->emit( + DiagnosticInfoOptimizationFailure(DEBUG_TYPE, + "FailedRequestedUnrollAndJamming", + L->getStartLoc(), L->getHeader()) + << "loop not unroll-and-jammed: the optimizer was unable to perform " + "the requested transformation; the transformation might be disabled " + "or specified as part of an unsupported transformation ordering"); + } + + if (hasVectorizeTransformation(L) == TM_ForcedByUser) { + LLVM_DEBUG(dbgs() << "Leftover vectorization transformation\n"); + Optional<int> VectorizeWidth = + getOptionalIntLoopAttribute(L, "llvm.loop.vectorize.width"); + Optional<int> InterleaveCount = + getOptionalIntLoopAttribute(L, "llvm.loop.interleave.count"); + + if (VectorizeWidth.getValueOr(0) != 1) + ORE->emit( + DiagnosticInfoOptimizationFailure(DEBUG_TYPE, + "FailedRequestedVectorization", + L->getStartLoc(), L->getHeader()) + << "loop not vectorized: the optimizer was unable to perform the " + "requested transformation; the transformation might be disabled " + "or specified as part of an unsupported transformation ordering"); + else if (InterleaveCount.getValueOr(0) != 1) + ORE->emit( + DiagnosticInfoOptimizationFailure(DEBUG_TYPE, + "FailedRequestedInterleaving", + L->getStartLoc(), L->getHeader()) + << "loop not interleaved: the optimizer was unable to perform the " + "requested transformation; the transformation might be disabled " + "or specified as part of an unsupported transformation ordering"); + } + + if (hasDistributeTransformation(L) == TM_ForcedByUser) { + LLVM_DEBUG(dbgs() << "Leftover distribute transformation\n"); + ORE->emit( + DiagnosticInfoOptimizationFailure(DEBUG_TYPE, + "FailedRequestedDistribution", + L->getStartLoc(), L->getHeader()) + << "loop not distributed: the optimizer was unable to perform the " + "requested transformation; the transformation might be disabled or " + "specified as part of an unsupported transformation ordering"); + } +} + +static void warnAboutLeftoverTransformations(Function *F, LoopInfo *LI, + OptimizationRemarkEmitter *ORE) { + for (auto *L : LI->getLoopsInPreorder()) + warnAboutLeftoverTransformations(L, ORE); +} + +// New pass manager boilerplate +PreservedAnalyses +WarnMissedTransformationsPass::run(Function &F, FunctionAnalysisManager &AM) { + // Do not warn about not applied transformations if optimizations are + // disabled. + if (F.hasFnAttribute(Attribute::OptimizeNone)) + return PreservedAnalyses::all(); + + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + + warnAboutLeftoverTransformations(&F, &LI, &ORE); + + return PreservedAnalyses::all(); +} + +// Legacy pass manager boilerplate +namespace { +class WarnMissedTransformationsLegacy : public FunctionPass { +public: + static char ID; + + explicit WarnMissedTransformationsLegacy() : FunctionPass(ID) { + initializeWarnMissedTransformationsLegacyPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + + warnAboutLeftoverTransformations(&F, &LI, &ORE); + return false; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + + AU.setPreservesAll(); + } +}; +} // end anonymous namespace + +char WarnMissedTransformationsLegacy::ID = 0; + +INITIALIZE_PASS_BEGIN(WarnMissedTransformationsLegacy, "transform-warning", + "Warn about non-applied transformations", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(WarnMissedTransformationsLegacy, "transform-warning", + "Warn about non-applied transformations", false, false) + +Pass *llvm::createWarnMissedTransformationsPass() { + return new WarnMissedTransformationsLegacy(); +} diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index e3ef42362223..564537af0c2a 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -209,10 +209,18 @@ static bool addDiscriminators(Function &F) { // Only the lowest 7 bits are used to represent a discriminator to fit // it in 1 byte ULEB128 representation. unsigned Discriminator = R.second ? ++LDM[L] : LDM[L]; - I.setDebugLoc(DIL->setBaseDiscriminator(Discriminator)); - LLVM_DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" - << DIL->getColumn() << ":" << Discriminator << " " << I - << "\n"); + auto NewDIL = DIL->setBaseDiscriminator(Discriminator); + if (!NewDIL) { + LLVM_DEBUG(dbgs() << "Could not encode discriminator: " + << DIL->getFilename() << ":" << DIL->getLine() << ":" + << DIL->getColumn() << ":" << Discriminator << " " + << I << "\n"); + } else { + I.setDebugLoc(NewDIL.getValue()); + LLVM_DEBUG(dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" + << DIL->getColumn() << ":" << Discriminator << " " << I + << "\n"); + } Changed = true; } } @@ -224,23 +232,31 @@ static bool addDiscriminators(Function &F) { for (BasicBlock &B : F) { LocationSet CallLocations; for (auto &I : B.getInstList()) { - CallInst *Current = dyn_cast<CallInst>(&I); // We bypass intrinsic calls for the following two reasons: // 1) We want to avoid a non-deterministic assigment of // discriminators. // 2) We want to minimize the number of base discriminators used. - if (!Current || isa<IntrinsicInst>(&I)) + if (!isa<InvokeInst>(I) && (!isa<CallInst>(I) || isa<IntrinsicInst>(I))) continue; - DILocation *CurrentDIL = Current->getDebugLoc(); + DILocation *CurrentDIL = I.getDebugLoc(); if (!CurrentDIL) continue; Location L = std::make_pair(CurrentDIL->getFilename(), CurrentDIL->getLine()); if (!CallLocations.insert(L).second) { unsigned Discriminator = ++LDM[L]; - Current->setDebugLoc(CurrentDIL->setBaseDiscriminator(Discriminator)); - Changed = true; + auto NewDIL = CurrentDIL->setBaseDiscriminator(Discriminator); + if (!NewDIL) { + LLVM_DEBUG(dbgs() + << "Could not encode discriminator: " + << CurrentDIL->getFilename() << ":" + << CurrentDIL->getLine() << ":" << CurrentDIL->getColumn() + << ":" << Discriminator << " " << I << "\n"); + } else { + I.setDebugLoc(NewDIL.getValue()); + Changed = true; + } } } } diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 516a785dce1e..7da768252fc1 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -20,11 +20,13 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" @@ -37,6 +39,7 @@ #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/Casting.h" +#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <string> @@ -45,42 +48,58 @@ using namespace llvm; -void llvm::DeleteDeadBlock(BasicBlock *BB, DeferredDominance *DDT) { - assert((pred_begin(BB) == pred_end(BB) || - // Can delete self loop. - BB->getSinglePredecessor() == BB) && "Block is not dead!"); - TerminatorInst *BBTerm = BB->getTerminator(); - std::vector<DominatorTree::UpdateType> Updates; +void llvm::DeleteDeadBlock(BasicBlock *BB, DomTreeUpdater *DTU) { + SmallVector<BasicBlock *, 1> BBs = {BB}; + DeleteDeadBlocks(BBs, DTU); +} - // Loop through all of our successors and make sure they know that one - // of their predecessors is going away. - if (DDT) - Updates.reserve(BBTerm->getNumSuccessors()); - for (BasicBlock *Succ : BBTerm->successors()) { - Succ->removePredecessor(BB); - if (DDT) - Updates.push_back({DominatorTree::Delete, BB, Succ}); - } +void llvm::DeleteDeadBlocks(SmallVectorImpl <BasicBlock *> &BBs, + DomTreeUpdater *DTU) { +#ifndef NDEBUG + // Make sure that all predecessors of each dead block is also dead. + SmallPtrSet<BasicBlock *, 4> Dead(BBs.begin(), BBs.end()); + assert(Dead.size() == BBs.size() && "Duplicating blocks?"); + for (auto *BB : Dead) + for (BasicBlock *Pred : predecessors(BB)) + assert(Dead.count(Pred) && "All predecessors must be dead!"); +#endif + + SmallVector<DominatorTree::UpdateType, 4> Updates; + for (auto *BB : BBs) { + // Loop through all of our successors and make sure they know that one + // of their predecessors is going away. + for (BasicBlock *Succ : successors(BB)) { + Succ->removePredecessor(BB); + if (DTU) + Updates.push_back({DominatorTree::Delete, BB, Succ}); + } - // Zap all the instructions in the block. - while (!BB->empty()) { - Instruction &I = BB->back(); - // If this instruction is used, replace uses with an arbitrary value. - // Because control flow can't get here, we don't care what we replace the - // value with. Note that since this block is unreachable, and all values - // contained within it must dominate their uses, that all uses will - // eventually be removed (they are themselves dead). - if (!I.use_empty()) - I.replaceAllUsesWith(UndefValue::get(I.getType())); - BB->getInstList().pop_back(); + // Zap all the instructions in the block. + while (!BB->empty()) { + Instruction &I = BB->back(); + // If this instruction is used, replace uses with an arbitrary value. + // Because control flow can't get here, we don't care what we replace the + // value with. Note that since this block is unreachable, and all values + // contained within it must dominate their uses, that all uses will + // eventually be removed (they are themselves dead). + if (!I.use_empty()) + I.replaceAllUsesWith(UndefValue::get(I.getType())); + BB->getInstList().pop_back(); + } + new UnreachableInst(BB->getContext(), BB); + assert(BB->getInstList().size() == 1 && + isa<UnreachableInst>(BB->getTerminator()) && + "The successor list of BB isn't empty before " + "applying corresponding DTU updates."); } + if (DTU) + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); - if (DDT) { - DDT->applyUpdates(Updates); - DDT->deleteBB(BB); // Deferred deletion of BB. - } else { - BB->eraseFromParent(); // Zap the block! - } + for (BasicBlock *BB : BBs) + if (DTU) + DTU->deleteBB(BB); + else + BB->eraseFromParent(); } void llvm::FoldSingleEntryPHINodes(BasicBlock *BB, @@ -115,12 +134,9 @@ bool llvm::DeleteDeadPHIs(BasicBlock *BB, const TargetLibraryInfo *TLI) { return Changed; } -bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, - LoopInfo *LI, - MemoryDependenceResults *MemDep, - DeferredDominance *DDT) { - assert(!(DT && DDT) && "Cannot call with both DT and DDT."); - +bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, + LoopInfo *LI, MemorySSAUpdater *MSSAU, + MemoryDependenceResults *MemDep) { if (BB->hasAddressTaken()) return false; @@ -131,7 +147,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, // Don't break self-loops. if (PredBB == BB) return false; // Don't break unwinding instructions. - if (PredBB->getTerminator()->isExceptional()) + if (PredBB->getTerminator()->isExceptionalTerminator()) return false; // Can't merge if there are multiple distinct successors. @@ -154,10 +170,10 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, FoldSingleEntryPHINodes(BB, MemDep); } - // Deferred DT update: Collect all the edges that exit BB. These - // dominator edges will be redirected from Pred. + // DTU update: Collect all the edges that exit BB. + // These dominator edges will be redirected from Pred. std::vector<DominatorTree::UpdateType> Updates; - if (DDT) { + if (DTU) { Updates.reserve(1 + (2 * succ_size(BB))); Updates.push_back({DominatorTree::Delete, PredBB, BB}); for (auto I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { @@ -166,6 +182,9 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, } } + if (MSSAU) + MSSAU->moveAllAfterMergeBlocks(BB, PredBB, &*(BB->begin())); + // Delete the unconditional branch from the predecessor... PredBB->getInstList().pop_back(); @@ -175,6 +194,7 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, // Move all definitions in the successor to the predecessor... PredBB->getInstList().splice(PredBB->end(), BB->getInstList()); + new UnreachableInst(BB->getContext(), BB); // Eliminate duplicate dbg.values describing the entry PHI node post-splice. for (auto Incoming : IncomingValues) { @@ -195,28 +215,24 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, if (!PredBB->hasName()) PredBB->takeName(BB); - // Finally, erase the old block and update dominator info. - if (DT) - if (DomTreeNode *DTN = DT->getNode(BB)) { - DomTreeNode *PredDTN = DT->getNode(PredBB); - SmallVector<DomTreeNode *, 8> Children(DTN->begin(), DTN->end()); - for (DomTreeNode *DI : Children) - DT->changeImmediateDominator(DI, PredDTN); - - DT->eraseNode(BB); - } - if (LI) LI->removeBlock(BB); if (MemDep) MemDep->invalidateCachedPredecessors(); - if (DDT) { - DDT->deleteBB(BB); // Deferred deletion of BB. - DDT->applyUpdates(Updates); - } else { - BB->eraseFromParent(); // Nuke BB. + // Finally, erase the old block and update dominator info. + if (DTU) { + assert(BB->getInstList().size() == 1 && + isa<UnreachableInst>(BB->getTerminator()) && + "The successor list of BB isn't empty before " + "applying corresponding DTU updates."); + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->deleteBB(BB); + } + + else { + BB->eraseFromParent(); // Nuke BB if DTU is nullptr. } return true; } @@ -261,13 +277,14 @@ void llvm::ReplaceInstWithInst(Instruction *From, Instruction *To) { } BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, DominatorTree *DT, - LoopInfo *LI) { + LoopInfo *LI, MemorySSAUpdater *MSSAU) { unsigned SuccNum = GetSuccessorNumber(BB, Succ); // If this is a critical edge, let SplitCriticalEdge do it. - TerminatorInst *LatchTerm = BB->getTerminator(); - if (SplitCriticalEdge(LatchTerm, SuccNum, CriticalEdgeSplittingOptions(DT, LI) - .setPreserveLCSSA())) + Instruction *LatchTerm = BB->getTerminator(); + if (SplitCriticalEdge( + LatchTerm, SuccNum, + CriticalEdgeSplittingOptions(DT, LI, MSSAU).setPreserveLCSSA())) return LatchTerm->getSuccessor(SuccNum); // If the edge isn't critical, then BB has a single successor or Succ has a @@ -277,14 +294,14 @@ BasicBlock *llvm::SplitEdge(BasicBlock *BB, BasicBlock *Succ, DominatorTree *DT, // block. assert(SP == BB && "CFG broken"); SP = nullptr; - return SplitBlock(Succ, &Succ->front(), DT, LI); + return SplitBlock(Succ, &Succ->front(), DT, LI, MSSAU); } // Otherwise, if BB has a single successor, split it at the bottom of the // block. assert(BB->getTerminator()->getNumSuccessors() == 1 && "Should have a single succ!"); - return SplitBlock(BB, BB->getTerminator(), DT, LI); + return SplitBlock(BB, BB->getTerminator(), DT, LI, MSSAU); } unsigned @@ -292,7 +309,7 @@ llvm::SplitAllCriticalEdges(Function &F, const CriticalEdgeSplittingOptions &Options) { unsigned NumBroken = 0; for (BasicBlock &BB : F) { - TerminatorInst *TI = BB.getTerminator(); + Instruction *TI = BB.getTerminator(); if (TI->getNumSuccessors() > 1 && !isa<IndirectBrInst>(TI)) for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) if (SplitCriticalEdge(TI, i, Options)) @@ -302,7 +319,8 @@ llvm::SplitAllCriticalEdges(Function &F, } BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, - DominatorTree *DT, LoopInfo *LI) { + DominatorTree *DT, LoopInfo *LI, + MemorySSAUpdater *MSSAU) { BasicBlock::iterator SplitIt = SplitPt->getIterator(); while (isa<PHINode>(SplitIt) || SplitIt->isEHPad()) ++SplitIt; @@ -324,6 +342,11 @@ BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, DT->changeImmediateDominator(I, NewNode); } + // Move MemoryAccesses still tracked in Old, but part of New now. + // Update accesses in successor blocks accordingly. + if (MSSAU) + MSSAU->moveAllAfterSpliceBlocks(Old, New, &*(New->begin())); + return New; } @@ -331,6 +354,7 @@ BasicBlock *llvm::SplitBlock(BasicBlock *Old, Instruction *SplitPt, static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, ArrayRef<BasicBlock *> Preds, DominatorTree *DT, LoopInfo *LI, + MemorySSAUpdater *MSSAU, bool PreserveLCSSA, bool &HasLoopExit) { // Update dominator tree if available. if (DT) { @@ -343,6 +367,10 @@ static void UpdateAnalysisInformation(BasicBlock *OldBB, BasicBlock *NewBB, } } + // Update MemoryPhis after split if MemorySSA is available + if (MSSAU) + MSSAU->wireOldPredecessorsToNewImmediatePredecessor(OldBB, NewBB, Preds); + // The rest of the logic is only relevant for updating the loop structures. if (!LI) return; @@ -483,7 +511,8 @@ static void UpdatePHINodes(BasicBlock *OrigBB, BasicBlock *NewBB, BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, ArrayRef<BasicBlock *> Preds, const char *Suffix, DominatorTree *DT, - LoopInfo *LI, bool PreserveLCSSA) { + LoopInfo *LI, MemorySSAUpdater *MSSAU, + bool PreserveLCSSA) { // Do not attempt to split that which cannot be split. if (!BB->canSplitPredecessors()) return nullptr; @@ -495,7 +524,7 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, std::string NewName = std::string(Suffix) + ".split-lp"; SplitLandingPadPredecessors(BB, Preds, Suffix, NewName.c_str(), NewBBs, DT, - LI, PreserveLCSSA); + LI, MSSAU, PreserveLCSSA); return NewBBs[0]; } @@ -529,7 +558,7 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, // Update DominatorTree, LoopInfo, and LCCSA analysis information. bool HasLoopExit = false; - UpdateAnalysisInformation(BB, NewBB, Preds, DT, LI, PreserveLCSSA, + UpdateAnalysisInformation(BB, NewBB, Preds, DT, LI, MSSAU, PreserveLCSSA, HasLoopExit); if (!Preds.empty()) { @@ -545,6 +574,7 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, const char *Suffix1, const char *Suffix2, SmallVectorImpl<BasicBlock *> &NewBBs, DominatorTree *DT, LoopInfo *LI, + MemorySSAUpdater *MSSAU, bool PreserveLCSSA) { assert(OrigBB->isLandingPad() && "Trying to split a non-landing pad!"); @@ -570,7 +600,7 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, } bool HasLoopExit = false; - UpdateAnalysisInformation(OrigBB, NewBB1, Preds, DT, LI, PreserveLCSSA, + UpdateAnalysisInformation(OrigBB, NewBB1, Preds, DT, LI, MSSAU, PreserveLCSSA, HasLoopExit); // Update the PHI nodes in OrigBB with the values coming from NewBB1. @@ -606,7 +636,7 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, // Update DominatorTree, LoopInfo, and LCCSA analysis information. HasLoopExit = false; - UpdateAnalysisInformation(OrigBB, NewBB2, NewBB2Preds, DT, LI, + UpdateAnalysisInformation(OrigBB, NewBB2, NewBB2Preds, DT, LI, MSSAU, PreserveLCSSA, HasLoopExit); // Update the PHI nodes in OrigBB with the values coming from NewBB2. @@ -644,7 +674,8 @@ void llvm::SplitLandingPadPredecessors(BasicBlock *OrigBB, } ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, - BasicBlock *Pred) { + BasicBlock *Pred, + DomTreeUpdater *DTU) { Instruction *UncondBranch = Pred->getTerminator(); // Clone the return and add it to the end of the predecessor. Instruction *NewRet = RI->clone(); @@ -678,19 +709,24 @@ ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB, // longer branch to them. BB->removePredecessor(Pred); UncondBranch->eraseFromParent(); + + if (DTU) + DTU->deleteEdge(Pred, BB); + return cast<ReturnInst>(NewRet); } -TerminatorInst * -llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, - bool Unreachable, MDNode *BranchWeights, - DominatorTree *DT, LoopInfo *LI) { +Instruction *llvm::SplitBlockAndInsertIfThen(Value *Cond, + Instruction *SplitBefore, + bool Unreachable, + MDNode *BranchWeights, + DominatorTree *DT, LoopInfo *LI) { BasicBlock *Head = SplitBefore->getParent(); BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); - TerminatorInst *HeadOldTerm = Head->getTerminator(); + Instruction *HeadOldTerm = Head->getTerminator(); LLVMContext &C = Head->getContext(); BasicBlock *ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); - TerminatorInst *CheckTerm; + Instruction *CheckTerm; if (Unreachable) CheckTerm = new UnreachableInst(C, ThenBlock); else @@ -725,12 +761,12 @@ llvm::SplitBlockAndInsertIfThen(Value *Cond, Instruction *SplitBefore, } void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, - TerminatorInst **ThenTerm, - TerminatorInst **ElseTerm, + Instruction **ThenTerm, + Instruction **ElseTerm, MDNode *BranchWeights) { BasicBlock *Head = SplitBefore->getParent(); BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); - TerminatorInst *HeadOldTerm = Head->getTerminator(); + Instruction *HeadOldTerm = Head->getTerminator(); LLVMContext &C = Head->getContext(); BasicBlock *ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); BasicBlock *ElseBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index 3e30c27a9f33..fafc9aaba5c9 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -129,7 +130,7 @@ static void createPHIsForSplitLoopExit(ArrayRef<BasicBlock *> Preds, } BasicBlock * -llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, +llvm::SplitCriticalEdge(Instruction *TI, unsigned SuccNum, const CriticalEdgeSplittingOptions &Options) { if (!isCriticalEdge(TI, SuccNum, Options.MergeIdenticalEdges)) return nullptr; @@ -198,6 +199,11 @@ llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, // If we have nothing to update, just return. auto *DT = Options.DT; auto *LI = Options.LI; + auto *MSSAU = Options.MSSAU; + if (MSSAU) + MSSAU->wireOldPredecessorsToNewImmediatePredecessor( + DestBB, NewBB, {TIBB}, Options.MergeIdenticalEdges); + if (!DT && !LI) return NewBB; @@ -283,7 +289,7 @@ llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, if (!LoopPreds.empty()) { assert(!DestBB->isEHPad() && "We don't split edges to EH pads!"); BasicBlock *NewExitBB = SplitBlockPredecessors( - DestBB, LoopPreds, "split", DT, LI, Options.PreserveLCSSA); + DestBB, LoopPreds, "split", DT, LI, MSSAU, Options.PreserveLCSSA); if (Options.PreserveLCSSA) createPHIsForSplitLoopExit(LoopPreds, NewExitBB, DestBB); } @@ -312,7 +318,7 @@ findIBRPredecessor(BasicBlock *BB, SmallVectorImpl<BasicBlock *> &OtherPreds) { BasicBlock *IBB = nullptr; for (unsigned Pred = 0, E = PN->getNumIncomingValues(); Pred != E; ++Pred) { BasicBlock *PredBB = PN->getIncomingBlock(Pred); - TerminatorInst *PredTerm = PredBB->getTerminator(); + Instruction *PredTerm = PredBB->getTerminator(); switch (PredTerm->getOpcode()) { case Instruction::IndirectBr: if (IBB) diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp index d0396e6ce47d..3466dedd3236 100644 --- a/lib/Transforms/Utils/BuildLibCalls.cpp +++ b/lib/Transforms/Utils/BuildLibCalls.cpp @@ -38,6 +38,7 @@ STATISTIC(NumNoCapture, "Number of arguments inferred as nocapture"); STATISTIC(NumReadOnlyArg, "Number of arguments inferred as readonly"); STATISTIC(NumNoAlias, "Number of function returns inferred as noalias"); STATISTIC(NumNonNull, "Number of function returns inferred as nonnull returns"); +STATISTIC(NumReturnedArg, "Number of arguments inferred as returned"); static bool setDoesNotAccessMemory(Function &F) { if (F.doesNotAccessMemory()) @@ -105,6 +106,14 @@ static bool setRetNonNull(Function &F) { return true; } +static bool setReturnedArg(Function &F, unsigned ArgNo) { + if (F.hasParamAttribute(ArgNo, Attribute::Returned)) + return false; + F.addParamAttr(ArgNo, Attribute::Returned); + ++NumReturnedArg; + return true; +} + static bool setNonLazyBind(Function &F) { if (F.hasFnAttribute(Attribute::NonLazyBind)) return false; @@ -112,6 +121,14 @@ static bool setNonLazyBind(Function &F) { return true; } +bool llvm::inferLibFuncAttributes(Module *M, StringRef Name, + const TargetLibraryInfo &TLI) { + Function *F = M->getFunction(Name); + if (!F) + return false; + return inferLibFuncAttributes(*F, TLI); +} + bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { LibFunc TheLibFunc; if (!(TLI.getLibFunc(F, TheLibFunc) && TLI.has(TheLibFunc))) @@ -147,10 +164,12 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setOnlyReadsMemory(F, 0); return Changed; case LibFunc_strcpy: - case LibFunc_stpcpy: + case LibFunc_strncpy: case LibFunc_strcat: case LibFunc_strncat: - case LibFunc_strncpy: + Changed |= setReturnedArg(F, 0); + LLVM_FALLTHROUGH; + case LibFunc_stpcpy: case LibFunc_stpncpy: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); @@ -262,9 +281,11 @@ bool llvm::inferLibFuncAttributes(Function &F, const TargetLibraryInfo &TLI) { Changed |= setDoesNotCapture(F, 1); return Changed; case LibFunc_memcpy: + case LibFunc_memmove: + Changed |= setReturnedArg(F, 0); + LLVM_FALLTHROUGH; case LibFunc_mempcpy: case LibFunc_memccpy: - case LibFunc_memmove: Changed |= setDoesNotThrow(F); Changed |= setDoesNotCapture(F, 1); Changed |= setOnlyReadsMemory(F, 1); @@ -733,6 +754,8 @@ bool llvm::hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, LibFunc DoubleFn, LibFunc FloatFn, LibFunc LongDoubleFn) { switch (Ty->getTypeID()) { + case Type::HalfTyID: + return false; case Type::FloatTyID: return TLI->has(FloatFn); case Type::DoubleTyID: @@ -742,6 +765,24 @@ bool llvm::hasUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, } } +StringRef llvm::getUnaryFloatFn(const TargetLibraryInfo *TLI, Type *Ty, + LibFunc DoubleFn, LibFunc FloatFn, + LibFunc LongDoubleFn) { + assert(hasUnaryFloatFn(TLI, Ty, DoubleFn, FloatFn, LongDoubleFn) && + "Cannot get name for unavailable function!"); + + switch (Ty->getTypeID()) { + case Type::HalfTyID: + llvm_unreachable("No name for HalfTy!"); + case Type::FloatTyID: + return TLI->getName(FloatFn); + case Type::DoubleTyID: + return TLI->getName(DoubleFn); + default: + return TLI->getName(LongDoubleFn); + } +} + //- Emit LibCalls ------------------------------------------------------------// Value *llvm::castToCStr(Value *V, IRBuilder<> &B) { @@ -755,11 +796,12 @@ Value *llvm::emitStrLen(Value *Ptr, IRBuilder<> &B, const DataLayout &DL, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef StrlenName = TLI->getName(LibFunc_strlen); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Constant *StrLen = M->getOrInsertFunction("strlen", DL.getIntPtrType(Context), + Constant *StrLen = M->getOrInsertFunction(StrlenName, DL.getIntPtrType(Context), B.getInt8PtrTy()); - inferLibFuncAttributes(*M->getFunction("strlen"), *TLI); - CallInst *CI = B.CreateCall(StrLen, castToCStr(Ptr, B), "strlen"); + inferLibFuncAttributes(M, StrlenName, *TLI); + CallInst *CI = B.CreateCall(StrLen, castToCStr(Ptr, B), StrlenName); if (const Function *F = dyn_cast<Function>(StrLen->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -772,13 +814,14 @@ Value *llvm::emitStrChr(Value *Ptr, char C, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef StrChrName = TLI->getName(LibFunc_strchr); Type *I8Ptr = B.getInt8PtrTy(); Type *I32Ty = B.getInt32Ty(); Constant *StrChr = - M->getOrInsertFunction("strchr", I8Ptr, I8Ptr, I32Ty); - inferLibFuncAttributes(*M->getFunction("strchr"), *TLI); + M->getOrInsertFunction(StrChrName, I8Ptr, I8Ptr, I32Ty); + inferLibFuncAttributes(M, StrChrName, *TLI); CallInst *CI = B.CreateCall( - StrChr, {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, "strchr"); + StrChr, {castToCStr(Ptr, B), ConstantInt::get(I32Ty, C)}, StrChrName); if (const Function *F = dyn_cast<Function>(StrChr->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; @@ -790,13 +833,14 @@ Value *llvm::emitStrNCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef StrNCmpName = TLI->getName(LibFunc_strncmp); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *StrNCmp = M->getOrInsertFunction("strncmp", B.getInt32Ty(), + Value *StrNCmp = M->getOrInsertFunction(StrNCmpName, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)); - inferLibFuncAttributes(*M->getFunction("strncmp"), *TLI); + inferLibFuncAttributes(M, StrNCmpName, *TLI); CallInst *CI = B.CreateCall( - StrNCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "strncmp"); + StrNCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, StrNCmpName); if (const Function *F = dyn_cast<Function>(StrNCmp->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -812,7 +856,7 @@ Value *llvm::emitStrCpy(Value *Dst, Value *Src, IRBuilder<> &B, Module *M = B.GetInsertBlock()->getModule(); Type *I8Ptr = B.getInt8PtrTy(); Value *StrCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr); - inferLibFuncAttributes(*M->getFunction(Name), *TLI); + inferLibFuncAttributes(M, Name, *TLI); CallInst *CI = B.CreateCall(StrCpy, {castToCStr(Dst, B), castToCStr(Src, B)}, Name); if (const Function *F = dyn_cast<Function>(StrCpy->stripPointerCasts())) @@ -829,9 +873,9 @@ Value *llvm::emitStrNCpy(Value *Dst, Value *Src, Value *Len, IRBuilder<> &B, Type *I8Ptr = B.getInt8PtrTy(); Value *StrNCpy = M->getOrInsertFunction(Name, I8Ptr, I8Ptr, I8Ptr, Len->getType()); - inferLibFuncAttributes(*M->getFunction(Name), *TLI); + inferLibFuncAttributes(M, Name, *TLI); CallInst *CI = B.CreateCall( - StrNCpy, {castToCStr(Dst, B), castToCStr(Src, B), Len}, "strncpy"); + StrNCpy, {castToCStr(Dst, B), castToCStr(Src, B), Len}, Name); if (const Function *F = dyn_cast<Function>(StrNCpy->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; @@ -866,12 +910,13 @@ Value *llvm::emitMemChr(Value *Ptr, Value *Val, Value *Len, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef MemChrName = TLI->getName(LibFunc_memchr); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *MemChr = M->getOrInsertFunction("memchr", B.getInt8PtrTy(), + Value *MemChr = M->getOrInsertFunction(MemChrName, B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt32Ty(), DL.getIntPtrType(Context)); - inferLibFuncAttributes(*M->getFunction("memchr"), *TLI); - CallInst *CI = B.CreateCall(MemChr, {castToCStr(Ptr, B), Val, Len}, "memchr"); + inferLibFuncAttributes(M, MemChrName, *TLI); + CallInst *CI = B.CreateCall(MemChr, {castToCStr(Ptr, B), Val, Len}, MemChrName); if (const Function *F = dyn_cast<Function>(MemChr->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -885,13 +930,14 @@ Value *llvm::emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef MemCmpName = TLI->getName(LibFunc_memcmp); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *MemCmp = M->getOrInsertFunction("memcmp", B.getInt32Ty(), + Value *MemCmp = M->getOrInsertFunction(MemCmpName, B.getInt32Ty(), B.getInt8PtrTy(), B.getInt8PtrTy(), DL.getIntPtrType(Context)); - inferLibFuncAttributes(*M->getFunction("memcmp"), *TLI); + inferLibFuncAttributes(M, MemCmpName, *TLI); CallInst *CI = B.CreateCall( - MemCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, "memcmp"); + MemCmp, {castToCStr(Ptr1, B), castToCStr(Ptr2, B), Len}, MemCmpName); if (const Function *F = dyn_cast<Function>(MemCmp->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -914,10 +960,10 @@ static void appendTypeSuffix(Value *Op, StringRef &Name, } } -Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, - const AttributeList &Attrs) { - SmallString<20> NameBuffer; - appendTypeSuffix(Op, Name, NameBuffer); +static Value *emitUnaryFloatFnCallHelper(Value *Op, StringRef Name, + IRBuilder<> &B, + const AttributeList &Attrs) { + assert((Name != "") && "Must specify Name to emitUnaryFloatFnCall"); Module *M = B.GetInsertBlock()->getModule(); Value *Callee = M->getOrInsertFunction(Name, Op->getType(), @@ -936,8 +982,29 @@ Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, return CI; } +Value *llvm::emitUnaryFloatFnCall(Value *Op, StringRef Name, IRBuilder<> &B, + const AttributeList &Attrs) { + SmallString<20> NameBuffer; + appendTypeSuffix(Op, Name, NameBuffer); + + return emitUnaryFloatFnCallHelper(Op, Name, B, Attrs); +} + +Value *llvm::emitUnaryFloatFnCall(Value *Op, const TargetLibraryInfo *TLI, + LibFunc DoubleFn, LibFunc FloatFn, + LibFunc LongDoubleFn, IRBuilder<> &B, + const AttributeList &Attrs) { + // Get the name of the function according to TLI. + StringRef Name = getUnaryFloatFn(TLI, Op->getType(), + DoubleFn, FloatFn, LongDoubleFn); + + return emitUnaryFloatFnCallHelper(Op, Name, B, Attrs); +} + Value *llvm::emitBinaryFloatFnCall(Value *Op1, Value *Op2, StringRef Name, IRBuilder<> &B, const AttributeList &Attrs) { + assert((Name != "") && "Must specify Name to emitBinaryFloatFnCall"); + SmallString<20> NameBuffer; appendTypeSuffix(Op1, Name, NameBuffer); @@ -958,14 +1025,15 @@ Value *llvm::emitPutChar(Value *Char, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); - Value *PutChar = M->getOrInsertFunction("putchar", B.getInt32Ty(), B.getInt32Ty()); - inferLibFuncAttributes(*M->getFunction("putchar"), *TLI); + StringRef PutCharName = TLI->getName(LibFunc_putchar); + Value *PutChar = M->getOrInsertFunction(PutCharName, B.getInt32Ty(), B.getInt32Ty()); + inferLibFuncAttributes(M, PutCharName, *TLI); CallInst *CI = B.CreateCall(PutChar, B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"), - "putchar"); + PutCharName); if (const Function *F = dyn_cast<Function>(PutChar->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -978,10 +1046,11 @@ Value *llvm::emitPutS(Value *Str, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef PutsName = TLI->getName(LibFunc_puts); Value *PutS = - M->getOrInsertFunction("puts", B.getInt32Ty(), B.getInt8PtrTy()); - inferLibFuncAttributes(*M->getFunction("puts"), *TLI); - CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), "puts"); + M->getOrInsertFunction(PutsName, B.getInt32Ty(), B.getInt8PtrTy()); + inferLibFuncAttributes(M, PutsName, *TLI); + CallInst *CI = B.CreateCall(PutS, castToCStr(Str, B), PutsName); if (const Function *F = dyn_cast<Function>(PutS->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); return CI; @@ -993,13 +1062,14 @@ Value *llvm::emitFPutC(Value *Char, Value *File, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); - Constant *F = M->getOrInsertFunction("fputc", B.getInt32Ty(), B.getInt32Ty(), + StringRef FPutcName = TLI->getName(LibFunc_fputc); + Constant *F = M->getOrInsertFunction(FPutcName, B.getInt32Ty(), B.getInt32Ty(), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction("fputc"), *TLI); + inferLibFuncAttributes(M, FPutcName, *TLI); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/true, "chari"); - CallInst *CI = B.CreateCall(F, {Char, File}, "fputc"); + CallInst *CI = B.CreateCall(F, {Char, File}, FPutcName); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); @@ -1012,12 +1082,13 @@ Value *llvm::emitFPutCUnlocked(Value *Char, Value *File, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); - Constant *F = M->getOrInsertFunction("fputc_unlocked", B.getInt32Ty(), + StringRef FPutcUnlockedName = TLI->getName(LibFunc_fputc_unlocked); + Constant *F = M->getOrInsertFunction(FPutcUnlockedName, B.getInt32Ty(), B.getInt32Ty(), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction("fputc_unlocked"), *TLI); + inferLibFuncAttributes(M, FPutcUnlockedName, *TLI); Char = B.CreateIntCast(Char, B.getInt32Ty(), /*isSigned*/ true, "chari"); - CallInst *CI = B.CreateCall(F, {Char, File}, "fputc_unlocked"); + CallInst *CI = B.CreateCall(F, {Char, File}, FPutcUnlockedName); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); @@ -1034,8 +1105,8 @@ Value *llvm::emitFPutS(Value *Str, Value *File, IRBuilder<> &B, Constant *F = M->getOrInsertFunction( FPutsName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction(FPutsName), *TLI); - CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, "fputs"); + inferLibFuncAttributes(M, FPutsName, *TLI); + CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, FPutsName); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); @@ -1052,8 +1123,8 @@ Value *llvm::emitFPutSUnlocked(Value *Str, Value *File, IRBuilder<> &B, Constant *F = M->getOrInsertFunction(FPutsUnlockedName, B.getInt32Ty(), B.getInt8PtrTy(), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction(FPutsUnlockedName), *TLI); - CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, "fputs_unlocked"); + inferLibFuncAttributes(M, FPutsUnlockedName, *TLI); + CallInst *CI = B.CreateCall(F, {castToCStr(Str, B), File}, FPutsUnlockedName); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); @@ -1073,7 +1144,7 @@ Value *llvm::emitFWrite(Value *Ptr, Value *Size, Value *File, IRBuilder<> &B, DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction(FWriteName), *TLI); + inferLibFuncAttributes(M, FWriteName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, ConstantInt::get(DL.getIntPtrType(Context), 1), File}); @@ -1089,11 +1160,12 @@ Value *llvm::emitMalloc(Value *Num, IRBuilder<> &B, const DataLayout &DL, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef MallocName = TLI->getName(LibFunc_malloc); LLVMContext &Context = B.GetInsertBlock()->getContext(); - Value *Malloc = M->getOrInsertFunction("malloc", B.getInt8PtrTy(), + Value *Malloc = M->getOrInsertFunction(MallocName, B.getInt8PtrTy(), DL.getIntPtrType(Context)); - inferLibFuncAttributes(*M->getFunction("malloc"), *TLI); - CallInst *CI = B.CreateCall(Malloc, Num, "malloc"); + inferLibFuncAttributes(M, MallocName, *TLI); + CallInst *CI = B.CreateCall(Malloc, Num, MallocName); if (const Function *F = dyn_cast<Function>(Malloc->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1107,12 +1179,13 @@ Value *llvm::emitCalloc(Value *Num, Value *Size, const AttributeList &Attrs, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef CallocName = TLI.getName(LibFunc_calloc); const DataLayout &DL = M->getDataLayout(); IntegerType *PtrType = DL.getIntPtrType((B.GetInsertBlock()->getContext())); - Value *Calloc = M->getOrInsertFunction("calloc", Attrs, B.getInt8PtrTy(), + Value *Calloc = M->getOrInsertFunction(CallocName, Attrs, B.getInt8PtrTy(), PtrType, PtrType); - inferLibFuncAttributes(*M->getFunction("calloc"), TLI); - CallInst *CI = B.CreateCall(Calloc, {Num, Size}, "calloc"); + inferLibFuncAttributes(M, CallocName, TLI); + CallInst *CI = B.CreateCall(Calloc, {Num, Size}, CallocName); if (const auto *F = dyn_cast<Function>(Calloc->stripPointerCasts())) CI->setCallingConv(F->getCallingConv()); @@ -1134,7 +1207,7 @@ Value *llvm::emitFWriteUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction(FWriteUnlockedName), *TLI); + inferLibFuncAttributes(M, FWriteUnlockedName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, N, File}); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) @@ -1148,11 +1221,12 @@ Value *llvm::emitFGetCUnlocked(Value *File, IRBuilder<> &B, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef FGetCUnlockedName = TLI->getName(LibFunc_fgetc_unlocked); Constant *F = - M->getOrInsertFunction("fgetc_unlocked", B.getInt32Ty(), File->getType()); + M->getOrInsertFunction(FGetCUnlockedName, B.getInt32Ty(), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction("fgetc_unlocked"), *TLI); - CallInst *CI = B.CreateCall(F, File, "fgetc_unlocked"); + inferLibFuncAttributes(M, FGetCUnlockedName, *TLI); + CallInst *CI = B.CreateCall(F, File, FGetCUnlockedName); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); @@ -1165,12 +1239,13 @@ Value *llvm::emitFGetSUnlocked(Value *Str, Value *Size, Value *File, return nullptr; Module *M = B.GetInsertBlock()->getModule(); + StringRef FGetSUnlockedName = TLI->getName(LibFunc_fgets_unlocked); Constant *F = - M->getOrInsertFunction("fgets_unlocked", B.getInt8PtrTy(), + M->getOrInsertFunction(FGetSUnlockedName, B.getInt8PtrTy(), B.getInt8PtrTy(), B.getInt32Ty(), File->getType()); - inferLibFuncAttributes(*M->getFunction("fgets_unlocked"), *TLI); + inferLibFuncAttributes(M, FGetSUnlockedName, *TLI); CallInst *CI = - B.CreateCall(F, {castToCStr(Str, B), Size, File}, "fgets_unlocked"); + B.CreateCall(F, {castToCStr(Str, B), Size, File}, FGetSUnlockedName); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) CI->setCallingConv(Fn->getCallingConv()); @@ -1191,7 +1266,7 @@ Value *llvm::emitFReadUnlocked(Value *Ptr, Value *Size, Value *N, Value *File, DL.getIntPtrType(Context), DL.getIntPtrType(Context), File->getType()); if (File->getType()->isPointerTy()) - inferLibFuncAttributes(*M->getFunction(FReadUnlockedName), *TLI); + inferLibFuncAttributes(M, FReadUnlockedName, *TLI); CallInst *CI = B.CreateCall(F, {castToCStr(Ptr, B), Size, N, File}); if (const Function *Fn = dyn_cast<Function>(F->stripPointerCasts())) diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp index 05512a6dff3e..e7828af648a9 100644 --- a/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -388,6 +388,15 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { return None; } + // After Constant Hoisting pass, long constants may be represented as + // bitcast instructions. As a result, some constants may look like an + // instruction at first, and an additional check is necessary to find out if + // an operand is actually a constant. + if (auto *BCI = dyn_cast<BitCastInst>(Divisor)) + if (BCI->getParent() == SlowDivOrRem->getParent() && + isa<ConstantInt>(BCI->getOperand(0))) + return None; + if (DividendShort && !isSignedOp()) { // If the division is unsigned and Dividend is known to be short, then // either diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index c87b74f739f4..cb3dc17c03ad 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -6,6 +6,7 @@ add_llvm_library(LLVMTransformUtils BuildLibCalls.cpp BypassSlowDivision.cpp CallPromotionUtils.cpp + CanonicalizeAliases.cpp CloneFunction.cpp CloneModule.cpp CodeExtractor.cpp @@ -18,6 +19,7 @@ add_llvm_library(LLVMTransformUtils FunctionComparator.cpp FunctionImportUtils.cpp GlobalStatus.cpp + GuardUtils.cpp InlineFunction.cpp ImportedFunctionsInliningStatistics.cpp InstructionNamer.cpp @@ -40,7 +42,6 @@ add_llvm_library(LLVMTransformUtils MetaRenamer.cpp ModuleUtils.cpp NameAnonGlobals.cpp - OrderedInstructions.cpp PredicateInfo.cpp PromoteMemoryToRegister.cpp StripGCRelocates.cpp diff --git a/lib/Transforms/Utils/CallPromotionUtils.cpp b/lib/Transforms/Utils/CallPromotionUtils.cpp index 6d18d0614611..e58ddcf34667 100644 --- a/lib/Transforms/Utils/CallPromotionUtils.cpp +++ b/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -177,8 +177,8 @@ static void createRetBitCast(CallSite CS, Type *RetTy, CastInst **RetBitCast) { InsertBefore = &*std::next(CS.getInstruction()->getIterator()); // Bitcast the return value to the correct type. - auto *Cast = CastInst::Create(Instruction::BitCast, CS.getInstruction(), - RetTy, "", InsertBefore); + auto *Cast = CastInst::CreateBitOrPointerCast(CS.getInstruction(), RetTy, "", + InsertBefore); if (RetBitCast) *RetBitCast = Cast; @@ -270,8 +270,8 @@ static Instruction *versionCallSite(CallSite CS, Value *Callee, // Create an if-then-else structure. The original instruction is moved into // the "else" block, and a clone of the original instruction is placed in the // "then" block. - TerminatorInst *ThenTerm = nullptr; - TerminatorInst *ElseTerm = nullptr; + Instruction *ThenTerm = nullptr; + Instruction *ElseTerm = nullptr; SplitBlockAndInsertIfThenElse(Cond, CS.getInstruction(), &ThenTerm, &ElseTerm, BranchWeights); BasicBlock *ThenBlock = ThenTerm->getParent(); @@ -321,12 +321,14 @@ bool llvm::isLegalToPromote(CallSite CS, Function *Callee, const char **FailureReason) { assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted"); + auto &DL = Callee->getParent()->getDataLayout(); + // Check the return type. The callee's return value type must be bitcast // compatible with the call site's type. Type *CallRetTy = CS.getInstruction()->getType(); Type *FuncRetTy = Callee->getReturnType(); if (CallRetTy != FuncRetTy) - if (!CastInst::isBitCastable(FuncRetTy, CallRetTy)) { + if (!CastInst::isBitOrNoopPointerCastable(FuncRetTy, CallRetTy, DL)) { if (FailureReason) *FailureReason = "Return type mismatch"; return false; @@ -351,7 +353,7 @@ bool llvm::isLegalToPromote(CallSite CS, Function *Callee, Type *ActualTy = CS.getArgument(I)->getType(); if (FormalTy == ActualTy) continue; - if (!CastInst::isBitCastable(ActualTy, FormalTy)) { + if (!CastInst::isBitOrNoopPointerCastable(ActualTy, FormalTy, DL)) { if (FailureReason) *FailureReason = "Argument type mismatch"; return false; @@ -391,21 +393,46 @@ Instruction *llvm::promoteCall(CallSite CS, Function *Callee, // to the correct type. auto CalleeType = Callee->getFunctionType(); auto CalleeParamNum = CalleeType->getNumParams(); + + LLVMContext &Ctx = Callee->getContext(); + const AttributeList &CallerPAL = CS.getAttributes(); + // The new list of argument attributes. + SmallVector<AttributeSet, 4> NewArgAttrs; + bool AttributeChanged = false; + for (unsigned ArgNo = 0; ArgNo < CalleeParamNum; ++ArgNo) { auto *Arg = CS.getArgument(ArgNo); Type *FormalTy = CalleeType->getParamType(ArgNo); Type *ActualTy = Arg->getType(); if (FormalTy != ActualTy) { - auto *Cast = CastInst::Create(Instruction::BitCast, Arg, FormalTy, "", - CS.getInstruction()); + auto *Cast = CastInst::CreateBitOrPointerCast(Arg, FormalTy, "", + CS.getInstruction()); CS.setArgument(ArgNo, Cast); - } + + // Remove any incompatible attributes for the argument. + AttrBuilder ArgAttrs(CallerPAL.getParamAttributes(ArgNo)); + ArgAttrs.remove(AttributeFuncs::typeIncompatible(FormalTy)); + NewArgAttrs.push_back(AttributeSet::get(Ctx, ArgAttrs)); + AttributeChanged = true; + } else + NewArgAttrs.push_back(CallerPAL.getParamAttributes(ArgNo)); } // If the return type of the call site doesn't match that of the callee, cast // the returned value to the appropriate type. - if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) + // Remove any incompatible return value attribute. + AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex); + if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) { createRetBitCast(CS, CallSiteRetTy, RetBitCast); + RAttrs.remove(AttributeFuncs::typeIncompatible(CalleeRetTy)); + AttributeChanged = true; + } + + // Set the new callsite attribute. + if (AttributeChanged) + CS.setAttributes(AttributeList::get(Ctx, CallerPAL.getFnAttributes(), + AttributeSet::get(Ctx, RAttrs), + NewArgAttrs)); return CS.getInstruction(); } diff --git a/lib/Transforms/Utils/CanonicalizeAliases.cpp b/lib/Transforms/Utils/CanonicalizeAliases.cpp new file mode 100644 index 000000000000..cf41fd2e14c0 --- /dev/null +++ b/lib/Transforms/Utils/CanonicalizeAliases.cpp @@ -0,0 +1,105 @@ +//===- CanonicalizeAliases.cpp - ThinLTO Support: Canonicalize Aliases ----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Currently this file implements partial alias canonicalization, to +// flatten chains of aliases (also done by GlobalOpt, but not on for +// O0 compiles). E.g. +// @a = alias i8, i8 *@b +// @b = alias i8, i8 *@g +// +// will be converted to: +// @a = alias i8, i8 *@g <-- @a is now an alias to base object @g +// @b = alias i8, i8 *@g +// +// Eventually this file will implement full alias canonicalation, so that +// all aliasees are private anonymous values. E.g. +// @a = alias i8, i8 *@g +// @g = global i8 0 +// +// will be converted to: +// @0 = private global +// @a = alias i8, i8* @0 +// @g = alias i8, i8* @0 +// +// This simplifies optimization and ThinLTO linking of the original symbols. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CanonicalizeAliases.h" + +#include "llvm/IR/Operator.h" +#include "llvm/IR/ValueHandle.h" + +using namespace llvm; + +namespace { + +static Constant *canonicalizeAlias(Constant *C, bool &Changed) { + if (auto *GA = dyn_cast<GlobalAlias>(C)) { + auto *NewAliasee = canonicalizeAlias(GA->getAliasee(), Changed); + if (NewAliasee != GA->getAliasee()) { + GA->setAliasee(NewAliasee); + Changed = true; + } + return NewAliasee; + } + + auto *CE = dyn_cast<ConstantExpr>(C); + if (!CE) + return C; + + std::vector<Constant *> Ops; + for (Use &U : CE->operands()) + Ops.push_back(canonicalizeAlias(cast<Constant>(U), Changed)); + return CE->getWithOperands(Ops); +} + +/// Convert aliases to canonical form. +static bool canonicalizeAliases(Module &M) { + bool Changed = false; + for (auto &GA : M.aliases()) + canonicalizeAlias(&GA, Changed); + return Changed; +} + +// Legacy pass that canonicalizes aliases. +class CanonicalizeAliasesLegacyPass : public ModulePass { + +public: + /// Pass identification, replacement for typeid + static char ID; + + /// Specify pass name for debug output + StringRef getPassName() const override { return "Canonicalize Aliases"; } + + explicit CanonicalizeAliasesLegacyPass() : ModulePass(ID) {} + + bool runOnModule(Module &M) override { return canonicalizeAliases(M); } +}; +char CanonicalizeAliasesLegacyPass::ID = 0; + +} // anonymous namespace + +PreservedAnalyses CanonicalizeAliasesPass::run(Module &M, + ModuleAnalysisManager &AM) { + if (!canonicalizeAliases(M)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +INITIALIZE_PASS_BEGIN(CanonicalizeAliasesLegacyPass, "canonicalize-aliases", + "Canonicalize aliases", false, false) +INITIALIZE_PASS_END(CanonicalizeAliasesLegacyPass, "canonicalize-aliases", + "Canonicalize aliases", false, false) + +namespace llvm { +ModulePass *createCanonicalizeAliasesPass() { + return new CanonicalizeAliasesLegacyPass(); +} +} // namespace llvm diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 807360340055..8f8c601f5f13 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -18,11 +18,11 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" @@ -32,6 +32,7 @@ #include "llvm/IR/Module.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" #include <map> using namespace llvm; @@ -235,8 +236,8 @@ Function *llvm::CloneFunction(Function *F, ValueToValueMapTy &VMap, ArgTypes, F->getFunctionType()->isVarArg()); // Create the new function... - Function *NewF = - Function::Create(FTy, F->getLinkage(), F->getName(), F->getParent()); + Function *NewF = Function::Create(FTy, F->getLinkage(), F->getAddressSpace(), + F->getName(), F->getParent()); // Loop over the arguments, copying the names of the mapped arguments over... Function::arg_iterator DestI = NewF->arg_begin(); @@ -365,7 +366,7 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, } // Finally, clone over the terminator. - const TerminatorInst *OldTI = BB->getTerminator(); + const Instruction *OldTI = BB->getTerminator(); bool TerminatorDone = false; if (const BranchInst *BI = dyn_cast<BranchInst>(OldTI)) { if (BI->isConditional()) { @@ -414,8 +415,8 @@ void PruningFunctionCloner::CloneBlock(const BasicBlock *BB, CodeInfo->OperandBundleCallSites.push_back(NewInst); // Recursively clone any reachable successor blocks. - const TerminatorInst *TI = BB->getTerminator(); - for (const BasicBlock *Succ : TI->successors()) + const Instruction *TI = BB->getTerminator(); + for (const BasicBlock *Succ : successors(TI)) ToClone.push_back(Succ); } @@ -636,6 +637,22 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, Function::iterator Begin = cast<BasicBlock>(VMap[StartingBB])->getIterator(); Function::iterator I = Begin; while (I != NewFunc->end()) { + // We need to simplify conditional branches and switches with a constant + // operand. We try to prune these out when cloning, but if the + // simplification required looking through PHI nodes, those are only + // available after forming the full basic block. That may leave some here, + // and we still want to prune the dead code as early as possible. + // + // Do the folding before we check if the block is dead since we want code + // like + // bb: + // br i1 undef, label %bb, label %bb + // to be simplified to + // bb: + // br label %bb + // before we call I->getSinglePredecessor(). + ConstantFoldTerminator(&*I); + // Check if this block has become dead during inlining or other // simplifications. Note that the first block will appear dead, as it has // not yet been wired up properly. @@ -646,13 +663,6 @@ void llvm::CloneAndPruneIntoFromInst(Function *NewFunc, const Function *OldFunc, continue; } - // We need to simplify conditional branches and switches with a constant - // operand. We try to prune these out when cloning, but if the - // simplification required looking through PHI nodes, those are only - // available after forming the full basic block. That may leave some here, - // and we still want to prune the dead code as early as possible. - ConstantFoldTerminator(&*I); - BranchInst *BI = dyn_cast<BranchInst>(I->getTerminator()); if (!BI || BI->isConditional()) { ++I; continue; } @@ -786,11 +796,12 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, /// Duplicate non-Phi instructions from the beginning of block up to /// StopAt instruction into a split block between BB and its predecessor. -BasicBlock * -llvm::DuplicateInstructionsInSplitBetween(BasicBlock *BB, BasicBlock *PredBB, - Instruction *StopAt, - ValueToValueMapTy &ValueMapping, - DominatorTree *DT) { +BasicBlock *llvm::DuplicateInstructionsInSplitBetween( + BasicBlock *BB, BasicBlock *PredBB, Instruction *StopAt, + ValueToValueMapTy &ValueMapping, DomTreeUpdater &DTU) { + + assert(count(successors(PredBB), BB) == 1 && + "There must be a single edge between PredBB and BB!"); // We are going to have to map operands from the original BB block to the new // copy of the block 'NewBB'. If there are PHI nodes in BB, evaluate them to // account for entry from PredBB. @@ -798,10 +809,16 @@ llvm::DuplicateInstructionsInSplitBetween(BasicBlock *BB, BasicBlock *PredBB, for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); - BasicBlock *NewBB = SplitEdge(PredBB, BB, DT); + BasicBlock *NewBB = SplitEdge(PredBB, BB); NewBB->setName(PredBB->getName() + ".split"); Instruction *NewTerm = NewBB->getTerminator(); + // FIXME: SplitEdge does not yet take a DTU, so we include the split edge + // in the update set here. + DTU.applyUpdates({{DominatorTree::Delete, PredBB, BB}, + {DominatorTree::Insert, PredBB, NewBB}, + {DominatorTree::Insert, NewBB, BB}}); + // Clone the non-phi instructions of BB into NewBB, keeping track of the // mapping and using it to remap operands in the cloned instructions. // Stop once we see the terminator too. This covers the case where BB's diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index c7d68bab8170..659993aa5478 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -74,8 +74,9 @@ std::unique_ptr<Module> llvm::CloneModule( // Loop over the functions in the module, making external functions as before for (const Function &I : M) { - Function *NF = Function::Create(cast<FunctionType>(I.getValueType()), - I.getLinkage(), I.getName(), New.get()); + Function *NF = + Function::Create(cast<FunctionType>(I.getValueType()), I.getLinkage(), + I.getAddressSpace(), I.getName(), New.get()); NF->copyAttributesFrom(&I); VMap[&I] = NF; } @@ -91,8 +92,8 @@ std::unique_ptr<Module> llvm::CloneModule( GlobalValue *GV; if (I->getValueType()->isFunctionTy()) GV = Function::Create(cast<FunctionType>(I->getValueType()), - GlobalValue::ExternalLinkage, I->getName(), - New.get()); + GlobalValue::ExternalLinkage, + I->getAddressSpace(), I->getName(), New.get()); else GV = new GlobalVariable( *New, I->getValueType(), false, GlobalValue::ExternalLinkage, diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index cb349e34606c..25d4ae583ecc 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -57,6 +57,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include <cassert> #include <cstdint> #include <iterator> @@ -167,14 +168,22 @@ static bool isBlockValidForExtraction(const BasicBlock &BB, continue; } - if (const CallInst *CI = dyn_cast<CallInst>(I)) - if (const Function *F = CI->getCalledFunction()) - if (F->getIntrinsicID() == Intrinsic::vastart) { + if (const CallInst *CI = dyn_cast<CallInst>(I)) { + if (const Function *F = CI->getCalledFunction()) { + auto IID = F->getIntrinsicID(); + if (IID == Intrinsic::vastart) { if (AllowVarArgs) continue; else return false; } + + // Currently, we miscompile outlined copies of eh_typid_for. There are + // proposals for fixing this in llvm.org/PR39545. + if (IID == Intrinsic::eh_typeid_for) + return false; + } + } } return true; @@ -228,19 +237,21 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI, bool AllowVarArgs, - bool AllowAlloca) + bool AllowAlloca, std::string Suffix) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AllowVarArgs(AllowVarArgs), - Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)) {} + Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)), + Suffix(Suffix) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI) + BranchProbabilityInfo *BPI, std::string Suffix) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), BPI(BPI), AllowVarArgs(false), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, /* AllowVarArgs */ false, - /* AllowAlloca */ false)) {} + /* AllowAlloca */ false)), + Suffix(Suffix) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -321,8 +332,7 @@ bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( default: { IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II); if (IntrInst) { - if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start || - IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) + if (IntrInst->isLifetimeStartOrEnd()) break; return false; } @@ -520,10 +530,10 @@ void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, } } -/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the -/// region, we need to split the entry block of the region so that the PHI node -/// is easier to deal with. -void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { +/// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside +/// of the region, we need to split the entry block of the region so that the +/// PHI node is easier to deal with. +void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) { unsigned NumPredsFromRegion = 0; unsigned NumPredsOutsideRegion = 0; @@ -566,7 +576,7 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // changing them to branch to NewBB instead. for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (Blocks.count(PN->getIncomingBlock(i))) { - TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator(); + Instruction *TI = PN->getIncomingBlock(i)->getTerminator(); TI->replaceUsesOfWith(OldPred, NewBB); } @@ -595,6 +605,56 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { } } +/// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from +/// outlined region, we split these PHIs on two: one with inputs from region +/// and other with remaining incoming blocks; then first PHIs are placed in +/// outlined region. +void CodeExtractor::severSplitPHINodesOfExits( + const SmallPtrSetImpl<BasicBlock *> &Exits) { + for (BasicBlock *ExitBB : Exits) { + BasicBlock *NewBB = nullptr; + + for (PHINode &PN : ExitBB->phis()) { + // Find all incoming values from the outlining region. + SmallVector<unsigned, 2> IncomingVals; + for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i) + if (Blocks.count(PN.getIncomingBlock(i))) + IncomingVals.push_back(i); + + // Do not process PHI if there is one (or fewer) predecessor from region. + // If PHI has exactly one predecessor from region, only this one incoming + // will be replaced on codeRepl block, so it should be safe to skip PHI. + if (IncomingVals.size() <= 1) + continue; + + // Create block for new PHIs and add it to the list of outlined if it + // wasn't done before. + if (!NewBB) { + NewBB = BasicBlock::Create(ExitBB->getContext(), + ExitBB->getName() + ".split", + ExitBB->getParent(), ExitBB); + SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBB), + pred_end(ExitBB)); + for (BasicBlock *PredBB : Preds) + if (Blocks.count(PredBB)) + PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB); + BranchInst::Create(ExitBB, NewBB); + Blocks.insert(NewBB); + } + + // Split this PHI. + PHINode *NewPN = + PHINode::Create(PN.getType(), IncomingVals.size(), + PN.getName() + ".ce", NewBB->getFirstNonPHI()); + for (unsigned i : IncomingVals) + NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i)); + for (unsigned i : reverse(IncomingVals)) + PN.removeIncomingValue(i, false); + PN.addIncoming(NewPN, NewBB); + } + } +} + void CodeExtractor::splitReturnBlocks() { for (BasicBlock *Block : Blocks) if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) { @@ -669,11 +729,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, FunctionType::get(RetTy, paramTy, AllowVarArgs && oldFunction->isVarArg()); + std::string SuffixToUse = + Suffix.empty() + ? (header->getName().empty() ? "extracted" : header->getName().str()) + : Suffix; // Create the new function - Function *newFunction = Function::Create(funcType, - GlobalValue::InternalLinkage, - oldFunction->getName() + "_" + - header->getName(), M); + Function *newFunction = Function::Create( + funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(), + oldFunction->getName() + "." + SuffixToUse, M); // If the old function is no-throw, so is the new one. if (oldFunction->doesNotThrow()) newFunction->setDoesNotThrow(); @@ -754,6 +817,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, case Attribute::SanitizeMemory: case Attribute::SanitizeThread: case Attribute::SanitizeHWAddress: + case Attribute::SpeculativeLoadHardening: case Attribute::StackProtect: case Attribute::StackProtectReq: case Attribute::StackProtectStrong: @@ -778,7 +842,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, Value *Idx[2]; Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i); - TerminatorInst *TI = newFunction->begin()->getTerminator(); + Instruction *TI = newFunction->begin()->getTerminator(); GetElementPtrInst *GEP = GetElementPtrInst::Create( StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI); RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI); @@ -808,10 +872,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, for (unsigned i = 0, e = Users.size(); i != e; ++i) // The BasicBlock which contains the branch is not in the region // modify the branch target to a new block - if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i])) - if (!Blocks.count(TI->getParent()) && - TI->getParent()->getParent() == oldFunction) - TI->replaceUsesOfWith(header, newHeader); + if (Instruction *I = dyn_cast<Instruction>(Users[i])) + if (I->isTerminator() && !Blocks.count(I->getParent()) && + I->getParent()->getParent() == oldFunction) + I->replaceUsesOfWith(header, newHeader); return newFunction; } @@ -819,9 +883,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, /// emitCallAndSwitchStatement - This method sets up the caller side by adding /// the call instruction, splitting any PHI nodes in the header block as /// necessary. -void CodeExtractor:: -emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, - ValueSet &inputs, ValueSet &outputs) { +CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, + BasicBlock *codeReplacer, + ValueSet &inputs, + ValueSet &outputs) { // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs std::vector<Value *> params, StructValues, ReloadOutputs, Reloads; @@ -829,6 +894,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, Module *M = newFunction->getParent(); LLVMContext &Context = M->getContext(); const DataLayout &DL = M->getDataLayout(); + CallInst *call = nullptr; // Add inputs as params, or to be filled into the struct for (Value *input : inputs) @@ -879,8 +945,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, } // Emit the call to the function - CallInst *call = CallInst::Create(newFunction, params, - NumExitBlocks > 1 ? "targetBlock" : ""); + call = CallInst::Create(newFunction, params, + NumExitBlocks > 1 ? "targetBlock" : ""); // Add debug location to the new call, if the original function has debug // info. In that case, the terminator of the entry block of the extracted // function contains the first debug location of the extracted function, @@ -925,11 +991,17 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, auto *OutI = dyn_cast<Instruction>(outputs[i]); if (!OutI) continue; + // Find proper insertion point. - Instruction *InsertPt = OutI->getNextNode(); - // Let's assume that there is no other guy interleave non-PHI in PHIs. - if (isa<PHINode>(InsertPt)) - InsertPt = InsertPt->getParent()->getFirstNonPHI(); + BasicBlock::iterator InsertPt; + // In case OutI is an invoke, we insert the store at the beginning in the + // 'normal destination' BB. Otherwise we insert the store right after OutI. + if (auto *InvokeI = dyn_cast<InvokeInst>(OutI)) + InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt(); + else if (auto *Phi = dyn_cast<PHINode>(OutI)) + InsertPt = Phi->getParent()->getFirstInsertionPt(); + else + InsertPt = std::next(OutI->getIterator()); assert(OAI != newFunction->arg_end() && "Number of output arguments should match " @@ -939,13 +1011,13 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), InsertPt); - new StoreInst(outputs[i], GEP, InsertPt); + StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt); + new StoreInst(outputs[i], GEP, &*InsertPt); // Since there should be only one struct argument aggregating // all the output values, we shouldn't increment OAI, which always // points to the struct argument, in this case. } else { - new StoreInst(outputs[i], &*OAI, InsertPt); + new StoreInst(outputs[i], &*OAI, &*InsertPt); ++OAI; } } @@ -964,7 +1036,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, unsigned switchVal = 0; for (BasicBlock *Block : Blocks) { - TerminatorInst *TI = Block->getTerminator(); + Instruction *TI = Block->getTerminator(); for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) if (!Blocks.count(TI->getSuccessor(i))) { BasicBlock *OldTarget = TI->getSuccessor(i); @@ -1046,6 +1118,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); break; } + + return call; } void CodeExtractor::moveCodeToFunction(Function *newFunction) { @@ -1070,7 +1144,7 @@ void CodeExtractor::calculateNewCallTerminatorWeights( using BlockNode = BlockFrequencyInfoImplBase::BlockNode; // Update the branch weights for the exit block. - TerminatorInst *TI = CodeReplacer->getTerminator(); + Instruction *TI = CodeReplacer->getTerminator(); SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0); // Block Frequency distribution with dummy node. @@ -1107,6 +1181,71 @@ void CodeExtractor::calculateNewCallTerminatorWeights( MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); } +/// Scan the extraction region for lifetime markers which reference inputs. +/// Erase these markers. Return the inputs which were referenced. +/// +/// The extraction region is defined by a set of blocks (\p Blocks), and a set +/// of allocas which will be moved from the caller function into the extracted +/// function (\p SunkAllocas). +static SetVector<Value *> +eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks, + const SetVector<Value *> &SunkAllocas) { + SetVector<Value *> InputObjectsWithLifetime; + for (BasicBlock *BB : Blocks) { + for (auto It = BB->begin(), End = BB->end(); It != End;) { + auto *II = dyn_cast<IntrinsicInst>(&*It); + ++It; + if (!II || !II->isLifetimeStartOrEnd()) + continue; + + // Get the memory operand of the lifetime marker. If the underlying + // object is a sunk alloca, or is otherwise defined in the extraction + // region, the lifetime marker must not be erased. + Value *Mem = II->getOperand(1)->stripInBoundsOffsets(); + if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem)) + continue; + + InputObjectsWithLifetime.insert(Mem); + II->eraseFromParent(); + } + } + return InputObjectsWithLifetime; +} + +/// Insert lifetime start/end markers surrounding the call to the new function +/// for objects defined in the caller. +static void insertLifetimeMarkersSurroundingCall( + Module *M, const SetVector<Value *> &InputObjectsWithLifetime, + CallInst *TheCall) { + if (InputObjectsWithLifetime.empty()) + return; + + LLVMContext &Ctx = M->getContext(); + auto Int8PtrTy = Type::getInt8PtrTy(Ctx); + auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1); + auto LifetimeStartFn = llvm::Intrinsic::getDeclaration( + M, llvm::Intrinsic::lifetime_start, Int8PtrTy); + auto LifetimeEndFn = llvm::Intrinsic::getDeclaration( + M, llvm::Intrinsic::lifetime_end, Int8PtrTy); + for (Value *Mem : InputObjectsWithLifetime) { + assert((!isa<Instruction>(Mem) || + cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) && + "Input memory not defined in original function"); + Value *MemAsI8Ptr = nullptr; + if (Mem->getType() == Int8PtrTy) + MemAsI8Ptr = Mem; + else + MemAsI8Ptr = + CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall); + + auto StartMarker = + CallInst::Create(LifetimeStartFn, {NegativeOne, MemAsI8Ptr}); + StartMarker->insertBefore(TheCall); + auto EndMarker = CallInst::Create(LifetimeEndFn, {NegativeOne, MemAsI8Ptr}); + EndMarker->insertAfter(TheCall); + } +} + Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; @@ -1150,13 +1289,33 @@ Function *CodeExtractor::extractCodeRegion() { } } - // If we have to split PHI nodes or the entry block, do so now. - severSplitPHINodes(header); - // If we have any return instructions in the region, split those blocks so // that the return is not in the region. splitReturnBlocks(); + // Calculate the exit blocks for the extracted region and the total exit + // weights for each of those blocks. + DenseMap<BasicBlock *, BlockFrequency> ExitWeights; + SmallPtrSet<BasicBlock *, 1> ExitBlocks; + for (BasicBlock *Block : Blocks) { + for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; + ++SI) { + if (!Blocks.count(*SI)) { + // Update the branch weight for this successor. + if (BFI) { + BlockFrequency &BF = ExitWeights[*SI]; + BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); + } + ExitBlocks.insert(*SI); + } + } + } + NumExitBlocks = ExitBlocks.size(); + + // If we have to split PHI nodes of the entry or exit blocks, do so now. + severSplitPHINodesOfEntry(header); + severSplitPHINodesOfExits(ExitBlocks); + // This takes place of the original loop BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, @@ -1201,30 +1360,17 @@ Function *CodeExtractor::extractCodeRegion() { cast<Instruction>(II)->moveBefore(TI); } - // Calculate the exit blocks for the extracted region and the total exit - // weights for each of those blocks. - DenseMap<BasicBlock *, BlockFrequency> ExitWeights; - SmallPtrSet<BasicBlock *, 1> ExitBlocks; - for (BasicBlock *Block : Blocks) { - for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; - ++SI) { - if (!Blocks.count(*SI)) { - // Update the branch weight for this successor. - if (BFI) { - BlockFrequency &BF = ExitWeights[*SI]; - BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); - } - ExitBlocks.insert(*SI); - } - } - } - NumExitBlocks = ExitBlocks.size(); + // Collect objects which are inputs to the extraction region and also + // referenced by lifetime start/end markers within it. The effects of these + // markers must be replicated in the calling function to prevent the stack + // coloring pass from merging slots which store input objects. + ValueSet InputObjectsWithLifetime = + eraseLifetimeMarkersOnInputs(Blocks, SinkingCands); // Construct new function based on inputs/outputs & add allocas for all defs. - Function *newFunction = constructFunction(inputs, outputs, header, - newFuncRoot, - codeReplacer, oldFunction, - oldFunction->getParent()); + Function *newFunction = + constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer, + oldFunction, oldFunction->getParent()); // Update the entry count of the function. if (BFI) { @@ -1235,10 +1381,16 @@ Function *CodeExtractor::extractCodeRegion() { BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); } - emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); + CallInst *TheCall = + emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); moveCodeToFunction(newFunction); + // Replicate the effects of any lifetime start/end markers which referenced + // input objects in the extraction region by placing markers around the call. + insertLifetimeMarkersSurroundingCall(oldFunction->getParent(), + InputObjectsWithLifetime, TheCall); + // Propagate personality info to the new function if there is one. if (oldFunction->hasPersonalityFn()) newFunction->setPersonalityFn(oldFunction->getPersonalityFn()); @@ -1247,8 +1399,8 @@ Function *CodeExtractor::extractCodeRegion() { if (BFI && NumExitBlocks > 1) calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); - // Loop over all of the PHI nodes in the header block, and change any - // references to the old incoming edge to be the new incoming edge. + // Loop over all of the PHI nodes in the header and exit blocks, and change + // any references to the old incoming edge to be the new incoming edge. for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) { PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) @@ -1256,29 +1408,60 @@ Function *CodeExtractor::extractCodeRegion() { PN->setIncomingBlock(i, newFuncRoot); } - // Look at all successors of the codeReplacer block. If any of these blocks - // had PHI nodes in them, we need to update the "from" block to be the code - // replacer, not the original block in the extracted region. - std::vector<BasicBlock *> Succs(succ_begin(codeReplacer), - succ_end(codeReplacer)); - for (unsigned i = 0, e = Succs.size(); i != e; ++i) - for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) { - PHINode *PN = cast<PHINode>(I); - std::set<BasicBlock*> ProcessedPreds; - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (Blocks.count(PN->getIncomingBlock(i))) { - if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) - PN->setIncomingBlock(i, codeReplacer); - else { - // There were multiple entries in the PHI for this block, now there - // is only one, so remove the duplicated entries. - PN->removeIncomingValue(i, false); - --i; --e; - } - } + for (BasicBlock *ExitBB : ExitBlocks) + for (PHINode &PN : ExitBB->phis()) { + Value *IncomingCodeReplacerVal = nullptr; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + // Ignore incoming values from outside of the extracted region. + if (!Blocks.count(PN.getIncomingBlock(i))) + continue; + + // Ensure that there is only one incoming value from codeReplacer. + if (!IncomingCodeReplacerVal) { + PN.setIncomingBlock(i, codeReplacer); + IncomingCodeReplacerVal = PN.getIncomingValue(i); + } else + assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) && + "PHI has two incompatbile incoming values from codeRepl"); + } + } + + // Erase debug info intrinsics. Variable updates within the new function are + // invisible to debuggers. This could be improved by defining a DISubprogram + // for the new function. + for (BasicBlock &BB : *newFunction) { + auto BlockIt = BB.begin(); + // Remove debug info intrinsics from the new function. + while (BlockIt != BB.end()) { + Instruction *Inst = &*BlockIt; + ++BlockIt; + if (isa<DbgInfoIntrinsic>(Inst)) + Inst->eraseFromParent(); } + // Remove debug info intrinsics which refer to values in the new function + // from the old function. + SmallVector<DbgVariableIntrinsic *, 4> DbgUsers; + for (Instruction &I : BB) + findDbgUsers(DbgUsers, &I); + for (DbgVariableIntrinsic *DVI : DbgUsers) + DVI->eraseFromParent(); + } - LLVM_DEBUG(if (verifyFunction(*newFunction)) - report_fatal_error("verifyFunction failed!")); + // Mark the new function `noreturn` if applicable. Terminators which resume + // exception propagation are treated as returning instructions. This is to + // avoid inserting traps after calls to outlined functions which unwind. + bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) { + const Instruction *Term = BB.getTerminator(); + return isa<ReturnInst>(Term) || isa<ResumeInst>(Term); + }); + if (doesNotReturn) + newFunction->setDoesNotReturn(); + + LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) { + newFunction->dump(); + report_fatal_error("verification of newFunction failed!"); + }); + LLVM_DEBUG(if (verifyFunction(*oldFunction)) + report_fatal_error("verification of oldFunction failed!")); return newFunction; } diff --git a/lib/Transforms/Utils/CtorUtils.cpp b/lib/Transforms/Utils/CtorUtils.cpp index 9a0240144d08..4e7da7d0449f 100644 --- a/lib/Transforms/Utils/CtorUtils.cpp +++ b/lib/Transforms/Utils/CtorUtils.cpp @@ -22,11 +22,10 @@ #define DEBUG_TYPE "ctor_utils" -namespace llvm { +using namespace llvm; -namespace { /// Given a specified llvm.global_ctors list, remove the listed elements. -void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemove) { +static void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemove) { // Filter out the initializer elements to remove. ConstantArray *OldCA = cast<ConstantArray>(GCL->getInitializer()); SmallVector<Constant *, 10> CAList; @@ -64,7 +63,7 @@ void removeGlobalCtors(GlobalVariable *GCL, const BitVector &CtorsToRemove) { /// Given a llvm.global_ctors list that we can understand, /// return a list of the functions and null terminator as a vector. -std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { +static std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { if (GV->getInitializer()->isNullValue()) return std::vector<Function *>(); ConstantArray *CA = cast<ConstantArray>(GV->getInitializer()); @@ -79,7 +78,7 @@ std::vector<Function *> parseGlobalCtors(GlobalVariable *GV) { /// Find the llvm.global_ctors list, verifying that all initializers have an /// init priority of 65535. -GlobalVariable *findGlobalCtors(Module &M) { +static GlobalVariable *findGlobalCtors(Module &M) { GlobalVariable *GV = M.getGlobalVariable("llvm.global_ctors"); if (!GV) return nullptr; @@ -112,12 +111,11 @@ GlobalVariable *findGlobalCtors(Module &M) { return GV; } -} // namespace /// Call "ShouldRemove" for every entry in M's global_ctor list and remove the /// entries for which it returns true. Return true if anything changed. -bool optimizeGlobalCtorsList(Module &M, - function_ref<bool(Function *)> ShouldRemove) { +bool llvm::optimizeGlobalCtorsList( + Module &M, function_ref<bool(Function *)> ShouldRemove) { GlobalVariable *GlobalCtors = findGlobalCtors(M); if (!GlobalCtors) return false; @@ -160,5 +158,3 @@ bool optimizeGlobalCtorsList(Module &M, removeGlobalCtors(GlobalCtors, CtorsToRemove); return true; } - -} // End llvm namespace diff --git a/lib/Transforms/Utils/DemoteRegToStack.cpp b/lib/Transforms/Utils/DemoteRegToStack.cpp index 56ff03c7f5e1..975b363859a9 100644 --- a/lib/Transforms/Utils/DemoteRegToStack.cpp +++ b/lib/Transforms/Utils/DemoteRegToStack.cpp @@ -90,7 +90,7 @@ AllocaInst *llvm::DemoteRegToStack(Instruction &I, bool VolatileLoads, // careful if I is an invoke instruction, because we can't insert the store // AFTER the terminator instruction. BasicBlock::iterator InsertPt; - if (!isa<TerminatorInst>(I)) { + if (!I.isTerminator()) { InsertPt = ++I.getIterator(); for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) /* empty */; // Don't insert before PHI nodes or landingpad instrs. diff --git a/lib/Transforms/Utils/EscapeEnumerator.cpp b/lib/Transforms/Utils/EscapeEnumerator.cpp index c9c96fbe5da0..762a374c135c 100644 --- a/lib/Transforms/Utils/EscapeEnumerator.cpp +++ b/lib/Transforms/Utils/EscapeEnumerator.cpp @@ -37,7 +37,7 @@ IRBuilder<> *EscapeEnumerator::Next() { // Branches and invokes do not escape, only unwind, resume, and return // do. - TerminatorInst *TI = CurBB->getTerminator(); + Instruction *TI = CurBB->getTerminator(); if (!isa<ReturnInst>(TI) && !isa<ResumeInst>(TI)) continue; diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp index 7fd9425efed3..e875cd686b00 100644 --- a/lib/Transforms/Utils/Evaluator.cpp +++ b/lib/Transforms/Utils/Evaluator.cpp @@ -483,8 +483,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, } } - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) { + if (II->isLifetimeStartOrEnd()) { LLVM_DEBUG(dbgs() << "Ignoring lifetime intrinsic.\n"); ++CurInst; continue; @@ -578,7 +577,7 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, << "Successfully evaluated function. Result: 0\n\n"); } } - } else if (isa<TerminatorInst>(CurInst)) { + } else if (CurInst->isTerminator()) { LLVM_DEBUG(dbgs() << "Found a terminator instruction.\n"); if (BranchInst *BI = dyn_cast<BranchInst>(CurInst)) { diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index 3c6c9c9a5df4..d9778f4a1fb7 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -232,7 +232,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { if (!FirstCondBlock || !LastCondBlock || (FirstCondBlock == LastCondBlock)) return false; - TerminatorInst *TBB = LastCondBlock->getTerminator(); + Instruction *TBB = LastCondBlock->getTerminator(); BasicBlock *PS1 = TBB->getSuccessor(0); BasicBlock *PS2 = TBB->getSuccessor(1); BranchInst *PBI1 = dyn_cast<BranchInst>(PS1->getTerminator()); @@ -325,7 +325,7 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { bool FlattenCFGOpt::CompareIfRegionBlock(BasicBlock *Head1, BasicBlock *Head2, BasicBlock *Block1, BasicBlock *Block2) { - TerminatorInst *PTI2 = Head2->getTerminator(); + Instruction *PTI2 = Head2->getTerminator(); Instruction *PBI2 = &Head2->front(); bool eq1 = (Block1 == Head1); @@ -421,7 +421,7 @@ bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { if ((IfTrue2 != SecondEntryBlock) && (IfFalse2 != SecondEntryBlock)) return false; - TerminatorInst *PTI2 = SecondEntryBlock->getTerminator(); + Instruction *PTI2 = SecondEntryBlock->getTerminator(); Instruction *PBI2 = &SecondEntryBlock->front(); if (!CompareIfRegionBlock(FirstEntryBlock, SecondEntryBlock, IfTrue1, diff --git a/lib/Transforms/Utils/FunctionComparator.cpp b/lib/Transforms/Utils/FunctionComparator.cpp index 69203f9f2485..a717d9b72819 100644 --- a/lib/Transforms/Utils/FunctionComparator.cpp +++ b/lib/Transforms/Utils/FunctionComparator.cpp @@ -410,8 +410,6 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { switch (TyL->getTypeID()) { default: llvm_unreachable("Unknown type!"); - // Fall through in Release mode. - LLVM_FALLTHROUGH; case Type::IntegerTyID: return cmpNumbers(cast<IntegerType>(TyL)->getBitWidth(), cast<IntegerType>(TyR)->getBitWidth()); @@ -867,8 +865,8 @@ int FunctionComparator::compare() { if (int Res = cmpBasicBlocks(BBL, BBR)) return Res; - const TerminatorInst *TermL = BBL->getTerminator(); - const TerminatorInst *TermR = BBR->getTerminator(); + const Instruction *TermL = BBL->getTerminator(); + const Instruction *TermR = BBR->getTerminator(); assert(TermL->getNumSuccessors() == TermR->getNumSuccessors()); for (unsigned i = 0, e = TermL->getNumSuccessors(); i != e; ++i) { @@ -938,7 +936,7 @@ FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { for (auto &Inst : *BB) { H.add(Inst.getOpcode()); } - const TerminatorInst *Term = BB->getTerminator(); + const Instruction *Term = BB->getTerminator(); for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { if (!VisitedBBs.insert(Term->getSuccessor(i)).second) continue; diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp index 479816a339d0..a9772e31da50 100644 --- a/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -124,7 +124,6 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, return SGV->getLinkage(); switch (SGV->getLinkage()) { - case GlobalValue::LinkOnceAnyLinkage: case GlobalValue::LinkOnceODRLinkage: case GlobalValue::ExternalLinkage: // External and linkonce definitions are converted to available_externally @@ -144,11 +143,13 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, // An imported available_externally declaration stays that way. return SGV->getLinkage(); + case GlobalValue::LinkOnceAnyLinkage: case GlobalValue::WeakAnyLinkage: - // Can't import weak_any definitions correctly, or we might change the - // program semantics, since the linker will pick the first weak_any - // definition and importing would change the order they are seen by the - // linker. The module linking caller needs to enforce this. + // Can't import linkonce_any/weak_any definitions correctly, or we might + // change the program semantics, since the linker will pick the first + // linkonce_any/weak_any definition and importing would change the order + // they are seen by the linker. The module linking caller needs to enforce + // this. assert(!doImportAsDefinition(SGV)); // If imported as a declaration, it becomes external_weak. return SGV->getLinkage(); @@ -202,10 +203,26 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { - // Check the summaries to see if the symbol gets resolved to a known local - // definition. + ValueInfo VI; if (GV.hasName()) { - ValueInfo VI = ImportIndex.getValueInfo(GV.getGUID()); + VI = ImportIndex.getValueInfo(GV.getGUID()); + // Set synthetic function entry counts. + if (VI && ImportIndex.hasSyntheticEntryCounts()) { + if (Function *F = dyn_cast<Function>(&GV)) { + if (!F->isDeclaration()) { + for (auto &S : VI.getSummaryList()) { + FunctionSummary *FS = dyn_cast<FunctionSummary>(S->getBaseObject()); + if (FS->modulePath() == M.getModuleIdentifier()) { + F->setEntryCount(Function::ProfileCount(FS->entryCount(), + Function::PCT_Synthetic)); + break; + } + } + } + } + } + // Check the summaries to see if the symbol gets resolved to a known local + // definition. if (VI && VI.isDSOLocal()) { GV.setDSOLocal(true); if (GV.hasDLLImportStorageClass()) @@ -213,6 +230,22 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { } } + // Mark read-only variables which can be imported with specific attribute. + // We can't internalize them now because IRMover will fail to link variable + // definitions to their external declarations during ThinLTO import. We'll + // internalize read-only variables later, after import is finished. + // See internalizeImmutableGVs. + // + // If global value dead stripping is not enabled in summary then + // propagateConstants hasn't been run. We can't internalize GV + // in such case. + if (!GV.isDeclaration() && VI && ImportIndex.withGlobalValueDeadStripping()) { + const auto &SL = VI.getSummaryList(); + auto *GVS = SL.empty() ? nullptr : dyn_cast<GlobalVarSummary>(SL[0].get()); + if (GVS && GVS->isReadOnly()) + cast<GlobalVariable>(&GV)->addAttribute("thinlto-internalize"); + } + bool DoPromote = false; if (GV.hasLocalLinkage() && ((DoPromote = shouldPromoteLocalToGlobal(&GV)) || isPerformingImport())) { @@ -230,7 +263,7 @@ void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { // Remove functions imported as available externally defs from comdats, // as this is a declaration for the linker, and will be dropped eventually. // It is illegal for comdats to contain declarations. - auto *GO = dyn_cast_or_null<GlobalObject>(&GV); + auto *GO = dyn_cast<GlobalObject>(&GV); if (GO && GO->isDeclarationForLinker() && GO->hasComdat()) { // The IRMover should not have placed any imported declarations in // a comdat, so the only declaration that should be in a comdat diff --git a/lib/Transforms/Utils/GuardUtils.cpp b/lib/Transforms/Utils/GuardUtils.cpp new file mode 100644 index 000000000000..08de0a4c53e9 --- /dev/null +++ b/lib/Transforms/Utils/GuardUtils.cpp @@ -0,0 +1,64 @@ +//===-- GuardUtils.cpp - Utils for work with guards -------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Utils that are used to perform transformations related to guards and their +// conditions. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/GuardUtils.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +static cl::opt<uint32_t> PredicatePassBranchWeight( + "guards-predicate-pass-branch-weight", cl::Hidden, cl::init(1 << 20), + cl::desc("The probability of a guard failing is assumed to be the " + "reciprocal of this value (default = 1 << 20)")); + +void llvm::makeGuardControlFlowExplicit(Function *DeoptIntrinsic, + CallInst *Guard) { + OperandBundleDef DeoptOB(*Guard->getOperandBundle(LLVMContext::OB_deopt)); + SmallVector<Value *, 4> Args(std::next(Guard->arg_begin()), Guard->arg_end()); + + auto *CheckBB = Guard->getParent(); + auto *DeoptBlockTerm = + SplitBlockAndInsertIfThen(Guard->getArgOperand(0), Guard, true); + + auto *CheckBI = cast<BranchInst>(CheckBB->getTerminator()); + + // SplitBlockAndInsertIfThen inserts control flow that branches to + // DeoptBlockTerm if the condition is true. We want the opposite. + CheckBI->swapSuccessors(); + + CheckBI->getSuccessor(0)->setName("guarded"); + CheckBI->getSuccessor(1)->setName("deopt"); + + if (auto *MD = Guard->getMetadata(LLVMContext::MD_make_implicit)) + CheckBI->setMetadata(LLVMContext::MD_make_implicit, MD); + + MDBuilder MDB(Guard->getContext()); + CheckBI->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(PredicatePassBranchWeight, 1)); + + IRBuilder<> B(DeoptBlockTerm); + auto *DeoptCall = B.CreateCall(DeoptIntrinsic, Args, {DeoptOB}, ""); + + if (DeoptIntrinsic->getReturnType()->isVoidTy()) { + B.CreateRetVoid(); + } else { + DeoptCall->setName("deoptcall"); + B.CreateRet(DeoptCall); + } + + DeoptCall->setCallingConv(Guard->getCallingConv()); + DeoptBlockTerm->eraseFromParent(); +} diff --git a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp index 8382220fc9e1..02482c550321 100644 --- a/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp +++ b/lib/Transforms/Utils/ImportedFunctionsInliningStatistics.cpp @@ -161,7 +161,7 @@ void ImportedFunctionsInliningStatistics::dump(const bool Verbose) { void ImportedFunctionsInliningStatistics::calculateRealInlines() { // Removing duplicated Callers. - llvm::sort(NonImportedCallers.begin(), NonImportedCallers.end()); + llvm::sort(NonImportedCallers); NonImportedCallers.erase( std::unique(NonImportedCallers.begin(), NonImportedCallers.end()), NonImportedCallers.end()); @@ -190,17 +190,14 @@ ImportedFunctionsInliningStatistics::getSortedNodes() { for (const NodesMapTy::value_type& Node : NodesMap) SortedNodes.push_back(&Node); - llvm::sort( - SortedNodes.begin(), SortedNodes.end(), - [&](const SortedNodesTy::value_type &Lhs, - const SortedNodesTy::value_type &Rhs) { - if (Lhs->second->NumberOfInlines != Rhs->second->NumberOfInlines) - return Lhs->second->NumberOfInlines > Rhs->second->NumberOfInlines; - if (Lhs->second->NumberOfRealInlines != - Rhs->second->NumberOfRealInlines) - return Lhs->second->NumberOfRealInlines > - Rhs->second->NumberOfRealInlines; - return Lhs->first() < Rhs->first(); - }); + llvm::sort(SortedNodes, [&](const SortedNodesTy::value_type &Lhs, + const SortedNodesTy::value_type &Rhs) { + if (Lhs->second->NumberOfInlines != Rhs->second->NumberOfInlines) + return Lhs->second->NumberOfInlines > Rhs->second->NumberOfInlines; + if (Lhs->second->NumberOfRealInlines != Rhs->second->NumberOfRealInlines) + return Lhs->second->NumberOfRealInlines > + Rhs->second->NumberOfRealInlines; + return Lhs->first() < Rhs->first(); + }); return SortedNodes; } diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index ddc6e07e2f59..623fe91a5a60 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -31,6 +31,7 @@ #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -84,13 +85,15 @@ PreserveAlignmentAssumptions("preserve-alignment-assumptions-during-inlining", cl::init(true), cl::Hidden, cl::desc("Convert align attributes to assumptions during inlining.")); -bool llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, - AAResults *CalleeAAR, bool InsertLifetime) { +llvm::InlineResult llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, + AAResults *CalleeAAR, + bool InsertLifetime) { return InlineFunction(CallSite(CI), IFI, CalleeAAR, InsertLifetime); } -bool llvm::InlineFunction(InvokeInst *II, InlineFunctionInfo &IFI, - AAResults *CalleeAAR, bool InsertLifetime) { +llvm::InlineResult llvm::InlineFunction(InvokeInst *II, InlineFunctionInfo &IFI, + AAResults *CalleeAAR, + bool InsertLifetime) { return InlineFunction(CallSite(II), IFI, CalleeAAR, InsertLifetime); } @@ -768,14 +771,16 @@ static void HandleInlinedEHPad(InvokeInst *II, BasicBlock *FirstNewBlock, UnwindDest->removePredecessor(InvokeBB); } -/// When inlining a call site that has !llvm.mem.parallel_loop_access metadata, -/// that metadata should be propagated to all memory-accessing cloned -/// instructions. +/// When inlining a call site that has !llvm.mem.parallel_loop_access or +/// llvm.access.group metadata, that metadata should be propagated to all +/// memory-accessing cloned instructions. static void PropagateParallelLoopAccessMetadata(CallSite CS, ValueToValueMapTy &VMap) { MDNode *M = CS.getInstruction()->getMetadata(LLVMContext::MD_mem_parallel_loop_access); - if (!M) + MDNode *CallAccessGroup = + CS.getInstruction()->getMetadata(LLVMContext::MD_access_group); + if (!M && !CallAccessGroup) return; for (ValueToValueMapTy::iterator VMI = VMap.begin(), VMIE = VMap.end(); @@ -787,11 +792,20 @@ static void PropagateParallelLoopAccessMetadata(CallSite CS, if (!NI) continue; - if (MDNode *PM = NI->getMetadata(LLVMContext::MD_mem_parallel_loop_access)) { + if (M) { + if (MDNode *PM = + NI->getMetadata(LLVMContext::MD_mem_parallel_loop_access)) { M = MDNode::concatenate(PM, M); NI->setMetadata(LLVMContext::MD_mem_parallel_loop_access, M); - } else if (NI->mayReadOrWriteMemory()) { - NI->setMetadata(LLVMContext::MD_mem_parallel_loop_access, M); + } else if (NI->mayReadOrWriteMemory()) { + NI->setMetadata(LLVMContext::MD_mem_parallel_loop_access, M); + } + } + + if (NI->mayReadOrWriteMemory()) { + MDNode *UnitedAccGroups = uniteAccessGroups( + NI->getMetadata(LLVMContext::MD_access_group), CallAccessGroup); + NI->setMetadata(LLVMContext::MD_access_group, UnitedAccGroups); } } } @@ -985,22 +999,22 @@ static void AddAliasScopeMetadata(CallSite CS, ValueToValueMapTy &VMap, PtrArgs.push_back(CXI->getPointerOperand()); else if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) PtrArgs.push_back(RMWI->getPointerOperand()); - else if (ImmutableCallSite ICS = ImmutableCallSite(I)) { + else if (const auto *Call = dyn_cast<CallBase>(I)) { // If we know that the call does not access memory, then we'll still // know that about the inlined clone of this call site, and we don't // need to add metadata. - if (ICS.doesNotAccessMemory()) + if (Call->doesNotAccessMemory()) continue; IsFuncCall = true; if (CalleeAAR) { - FunctionModRefBehavior MRB = CalleeAAR->getModRefBehavior(ICS); + FunctionModRefBehavior MRB = CalleeAAR->getModRefBehavior(Call); if (MRB == FMRB_OnlyAccessesArgumentPointees || MRB == FMRB_OnlyReadsArgumentPointees) IsArgMemOnlyCall = true; } - for (Value *Arg : ICS.args()) { + for (Value *Arg : Call->args()) { // We need to check the underlying objects of all arguments, not just // the pointer arguments, because we might be passing pointers as // integers, etc. @@ -1306,16 +1320,10 @@ static Value *HandleByValArgument(Value *Arg, Instruction *TheCall, // Check whether this Value is used by a lifetime intrinsic. static bool isUsedByLifetimeMarker(Value *V) { - for (User *U : V->users()) { - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) { - switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::lifetime_start: - case Intrinsic::lifetime_end: + for (User *U : V->users()) + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) + if (II->isLifetimeStartOrEnd()) return true; - } - } - } return false; } @@ -1491,9 +1499,10 @@ static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, /// instruction 'call B' is inlined, and 'B' calls 'C', then the call to 'C' now /// exists in the instruction stream. Similarly this will inline a recursive /// function by one level. -bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, - AAResults *CalleeAAR, bool InsertLifetime, - Function *ForwardVarArgsTo) { +llvm::InlineResult llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, + AAResults *CalleeAAR, + bool InsertLifetime, + Function *ForwardVarArgsTo) { Instruction *TheCall = CS.getInstruction(); assert(TheCall->getParent() && TheCall->getFunction() && "Instruction not in function!"); @@ -1504,7 +1513,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, Function *CalledFunc = CS.getCalledFunction(); if (!CalledFunc || // Can't inline external function or indirect CalledFunc->isDeclaration()) // call! - return false; + return "external or indirect"; // The inliner does not know how to inline through calls with operand bundles // in general ... @@ -1518,7 +1527,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, if (Tag == LLVMContext::OB_funclet) continue; - return false; + return "unsupported operand bundle"; } } @@ -1537,7 +1546,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, if (!Caller->hasGC()) Caller->setGC(CalledFunc->getGC()); else if (CalledFunc->getGC() != Caller->getGC()) - return false; + return "incompatible GC"; } // Get the personality function from the callee if it contains a landing pad. @@ -1561,7 +1570,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // TODO: This isn't 100% true. Some personality functions are proper // supersets of others and can be used in place of the other. else if (CalledPersonality != CallerPersonality) - return false; + return "incompatible personality"; } // We need to figure out which funclet the callsite was in so that we may @@ -1586,7 +1595,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // for catchpads. for (const BasicBlock &CalledBB : *CalledFunc) { if (isa<CatchSwitchInst>(CalledBB.getFirstNonPHI())) - return false; + return "catch in cleanup funclet"; } } } else if (isAsynchronousEHPersonality(Personality)) { @@ -1594,7 +1603,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // funclet in the callee. for (const BasicBlock &CalledBB : *CalledFunc) { if (CalledBB.isEHPad()) - return false; + return "SEH in cleanup funclet"; } } } @@ -2244,7 +2253,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Change the branch that used to go to AfterCallBB to branch to the first // basic block of the inlined function. // - TerminatorInst *Br = OrigBB->getTerminator(); + Instruction *Br = OrigBB->getTerminator(); assert(Br && Br->getOpcode() == Instruction::Br && "splitBasicBlock broken!"); Br->setOperand(0, &*FirstNewBlock); diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index a1f8e7484bcf..53d444b309d5 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -41,6 +41,7 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PredIteratorCache.h" #include "llvm/Pass.h" #include "llvm/Transforms/Utils.h" @@ -201,6 +202,21 @@ bool llvm::formLCSSAForInstructions(SmallVectorImpl<Instruction *> &Worklist, SSAUpdate.RewriteUse(*UseToRewrite); } + SmallVector<DbgValueInst *, 4> DbgValues; + llvm::findDbgValues(DbgValues, I); + + // Update pre-existing debug value uses that reside outside the loop. + auto &Ctx = I->getContext(); + for (auto DVI : DbgValues) { + BasicBlock *UserBB = DVI->getParent(); + if (InstBB == UserBB || L->contains(UserBB)) + continue; + // We currently only handle debug values residing in blocks where we have + // inserted a PHI instruction. + if (Value *V = SSAUpdate.FindValueForBlock(UserBB)) + DVI->setOperand(0, MetadataAsValue::get(Ctx, ValueAsMetadata::get(V))); + } + // SSAUpdater might have inserted phi-nodes inside other loops. We'll need // to post-process them to keep LCSSA form. for (PHINode *InsertedPN : InsertedPHIs) { diff --git a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp index 9832a6f24e1f..e1592c867636 100644 --- a/lib/Transforms/Utils/LibCallsShrinkWrap.cpp +++ b/lib/Transforms/Utils/LibCallsShrinkWrap.cpp @@ -487,7 +487,7 @@ void LibCallsShrinkWrap::shrinkWrapCI(CallInst *CI, Value *Cond) { MDNode *BranchWeights = MDBuilder(CI->getContext()).createBranchWeights(1, 2000); - TerminatorInst *NewInst = + Instruction *NewInst = SplitBlockAndInsertIfThen(Cond, CI, false, BranchWeights, DT); BasicBlock *CallBB = NewInst->getParent(); CallBB->setName("cdce.call"); diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index ae3cb077a3af..499e611acb57 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -31,8 +31,10 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/BinaryFormat/Dwarf.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" @@ -47,6 +49,7 @@ #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" @@ -102,8 +105,8 @@ STATISTIC(NumRemoved, "Number of unreachable basic blocks removed"); /// DeleteDeadConditions is true. bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, const TargetLibraryInfo *TLI, - DeferredDominance *DDT) { - TerminatorInst *T = BB->getTerminator(); + DomTreeUpdater *DTU) { + Instruction *T = BB->getTerminator(); IRBuilder<> Builder(T); // Branch - See if we are conditional jumping on constant @@ -125,8 +128,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, // Replace the conditional branch with an unconditional one. Builder.CreateBr(Destination); BI->eraseFromParent(); - if (DDT) - DDT->deleteEdge(BB, OldDest); + if (DTU) + DTU->deleteEdgeRelaxed(BB, OldDest); return true; } @@ -201,8 +204,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, DefaultDest->removePredecessor(ParentBB); i = SI->removeCase(i); e = SI->case_end(); - if (DDT) - DDT->deleteEdge(ParentBB, DefaultDest); + if (DTU) + DTU->deleteEdgeRelaxed(ParentBB, DefaultDest); continue; } @@ -229,17 +232,17 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, Builder.CreateBr(TheOnlyDest); BasicBlock *BB = SI->getParent(); std::vector <DominatorTree::UpdateType> Updates; - if (DDT) + if (DTU) Updates.reserve(SI->getNumSuccessors() - 1); // Remove entries from PHI nodes which we no longer branch to... - for (BasicBlock *Succ : SI->successors()) { + for (BasicBlock *Succ : successors(SI)) { // Found case matching a constant operand? if (Succ == TheOnlyDest) { TheOnlyDest = nullptr; // Don't modify the first branch to TheOnlyDest } else { Succ->removePredecessor(BB); - if (DDT) + if (DTU) Updates.push_back({DominatorTree::Delete, BB, Succ}); } } @@ -249,8 +252,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, SI->eraseFromParent(); if (DeleteDeadConditions) RecursivelyDeleteTriviallyDeadInstructions(Cond, TLI); - if (DDT) - DDT->applyUpdates(Updates); + if (DTU) + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); return true; } @@ -297,7 +300,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, dyn_cast<BlockAddress>(IBI->getAddress()->stripPointerCasts())) { BasicBlock *TheOnlyDest = BA->getBasicBlock(); std::vector <DominatorTree::UpdateType> Updates; - if (DDT) + if (DTU) Updates.reserve(IBI->getNumDestinations() - 1); // Insert the new branch. @@ -310,7 +313,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, BasicBlock *ParentBB = IBI->getParent(); BasicBlock *DestBB = IBI->getDestination(i); DestBB->removePredecessor(ParentBB); - if (DDT) + if (DTU) Updates.push_back({DominatorTree::Delete, ParentBB, DestBB}); } } @@ -327,8 +330,8 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, new UnreachableInst(BB->getContext(), BB); } - if (DDT) - DDT->applyUpdates(Updates); + if (DTU) + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); return true; } } @@ -352,7 +355,7 @@ bool llvm::isInstructionTriviallyDead(Instruction *I, bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, const TargetLibraryInfo *TLI) { - if (isa<TerminatorInst>(I)) + if (I->isTerminator()) return false; // We don't want the landingpad-like instructions removed by anything this @@ -390,8 +393,7 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, return true; // Lifetime intrinsics are dead when their right-hand is undef. - if (II->getIntrinsicID() == Intrinsic::lifetime_start || - II->getIntrinsicID() == Intrinsic::lifetime_end) + if (II->isLifetimeStartOrEnd()) return isa<UndefValue>(II->getArgOperand(1)); // Assumptions are dead if their condition is trivially true. Guards on @@ -425,22 +427,22 @@ bool llvm::wouldInstructionBeTriviallyDead(Instruction *I, /// trivially dead instruction, delete it. If that makes any of its operands /// trivially dead, delete them too, recursively. Return true if any /// instructions were deleted. -bool -llvm::RecursivelyDeleteTriviallyDeadInstructions(Value *V, - const TargetLibraryInfo *TLI) { +bool llvm::RecursivelyDeleteTriviallyDeadInstructions( + Value *V, const TargetLibraryInfo *TLI, MemorySSAUpdater *MSSAU) { Instruction *I = dyn_cast<Instruction>(V); if (!I || !I->use_empty() || !isInstructionTriviallyDead(I, TLI)) return false; SmallVector<Instruction*, 16> DeadInsts; DeadInsts.push_back(I); - RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, TLI); + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts, TLI, MSSAU); return true; } void llvm::RecursivelyDeleteTriviallyDeadInstructions( - SmallVectorImpl<Instruction *> &DeadInsts, const TargetLibraryInfo *TLI) { + SmallVectorImpl<Instruction *> &DeadInsts, const TargetLibraryInfo *TLI, + MemorySSAUpdater *MSSAU) { // Process the dead instruction list until empty. while (!DeadInsts.empty()) { Instruction &I = *DeadInsts.pop_back_val(); @@ -467,11 +469,24 @@ void llvm::RecursivelyDeleteTriviallyDeadInstructions( if (isInstructionTriviallyDead(OpI, TLI)) DeadInsts.push_back(OpI); } + if (MSSAU) + MSSAU->removeMemoryAccess(&I); I.eraseFromParent(); } } +bool llvm::replaceDbgUsesWithUndef(Instruction *I) { + SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; + findDbgUsers(DbgUsers, I); + for (auto *DII : DbgUsers) { + Value *Undef = UndefValue::get(I->getType()); + DII->setOperand(0, MetadataAsValue::get(DII->getContext(), + ValueAsMetadata::get(Undef))); + } + return !DbgUsers.empty(); +} + /// areAllUsesEqual - Check whether the uses of a value are all the same. /// This is similar to Instruction::hasOneUse() except this will also return /// true when there are no uses or multiple uses that all refer to the same @@ -626,7 +641,7 @@ bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB, /// .. and delete the predecessor corresponding to the '1', this will attempt to /// recursively fold the and to 0. void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred, - DeferredDominance *DDT) { + DomTreeUpdater *DTU) { // This only adjusts blocks with PHI nodes. if (!isa<PHINode>(BB->begin())) return; @@ -649,17 +664,16 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred, // of the block. if (PhiIt != OldPhiIt) PhiIt = &BB->front(); } - if (DDT) - DDT->deleteEdge(Pred, BB); + if (DTU) + DTU->deleteEdgeRelaxed(Pred, BB); } /// MergeBasicBlockIntoOnlyPred - DestBB is a block with one predecessor and its -/// predecessor is known to have one successor (DestBB!). Eliminate the edge +/// predecessor is known to have one successor (DestBB!). Eliminate the edge /// between them, moving the instructions in the predecessor into DestBB and /// deleting the predecessor block. -void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT, - DeferredDominance *DDT) { - assert(!(DT && DDT) && "Cannot call with both DT and DDT."); +void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, + DomTreeUpdater *DTU) { // If BB has single-entry PHI nodes, fold them. while (PHINode *PN = dyn_cast<PHINode>(DestBB->begin())) { @@ -677,11 +691,11 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT, if (PredBB == &DestBB->getParent()->getEntryBlock()) ReplaceEntryBB = true; - // Deferred DT update: Collect all the edges that enter PredBB. These - // dominator edges will be redirected to DestBB. - std::vector <DominatorTree::UpdateType> Updates; - if (DDT && !ReplaceEntryBB) { - Updates.reserve(1 + (2 * pred_size(PredBB))); + // DTU updates: Collect all the edges that enter + // PredBB. These dominator edges will be redirected to DestBB. + SmallVector<DominatorTree::UpdateType, 32> Updates; + + if (DTU) { Updates.push_back({DominatorTree::Delete, PredBB, DestBB}); for (auto I = pred_begin(PredBB), E = pred_end(PredBB); I != E; ++I) { Updates.push_back({DominatorTree::Delete, *I, PredBB}); @@ -708,33 +722,32 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT, // Splice all the instructions from PredBB to DestBB. PredBB->getTerminator()->eraseFromParent(); DestBB->getInstList().splice(DestBB->begin(), PredBB->getInstList()); + new UnreachableInst(PredBB->getContext(), PredBB); // If the PredBB is the entry block of the function, move DestBB up to // become the entry block after we erase PredBB. if (ReplaceEntryBB) DestBB->moveAfter(PredBB); - if (DT) { - // For some irreducible CFG we end up having forward-unreachable blocks - // so check if getNode returns a valid node before updating the domtree. - if (DomTreeNode *DTN = DT->getNode(PredBB)) { - BasicBlock *PredBBIDom = DTN->getIDom()->getBlock(); - DT->changeImmediateDominator(DestBB, PredBBIDom); - DT->eraseNode(PredBB); + if (DTU) { + assert(PredBB->getInstList().size() == 1 && + isa<UnreachableInst>(PredBB->getTerminator()) && + "The successor list of PredBB isn't empty before " + "applying corresponding DTU updates."); + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->deleteBB(PredBB); + // Recalculation of DomTree is needed when updating a forward DomTree and + // the Entry BB is replaced. + if (ReplaceEntryBB && DTU->hasDomTree()) { + // The entry block was removed and there is no external interface for + // the dominator tree to be notified of this change. In this corner-case + // we recalculate the entire tree. + DTU->recalculate(*(DestBB->getParent())); } } - if (DDT) { - DDT->deleteBB(PredBB); // Deferred deletion of BB. - if (ReplaceEntryBB) - // The entry block was removed and there is no external interface for the - // dominator tree to be notified of this change. In this corner-case we - // recalculate the entire tree. - DDT->recalculate(*(DestBB->getParent())); - else - DDT->applyUpdates(Updates); - } else { - PredBB->eraseFromParent(); // Nuke BB. + else { + PredBB->eraseFromParent(); // Nuke BB if DTU is nullptr. } } @@ -945,7 +958,7 @@ static void redirectValuesFromPredecessorsToPhi(BasicBlock *BB, /// eliminate BB by rewriting all the predecessors to branch to the successor /// block and return true. If we can't transform, return false. bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, - DeferredDominance *DDT) { + DomTreeUpdater *DTU) { assert(BB != &BB->getParent()->getEntryBlock() && "TryToSimplifyUncondBranchFromEmptyBlock called on entry block!"); @@ -986,9 +999,8 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, LLVM_DEBUG(dbgs() << "Killing Trivial BB: \n" << *BB); - std::vector<DominatorTree::UpdateType> Updates; - if (DDT) { - Updates.reserve(1 + (2 * pred_size(BB))); + SmallVector<DominatorTree::UpdateType, 32> Updates; + if (DTU) { Updates.push_back({DominatorTree::Delete, BB, Succ}); // All predecessors of BB will be moved to Succ. for (auto I = pred_begin(BB), E = pred_end(BB); I != E; ++I) { @@ -1044,9 +1056,16 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB, BB->replaceAllUsesWith(Succ); if (!Succ->hasName()) Succ->takeName(BB); - if (DDT) { - DDT->deleteBB(BB); // Deferred deletion of the old basic block. - DDT->applyUpdates(Updates); + // Clear the successor list of BB to match updates applying to DTU later. + if (BB->getTerminator()) + BB->getInstList().pop_back(); + new UnreachableInst(BB->getContext(), BB); + assert(succ_empty(BB) && "The successor list of BB isn't empty before " + "applying corresponding DTU updates."); + + if (DTU) { + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + DTU->deleteBB(BB); } else { BB->eraseFromParent(); // Delete the old basic block. } @@ -1237,7 +1256,7 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, /// alloc size of the value when doing the comparison. E.g. an i1 value will be /// identified as covering an n-bit fragment, if the store size of i1 is at /// least n bits. -static bool valueCoversEntireFragment(Type *ValTy, DbgInfoIntrinsic *DII) { +static bool valueCoversEntireFragment(Type *ValTy, DbgVariableIntrinsic *DII) { const DataLayout &DL = DII->getModule()->getDataLayout(); uint64_t ValueSize = DL.getTypeAllocSizeInBits(ValTy); if (auto FragmentSize = DII->getFragmentSizeInBits()) @@ -1255,7 +1274,7 @@ static bool valueCoversEntireFragment(Type *ValTy, DbgInfoIntrinsic *DII) { /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value /// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. -void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, +void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, StoreInst *SI, DIBuilder &Builder) { assert(DII->isAddressOfVariable()); auto *DIVar = DII->getVariable(); @@ -1278,33 +1297,6 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, return; } - // If an argument is zero extended then use argument directly. The ZExt - // may be zapped by an optimization pass in future. - Argument *ExtendedArg = nullptr; - if (ZExtInst *ZExt = dyn_cast<ZExtInst>(SI->getOperand(0))) - ExtendedArg = dyn_cast<Argument>(ZExt->getOperand(0)); - if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) - ExtendedArg = dyn_cast<Argument>(SExt->getOperand(0)); - if (ExtendedArg) { - // If this DII was already describing only a fragment of a variable, ensure - // that fragment is appropriately narrowed here. - // But if a fragment wasn't used, describe the value as the original - // argument (rather than the zext or sext) so that it remains described even - // if the sext/zext is optimized away. This widens the variable description, - // leaving it up to the consumer to know how the smaller value may be - // represented in a larger register. - if (auto Fragment = DIExpr->getFragmentInfo()) { - unsigned FragmentOffset = Fragment->OffsetInBits; - SmallVector<uint64_t, 3> Ops(DIExpr->elements_begin(), - DIExpr->elements_end() - 3); - Ops.push_back(dwarf::DW_OP_LLVM_fragment); - Ops.push_back(FragmentOffset); - const DataLayout &DL = DII->getModule()->getDataLayout(); - Ops.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); - DIExpr = Builder.createExpression(Ops); - } - DV = ExtendedArg; - } if (!LdStHasDebugValue(DIVar, DIExpr, SI)) Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, DII->getDebugLoc(), SI); @@ -1312,7 +1304,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value /// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. -void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, +void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, LoadInst *LI, DIBuilder &Builder) { auto *DIVar = DII->getVariable(); auto *DIExpr = DII->getExpression(); @@ -1341,7 +1333,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, /// Inserts a llvm.dbg.value intrinsic after a phi that has an associated /// llvm.dbg.declare or llvm.dbg.addr intrinsic. -void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, +void llvm::ConvertDebugDeclareToDebugValue(DbgVariableIntrinsic *DII, PHINode *APN, DIBuilder &Builder) { auto *DIVar = DII->getVariable(); auto *DIExpr = DII->getExpression(); @@ -1443,7 +1435,7 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB, // Map existing PHI nodes to their dbg.values. ValueToValueMapTy DbgValueMap; for (auto &I : *BB) { - if (auto DbgII = dyn_cast<DbgInfoIntrinsic>(&I)) { + if (auto DbgII = dyn_cast<DbgVariableIntrinsic>(&I)) { if (auto *Loc = dyn_cast_or_null<PHINode>(DbgII->getVariableLocation())) DbgValueMap.insert({Loc, DbgII}); } @@ -1464,7 +1456,7 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB, for (auto VI : PHI->operand_values()) { auto V = DbgValueMap.find(VI); if (V != DbgValueMap.end()) { - auto *DbgII = cast<DbgInfoIntrinsic>(V->second); + auto *DbgII = cast<DbgVariableIntrinsic>(V->second); Instruction *NewDbgII = DbgII->clone(); NewDbgII->setOperand(0, PhiMAV); auto InsertionPt = Parent->getFirstInsertionPt(); @@ -1478,7 +1470,7 @@ void llvm::insertDebugValuesForPHIs(BasicBlock *BB, /// Finds all intrinsics declaring local variables as living in the memory that /// 'V' points to. This may include a mix of dbg.declare and /// dbg.addr intrinsics. -TinyPtrVector<DbgInfoIntrinsic *> llvm::FindDbgAddrUses(Value *V) { +TinyPtrVector<DbgVariableIntrinsic *> llvm::FindDbgAddrUses(Value *V) { // This function is hot. Check whether the value has any metadata to avoid a // DenseMap lookup. if (!V->isUsedByMetadata()) @@ -1490,9 +1482,9 @@ TinyPtrVector<DbgInfoIntrinsic *> llvm::FindDbgAddrUses(Value *V) { if (!MDV) return {}; - TinyPtrVector<DbgInfoIntrinsic *> Declares; + TinyPtrVector<DbgVariableIntrinsic *> Declares; for (User *U : MDV->users()) { - if (auto *DII = dyn_cast<DbgInfoIntrinsic>(U)) + if (auto *DII = dyn_cast<DbgVariableIntrinsic>(U)) if (DII->isAddressOfVariable()) Declares.push_back(DII); } @@ -1512,7 +1504,7 @@ void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) { DbgValues.push_back(DVI); } -void llvm::findDbgUsers(SmallVectorImpl<DbgInfoIntrinsic *> &DbgUsers, +void llvm::findDbgUsers(SmallVectorImpl<DbgVariableIntrinsic *> &DbgUsers, Value *V) { // This function is hot. Check whether the value has any metadata to avoid a // DenseMap lookup. @@ -1521,7 +1513,7 @@ void llvm::findDbgUsers(SmallVectorImpl<DbgInfoIntrinsic *> &DbgUsers, if (auto *L = LocalAsMetadata::getIfExists(V)) if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) for (User *U : MDV->users()) - if (DbgInfoIntrinsic *DII = dyn_cast<DbgInfoIntrinsic>(U)) + if (DbgVariableIntrinsic *DII = dyn_cast<DbgVariableIntrinsic>(U)) DbgUsers.push_back(DII); } @@ -1529,7 +1521,7 @@ bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, Instruction *InsertBefore, DIBuilder &Builder, bool DerefBefore, int Offset, bool DerefAfter) { auto DbgAddrs = FindDbgAddrUses(Address); - for (DbgInfoIntrinsic *DII : DbgAddrs) { + for (DbgVariableIntrinsic *DII : DbgAddrs) { DebugLoc Loc = DII->getDebugLoc(); auto *DIVar = DII->getVariable(); auto *DIExpr = DII->getExpression(); @@ -1597,7 +1589,7 @@ static MetadataAsValue *wrapValueInMetadata(LLVMContext &C, Value *V) { } bool llvm::salvageDebugInfo(Instruction &I) { - SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; + SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; findDbgUsers(DbgUsers, &I); if (DbgUsers.empty()) return false; @@ -1607,7 +1599,7 @@ bool llvm::salvageDebugInfo(Instruction &I) { auto &Ctx = I.getContext(); auto wrapMD = [&](Value *V) { return wrapValueInMetadata(Ctx, V); }; - auto doSalvage = [&](DbgInfoIntrinsic *DII, SmallVectorImpl<uint64_t> &Ops) { + auto doSalvage = [&](DbgVariableIntrinsic *DII, SmallVectorImpl<uint64_t> &Ops) { auto *DIExpr = DII->getExpression(); if (!Ops.empty()) { // Do not add DW_OP_stack_value for DbgDeclare and DbgAddr, because they @@ -1621,13 +1613,13 @@ bool llvm::salvageDebugInfo(Instruction &I) { LLVM_DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); }; - auto applyOffset = [&](DbgInfoIntrinsic *DII, uint64_t Offset) { + auto applyOffset = [&](DbgVariableIntrinsic *DII, uint64_t Offset) { SmallVector<uint64_t, 8> Ops; DIExpression::appendOffset(Ops, Offset); doSalvage(DII, Ops); }; - auto applyOps = [&](DbgInfoIntrinsic *DII, + auto applyOps = [&](DbgVariableIntrinsic *DII, std::initializer_list<uint64_t> Opcodes) { SmallVector<uint64_t, 8> Ops(Opcodes); doSalvage(DII, Ops); @@ -1726,16 +1718,16 @@ using DbgValReplacement = Optional<DIExpression *>; /// changes are made. static bool rewriteDebugUsers( Instruction &From, Value &To, Instruction &DomPoint, DominatorTree &DT, - function_ref<DbgValReplacement(DbgInfoIntrinsic &DII)> RewriteExpr) { + function_ref<DbgValReplacement(DbgVariableIntrinsic &DII)> RewriteExpr) { // Find debug users of From. - SmallVector<DbgInfoIntrinsic *, 1> Users; + SmallVector<DbgVariableIntrinsic *, 1> Users; findDbgUsers(Users, &From); if (Users.empty()) return false; // Prevent use-before-def of To. bool Changed = false; - SmallPtrSet<DbgInfoIntrinsic *, 1> DeleteOrSalvage; + SmallPtrSet<DbgVariableIntrinsic *, 1> DeleteOrSalvage; if (isa<Instruction>(&To)) { bool DomPointAfterFrom = From.getNextNonDebugInstruction() == &DomPoint; @@ -1824,7 +1816,7 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, Type *FromTy = From.getType(); Type *ToTy = To.getType(); - auto Identity = [&](DbgInfoIntrinsic &DII) -> DbgValReplacement { + auto Identity = [&](DbgVariableIntrinsic &DII) -> DbgValReplacement { return DII.getExpression(); }; @@ -1848,7 +1840,7 @@ bool llvm::replaceAllDbgUsesWith(Instruction &From, Value &To, // The width of the result has shrunk. Use sign/zero extension to describe // the source variable's high bits. - auto SignOrZeroExt = [&](DbgInfoIntrinsic &DII) -> DbgValReplacement { + auto SignOrZeroExt = [&](DbgVariableIntrinsic &DII) -> DbgValReplacement { DILocalVariable *Var = DII.getVariable(); // Without knowing signedness, sign/zero extension isn't possible. @@ -1902,17 +1894,17 @@ unsigned llvm::removeAllNonTerminatorAndEHPadInstructions(BasicBlock *BB) { } unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, - bool PreserveLCSSA, DeferredDominance *DDT) { + bool PreserveLCSSA, DomTreeUpdater *DTU) { BasicBlock *BB = I->getParent(); std::vector <DominatorTree::UpdateType> Updates; // Loop over all of the successors, removing BB's entry from any PHI // nodes. - if (DDT) + if (DTU) Updates.reserve(BB->getTerminator()->getNumSuccessors()); for (BasicBlock *Successor : successors(BB)) { Successor->removePredecessor(BB, PreserveLCSSA); - if (DDT) + if (DTU) Updates.push_back({DominatorTree::Delete, BB, Successor}); } // Insert a call to llvm.trap right before this. This turns the undefined @@ -1923,7 +1915,8 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, CallInst *CallTrap = CallInst::Create(TrapFn, "", I); CallTrap->setDebugLoc(I->getDebugLoc()); } - new UnreachableInst(I->getContext(), I); + auto *UI = new UnreachableInst(I->getContext(), I); + UI->setDebugLoc(I->getDebugLoc()); // All instructions after this are dead. unsigned NumInstrsRemoved = 0; @@ -1934,13 +1927,13 @@ unsigned llvm::changeToUnreachable(Instruction *I, bool UseLLVMTrap, BB->getInstList().erase(BBI++); ++NumInstrsRemoved; } - if (DDT) - DDT->applyUpdates(Updates); + if (DTU) + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); return NumInstrsRemoved; } /// changeToCall - Convert the specified invoke into a normal call. -static void changeToCall(InvokeInst *II, DeferredDominance *DDT = nullptr) { +static void changeToCall(InvokeInst *II, DomTreeUpdater *DTU = nullptr) { SmallVector<Value*, 8> Args(II->arg_begin(), II->arg_end()); SmallVector<OperandBundleDef, 1> OpBundles; II->getOperandBundlesAsDefs(OpBundles); @@ -1950,6 +1943,7 @@ static void changeToCall(InvokeInst *II, DeferredDominance *DDT = nullptr) { NewCall->setCallingConv(II->getCallingConv()); NewCall->setAttributes(II->getAttributes()); NewCall->setDebugLoc(II->getDebugLoc()); + NewCall->copyMetadata(*II); II->replaceAllUsesWith(NewCall); // Follow the call by a branch to the normal destination. @@ -1961,8 +1955,8 @@ static void changeToCall(InvokeInst *II, DeferredDominance *DDT = nullptr) { BasicBlock *UnwindDestBB = II->getUnwindDest(); UnwindDestBB->removePredecessor(BB); II->eraseFromParent(); - if (DDT) - DDT->deleteEdge(BB, UnwindDestBB); + if (DTU) + DTU->deleteEdgeRelaxed(BB, UnwindDestBB); } BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, @@ -2003,8 +1997,8 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, } static bool markAliveBlocks(Function &F, - SmallPtrSetImpl<BasicBlock*> &Reachable, - DeferredDominance *DDT = nullptr) { + SmallPtrSetImpl<BasicBlock *> &Reachable, + DomTreeUpdater *DTU = nullptr) { SmallVector<BasicBlock*, 128> Worklist; BasicBlock *BB = &F.front(); Worklist.push_back(BB); @@ -2029,7 +2023,7 @@ static bool markAliveBlocks(Function &F, if (IntrinsicID == Intrinsic::assume) { if (match(CI->getArgOperand(0), m_CombineOr(m_Zero(), m_Undef()))) { // Don't insert a call to llvm.trap right before the unreachable. - changeToUnreachable(CI, false, false, DDT); + changeToUnreachable(CI, false, false, DTU); Changed = true; break; } @@ -2046,7 +2040,7 @@ static bool markAliveBlocks(Function &F, if (match(CI->getArgOperand(0), m_Zero())) if (!isa<UnreachableInst>(CI->getNextNode())) { changeToUnreachable(CI->getNextNode(), /*UseLLVMTrap=*/false, - false, DDT); + false, DTU); Changed = true; break; } @@ -2054,7 +2048,7 @@ static bool markAliveBlocks(Function &F, } else if ((isa<ConstantPointerNull>(Callee) && !NullPointerIsDefined(CI->getFunction())) || isa<UndefValue>(Callee)) { - changeToUnreachable(CI, /*UseLLVMTrap=*/false, false, DDT); + changeToUnreachable(CI, /*UseLLVMTrap=*/false, false, DTU); Changed = true; break; } @@ -2064,7 +2058,7 @@ static bool markAliveBlocks(Function &F, // though. if (!isa<UnreachableInst>(CI->getNextNode())) { // Don't insert a call to llvm.trap right before the unreachable. - changeToUnreachable(CI->getNextNode(), false, false, DDT); + changeToUnreachable(CI->getNextNode(), false, false, DTU); Changed = true; } break; @@ -2083,21 +2077,21 @@ static bool markAliveBlocks(Function &F, (isa<ConstantPointerNull>(Ptr) && !NullPointerIsDefined(SI->getFunction(), SI->getPointerAddressSpace()))) { - changeToUnreachable(SI, true, false, DDT); + changeToUnreachable(SI, true, false, DTU); Changed = true; break; } } } - TerminatorInst *Terminator = BB->getTerminator(); + Instruction *Terminator = BB->getTerminator(); if (auto *II = dyn_cast<InvokeInst>(Terminator)) { // Turn invokes that call 'nounwind' functions into ordinary calls. Value *Callee = II->getCalledValue(); if ((isa<ConstantPointerNull>(Callee) && !NullPointerIsDefined(BB->getParent())) || isa<UndefValue>(Callee)) { - changeToUnreachable(II, true, false, DDT); + changeToUnreachable(II, true, false, DTU); Changed = true; } else if (II->doesNotThrow() && canSimplifyInvokeNoUnwind(&F)) { if (II->use_empty() && II->onlyReadsMemory()) { @@ -2107,10 +2101,10 @@ static bool markAliveBlocks(Function &F, BranchInst::Create(NormalDestBB, II); UnwindDestBB->removePredecessor(II->getParent()); II->eraseFromParent(); - if (DDT) - DDT->deleteEdge(BB, UnwindDestBB); + if (DTU) + DTU->deleteEdgeRelaxed(BB, UnwindDestBB); } else - changeToCall(II, DDT); + changeToCall(II, DTU); Changed = true; } } else if (auto *CatchSwitch = dyn_cast<CatchSwitchInst>(Terminator)) { @@ -2156,7 +2150,7 @@ static bool markAliveBlocks(Function &F, } } - Changed |= ConstantFoldTerminator(BB, true, nullptr, DDT); + Changed |= ConstantFoldTerminator(BB, true, nullptr, DTU); for (BasicBlock *Successor : successors(BB)) if (Reachable.insert(Successor).second) Worklist.push_back(Successor); @@ -2164,15 +2158,15 @@ static bool markAliveBlocks(Function &F, return Changed; } -void llvm::removeUnwindEdge(BasicBlock *BB, DeferredDominance *DDT) { - TerminatorInst *TI = BB->getTerminator(); +void llvm::removeUnwindEdge(BasicBlock *BB, DomTreeUpdater *DTU) { + Instruction *TI = BB->getTerminator(); if (auto *II = dyn_cast<InvokeInst>(TI)) { - changeToCall(II, DDT); + changeToCall(II, DTU); return; } - TerminatorInst *NewTI; + Instruction *NewTI; BasicBlock *UnwindDest; if (auto *CRI = dyn_cast<CleanupReturnInst>(TI)) { @@ -2196,8 +2190,8 @@ void llvm::removeUnwindEdge(BasicBlock *BB, DeferredDominance *DDT) { UnwindDest->removePredecessor(BB); TI->replaceAllUsesWith(NewTI); TI->eraseFromParent(); - if (DDT) - DDT->deleteEdge(BB, UnwindDest); + if (DTU) + DTU->deleteEdgeRelaxed(BB, UnwindDest); } /// removeUnreachableBlocks - Remove blocks that are not reachable, even @@ -2205,9 +2199,10 @@ void llvm::removeUnwindEdge(BasicBlock *BB, DeferredDominance *DDT) { /// otherwise. If `LVI` is passed, this function preserves LazyValueInfo /// after modifying the CFG. bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, - DeferredDominance *DDT) { + DomTreeUpdater *DTU, + MemorySSAUpdater *MSSAU) { SmallPtrSet<BasicBlock*, 16> Reachable; - bool Changed = markAliveBlocks(F, Reachable, DDT); + bool Changed = markAliveBlocks(F, Reachable, DTU); // If there are unreachable blocks in the CFG... if (Reachable.size() == F.size()) @@ -2216,45 +2211,68 @@ bool llvm::removeUnreachableBlocks(Function &F, LazyValueInfo *LVI, assert(Reachable.size() < F.size()); NumRemoved += F.size()-Reachable.size(); - // Loop over all of the basic blocks that are not reachable, dropping all of - // their internal references. Update DDT and LVI if available. - std::vector <DominatorTree::UpdateType> Updates; + SmallPtrSet<BasicBlock *, 16> DeadBlockSet; for (Function::iterator I = ++F.begin(), E = F.end(); I != E; ++I) { auto *BB = &*I; if (Reachable.count(BB)) continue; + DeadBlockSet.insert(BB); + } + + if (MSSAU) + MSSAU->removeBlocks(DeadBlockSet); + + // Loop over all of the basic blocks that are not reachable, dropping all of + // their internal references. Update DTU and LVI if available. + std::vector<DominatorTree::UpdateType> Updates; + for (auto *BB : DeadBlockSet) { for (BasicBlock *Successor : successors(BB)) { - if (Reachable.count(Successor)) + if (!DeadBlockSet.count(Successor)) Successor->removePredecessor(BB); - if (DDT) + if (DTU) Updates.push_back({DominatorTree::Delete, BB, Successor}); } if (LVI) LVI->eraseBlock(BB); BB->dropAllReferences(); } - for (Function::iterator I = ++F.begin(); I != F.end();) { auto *BB = &*I; if (Reachable.count(BB)) { ++I; continue; } - if (DDT) { - DDT->deleteBB(BB); // deferred deletion of BB. + if (DTU) { + // Remove the terminator of BB to clear the successor list of BB. + if (BB->getTerminator()) + BB->getInstList().pop_back(); + new UnreachableInst(BB->getContext(), BB); + assert(succ_empty(BB) && "The successor list of BB isn't empty before " + "applying corresponding DTU updates."); ++I; } else { I = F.getBasicBlockList().erase(I); } } - if (DDT) - DDT->applyUpdates(Updates); + if (DTU) { + DTU->applyUpdates(Updates, /*ForceRemoveDuplicates*/ true); + bool Deleted = false; + for (auto *BB : DeadBlockSet) { + if (DTU->isBBPendingDeletion(BB)) + --NumRemoved; + else + Deleted = true; + DTU->deleteBB(BB); + } + if (!Deleted) + return false; + } return true; } void llvm::combineMetadata(Instruction *K, const Instruction *J, - ArrayRef<unsigned> KnownIDs) { + ArrayRef<unsigned> KnownIDs, bool DoesKMove) { SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; K->dropUnknownNonDebugMetadata(KnownIDs); K->getAllMetadataOtherThanDebugLoc(Metadata); @@ -2279,8 +2297,20 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, case LLVMContext::MD_mem_parallel_loop_access: K->setMetadata(Kind, MDNode::intersect(JMD, KMD)); break; + case LLVMContext::MD_access_group: + K->setMetadata(LLVMContext::MD_access_group, + intersectAccessGroups(K, J)); + break; case LLVMContext::MD_range: - K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD)); + + // If K does move, use most generic range. Otherwise keep the range of + // K. + if (DoesKMove) + // FIXME: If K does move, we should drop the range info and nonnull. + // Currently this function is used with DoesKMove in passes + // doing hoisting/sinking and the current behavior of using the + // most generic range is correct in those cases. + K->setMetadata(Kind, MDNode::getMostGenericRange(JMD, KMD)); break; case LLVMContext::MD_fpmath: K->setMetadata(Kind, MDNode::getMostGenericFPMath(JMD, KMD)); @@ -2290,8 +2320,9 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, K->setMetadata(Kind, JMD); break; case LLVMContext::MD_nonnull: - // Only set the !nonnull if it is present in both instructions. - K->setMetadata(Kind, JMD); + // If K does move, keep nonull if it is present in both instructions. + if (DoesKMove) + K->setMetadata(Kind, JMD); break; case LLVMContext::MD_invariant_group: // Preserve !invariant.group in K. @@ -2318,15 +2349,49 @@ void llvm::combineMetadata(Instruction *K, const Instruction *J, K->setMetadata(LLVMContext::MD_invariant_group, JMD); } -void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J) { +void llvm::combineMetadataForCSE(Instruction *K, const Instruction *J, + bool KDominatesJ) { unsigned KnownIDs[] = { LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, LLVMContext::MD_noalias, LLVMContext::MD_range, LLVMContext::MD_invariant_load, LLVMContext::MD_nonnull, LLVMContext::MD_invariant_group, LLVMContext::MD_align, LLVMContext::MD_dereferenceable, - LLVMContext::MD_dereferenceable_or_null}; - combineMetadata(K, J, KnownIDs); + LLVMContext::MD_dereferenceable_or_null, + LLVMContext::MD_access_group}; + combineMetadata(K, J, KnownIDs, KDominatesJ); +} + +void llvm::patchReplacementInstruction(Instruction *I, Value *Repl) { + auto *ReplInst = dyn_cast<Instruction>(Repl); + if (!ReplInst) + return; + + // Patch the replacement so that it is not more restrictive than the value + // being replaced. + // Note that if 'I' is a load being replaced by some operation, + // for example, by an arithmetic operation, then andIRFlags() + // would just erase all math flags from the original arithmetic + // operation, which is clearly not wanted and not needed. + if (!isa<LoadInst>(I)) + ReplInst->andIRFlags(I); + + // FIXME: If both the original and replacement value are part of the + // same control-flow region (meaning that the execution of one + // guarantees the execution of the other), then we can combine the + // noalias scopes here and do better than the general conservative + // answer used in combineMetadata(). + + // In general, GVN unifies expressions over different control-flow + // regions, and so we need a conservative combination of the noalias + // scopes. + static const unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group, LLVMContext::MD_nonnull, + LLVMContext::MD_access_group}; + combineMetadata(ReplInst, I, KnownIDs, false); } template <typename RootType, typename DominatesFn> @@ -2454,6 +2519,54 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, } } +void llvm::dropDebugUsers(Instruction &I) { + SmallVector<DbgVariableIntrinsic *, 1> DbgUsers; + findDbgUsers(DbgUsers, &I); + for (auto *DII : DbgUsers) + DII->eraseFromParent(); +} + +void llvm::hoistAllInstructionsInto(BasicBlock *DomBlock, Instruction *InsertPt, + BasicBlock *BB) { + // Since we are moving the instructions out of its basic block, we do not + // retain their original debug locations (DILocations) and debug intrinsic + // instructions (dbg.values). + // + // Doing so would degrade the debugging experience and adversely affect the + // accuracy of profiling information. + // + // Currently, when hoisting the instructions, we take the following actions: + // - Remove their dbg.values. + // - Set their debug locations to the values from the insertion point. + // + // As per PR39141 (comment #8), the more fundamental reason why the dbg.values + // need to be deleted, is because there will not be any instructions with a + // DILocation in either branch left after performing the transformation. We + // can only insert a dbg.value after the two branches are joined again. + // + // See PR38762, PR39243 for more details. + // + // TODO: Extend llvm.dbg.value to take more than one SSA Value (PR39141) to + // encode predicated DIExpressions that yield different results on different + // code paths. + for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { + Instruction *I = &*II; + I->dropUnknownNonDebugMetadata(); + if (I->isUsedByMetadata()) + dropDebugUsers(*I); + if (isa<DbgVariableIntrinsic>(I)) { + // Remove DbgInfo Intrinsics. + II = I->eraseFromParent(); + continue; + } + I->setDebugLoc(InsertPt->getDebugLoc()); + ++II; + } + DomBlock->getInstList().splice(InsertPt->getIterator(), BB->getInstList(), + BB->begin(), + BB->getTerminator()->getIterator()); +} + namespace { /// A potential constituent of a bitreverse or bswap expression. See diff --git a/lib/Transforms/Utils/LoopRotationUtils.cpp b/lib/Transforms/Utils/LoopRotationUtils.cpp index 6e92e679f999..41f14a834617 100644 --- a/lib/Transforms/Utils/LoopRotationUtils.cpp +++ b/lib/Transforms/Utils/LoopRotationUtils.cpp @@ -20,13 +20,15 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" @@ -35,6 +37,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/ValueMapper.h" @@ -53,6 +56,7 @@ class LoopRotate { AssumptionCache *AC; DominatorTree *DT; ScalarEvolution *SE; + MemorySSAUpdater *MSSAU; const SimplifyQuery &SQ; bool RotationOnly; bool IsUtilMode; @@ -60,10 +64,11 @@ class LoopRotate { public: LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC, - DominatorTree *DT, ScalarEvolution *SE, const SimplifyQuery &SQ, - bool RotationOnly, bool IsUtilMode) + DominatorTree *DT, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + const SimplifyQuery &SQ, bool RotationOnly, bool IsUtilMode) : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE), - SQ(SQ), RotationOnly(RotationOnly), IsUtilMode(IsUtilMode) {} + MSSAU(MSSAU), SQ(SQ), RotationOnly(RotationOnly), + IsUtilMode(IsUtilMode) {} bool processLoop(Loop *L); private: @@ -268,6 +273,8 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { SE->forgetTopmostLoop(L); LLVM_DEBUG(dbgs() << "LoopRotation: rotating "; L->dump()); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); // Find new Loop header. NewHeader is a Header's one and only successor // that is inside loop. Header's other successor is outside the @@ -298,18 +305,18 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // For the rest of the instructions, either hoist to the OrigPreheader if // possible or create a clone in the OldPreHeader if not. - TerminatorInst *LoopEntryBranch = OrigPreheader->getTerminator(); + Instruction *LoopEntryBranch = OrigPreheader->getTerminator(); // Record all debug intrinsics preceding LoopEntryBranch to avoid duplication. using DbgIntrinsicHash = std::pair<std::pair<Value *, DILocalVariable *>, DIExpression *>; - auto makeHash = [](DbgInfoIntrinsic *D) -> DbgIntrinsicHash { + auto makeHash = [](DbgVariableIntrinsic *D) -> DbgIntrinsicHash { return {{D->getVariableLocation(), D->getVariable()}, D->getExpression()}; }; SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics; for (auto I = std::next(OrigPreheader->rbegin()), E = OrigPreheader->rend(); I != E; ++I) { - if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&*I)) + if (auto *DII = dyn_cast<DbgVariableIntrinsic>(&*I)) DbgIntrinsics.insert(makeHash(DII)); else break; @@ -325,7 +332,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // something that might trap, but isn't safe to hoist something that reads // memory (without proving that the loop doesn't write). if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() && - !Inst->mayWriteToMemory() && !isa<TerminatorInst>(Inst) && + !Inst->mayWriteToMemory() && !Inst->isTerminator() && !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) { Inst->moveBefore(LoopEntryBranch); continue; @@ -339,7 +346,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); // Avoid inserting the same intrinsic twice. - if (auto *DII = dyn_cast<DbgInfoIntrinsic>(C)) + if (auto *DII = dyn_cast<DbgVariableIntrinsic>(C)) if (DbgIntrinsics.count(makeHash(DII))) { C->deleteValue(); continue; @@ -374,8 +381,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // Along with all the other instructions, we just cloned OrigHeader's // terminator into OrigPreHeader. Fix up the PHI nodes in each of OrigHeader's // successors by duplicating their incoming values for OrigHeader. - TerminatorInst *TI = OrigHeader->getTerminator(); - for (BasicBlock *SuccBB : TI->successors()) + for (BasicBlock *SuccBB : successors(OrigHeader)) for (BasicBlock::iterator BI = SuccBB->begin(); PHINode *PN = dyn_cast<PHINode>(BI); ++BI) PN->addIncoming(PN->getIncomingValueForBlock(OrigHeader), OrigPreheader); @@ -385,6 +391,12 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // remove the corresponding incoming values from the PHI nodes in OrigHeader. LoopEntryBranch->eraseFromParent(); + // Update MemorySSA before the rewrite call below changes the 1:1 + // instruction:cloned_instruction_or_value mapping in ValueMap. + if (MSSAU) { + ValueMap[OrigHeader] = OrigPreheader; + MSSAU->updateForClonedBlockIntoPred(OrigHeader, OrigPreheader, ValueMap); + } SmallVector<PHINode*, 2> InsertedPHIs; // If there were any uses of instructions in the duplicated block outside the @@ -411,6 +423,12 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { Updates.push_back({DominatorTree::Insert, OrigPreheader, NewHeader}); Updates.push_back({DominatorTree::Delete, OrigPreheader, OrigHeader}); DT->applyUpdates(Updates); + + if (MSSAU) { + MSSAU->applyUpdates(Updates, *DT); + if (VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + } } // At this point, we've finished our major CFG changes. As part of cloning @@ -433,7 +451,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // Split the edge to form a real preheader. BasicBlock *NewPH = SplitCriticalEdge( OrigPreheader, NewHeader, - CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + CriticalEdgeSplittingOptions(DT, LI, MSSAU).setPreserveLCSSA()); NewPH->setName(NewHeader->getName() + ".lr.ph"); // Preserve canonical loop form, which means that 'Exit' should have only @@ -452,7 +470,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { SplitLatchEdge |= L->getLoopLatch() == ExitPred; BasicBlock *ExitSplit = SplitCriticalEdge( ExitPred, Exit, - CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + CriticalEdgeSplittingOptions(DT, LI, MSSAU).setPreserveLCSSA()); ExitSplit->moveBefore(Exit); } assert(SplitLatchEdge && @@ -467,16 +485,27 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // With our CFG finalized, update DomTree if it is available. if (DT) DT->deleteEdge(OrigPreheader, Exit); + + // Update MSSA too, if available. + if (MSSAU) + MSSAU->removeEdge(OrigPreheader, Exit); } assert(L->getLoopPreheader() && "Invalid loop preheader after loop rotation"); assert(L->getLoopLatch() && "Invalid loop latch after loop rotation"); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // Now that the CFG and DomTree are in a consistent state again, try to merge // the OrigHeader block into OrigLatch. This will succeed if they are // connected by an unconditional branch. This is just a cleanup so the // emitted code isn't too gross in this common case. - MergeBlockIntoPredecessor(OrigHeader, DT, LI); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + MergeBlockIntoPredecessor(OrigHeader, &DTU, LI, MSSAU); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); LLVM_DEBUG(dbgs() << "LoopRotation: into "; L->dump()); @@ -585,9 +614,14 @@ bool LoopRotate::simplifyLoopLatch(Loop *L) { << LastExit->getName() << "\n"); // Hoist the instructions from Latch into LastExit. + Instruction *FirstLatchInst = &*(Latch->begin()); LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(), Latch->begin(), Jmp->getIterator()); + // Update MemorySSA + if (MSSAU) + MSSAU->moveAllAfterMergeBlocks(Latch, LastExit, FirstLatchInst); + unsigned FallThruPath = BI->getSuccessor(0) == Latch ? 0 : 1; BasicBlock *Header = Jmp->getSuccessor(0); assert(Header == L->getHeader() && "expected a backward branch"); @@ -603,6 +637,10 @@ bool LoopRotate::simplifyLoopLatch(Loop *L) { if (DT) DT->eraseNode(Latch); Latch->eraseFromParent(); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + return true; } @@ -635,11 +673,16 @@ bool LoopRotate::processLoop(Loop *L) { /// The utility to convert a loop into a loop with bottom test. bool llvm::LoopRotation(Loop *L, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC, DominatorTree *DT, - ScalarEvolution *SE, const SimplifyQuery &SQ, - bool RotationOnly = true, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + const SimplifyQuery &SQ, bool RotationOnly = true, unsigned Threshold = unsigned(-1), bool IsUtilMode = true) { - LoopRotate LR(Threshold, LI, TTI, AC, DT, SE, SQ, RotationOnly, IsUtilMode); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + LoopRotate LR(Threshold, LI, TTI, AC, DT, SE, MSSAU, SQ, RotationOnly, + IsUtilMode); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); return LR.processLoop(L); } diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index 970494eb4704..380f4fca54d9 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -137,7 +137,7 @@ BasicBlock *llvm::InsertPreheaderForLoop(Loop *L, DominatorTree *DT, // Split out the loop pre-header. BasicBlock *PreheaderBB; PreheaderBB = SplitBlockPredecessors(Header, OutsideBlocks, ".preheader", DT, - LI, PreserveLCSSA); + LI, nullptr, PreserveLCSSA); if (!PreheaderBB) return nullptr; @@ -251,7 +251,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, SE->forgetLoop(L); BasicBlock *NewBB = SplitBlockPredecessors(Header, OuterLoopPreds, ".outer", - DT, LI, PreserveLCSSA); + DT, LI, nullptr, PreserveLCSSA); // Make sure that NewBB is put someplace intelligent, which doesn't mess up // code layout too horribly. @@ -435,7 +435,7 @@ static BasicBlock *insertUniqueBackedgeBlock(Loop *L, BasicBlock *Preheader, unsigned LoopMDKind = BEBlock->getContext().getMDKindID("llvm.loop"); MDNode *LoopMD = nullptr; for (unsigned i = 0, e = BackedgeBlocks.size(); i != e; ++i) { - TerminatorInst *TI = BackedgeBlocks[i]->getTerminator(); + Instruction *TI = BackedgeBlocks[i]->getTerminator(); if (!LoopMD) LoopMD = TI->getMetadata(LoopMDKind); TI->setMetadata(LoopMDKind, nullptr); @@ -488,7 +488,7 @@ ReprocessLoop: << P->getName() << "\n"); // Zap the dead pred's terminator and replace it with unreachable. - TerminatorInst *TI = P->getTerminator(); + Instruction *TI = P->getTerminator(); changeToUnreachable(TI, /*UseLLVMTrap=*/false, PreserveLCSSA); Changed = true; } diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index 04b8c1417e0a..da7ed2bd1652 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -54,10 +54,10 @@ UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden, static cl::opt<bool> UnrollVerifyDomtree("unroll-verify-domtree", cl::Hidden, cl::desc("Verify domtree after unrolling"), -#ifdef NDEBUG - cl::init(false) -#else +#ifdef EXPENSIVE_CHECKS cl::init(true) +#else + cl::init(false) #endif ); @@ -275,8 +275,7 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, // inserted code, doing constant propagation and dead code elimination as we // go. const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - const std::vector<BasicBlock *> &NewLoopBlocks = L->getBlocks(); - for (BasicBlock *BB : NewLoopBlocks) { + for (BasicBlock *BB : L->getBlocks()) { for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { Instruction *Inst = &*I++; @@ -330,12 +329,15 @@ void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, /// /// This utility preserves LoopInfo. It will also preserve ScalarEvolution and /// DominatorTree if they are non-null. +/// +/// If RemainderLoop is non-null, it will receive the remainder loop (if +/// required and not fully unrolled). LoopUnrollResult llvm::UnrollLoop( Loop *L, unsigned Count, unsigned TripCount, bool Force, bool AllowRuntime, bool AllowExpensiveTripCount, bool PreserveCondBr, bool PreserveOnlyFirst, unsigned TripMultiple, unsigned PeelCount, bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, - OptimizationRemarkEmitter *ORE, bool PreserveLCSSA) { + OptimizationRemarkEmitter *ORE, bool PreserveLCSSA, Loop **RemainderLoop) { BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { @@ -469,7 +471,7 @@ LoopUnrollResult llvm::UnrollLoop( if (RuntimeTripCount && TripMultiple % Count != 0 && !UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount, EpilogProfitability, UnrollRemainder, LI, SE, - DT, AC, PreserveLCSSA)) { + DT, AC, PreserveLCSSA, RemainderLoop)) { if (Force) RuntimeTripCount = false; else { @@ -596,8 +598,15 @@ LoopUnrollResult llvm::UnrollLoop( for (BasicBlock *BB : L->getBlocks()) for (Instruction &I : *BB) if (!isa<DbgInfoIntrinsic>(&I)) - if (const DILocation *DIL = I.getDebugLoc()) - I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count)); + if (const DILocation *DIL = I.getDebugLoc()) { + auto NewDIL = DIL->cloneWithDuplicationFactor(Count); + if (NewDIL) + I.setDebugLoc(NewDIL.getValue()); + else + LLVM_DEBUG(dbgs() + << "Failed to create new discriminator: " + << DIL->getFilename() << " Line: " << DIL->getLine()); + } for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; @@ -782,7 +791,7 @@ LoopUnrollResult llvm::UnrollLoop( // there is no such latch. NewIDom = Latches.back(); for (BasicBlock *IterLatch : Latches) { - TerminatorInst *Term = IterLatch->getTerminator(); + Instruction *Term = IterLatch->getTerminator(); if (isa<BranchInst>(Term) && cast<BranchInst>(Term)->isConditional()) { NewIDom = IterLatch; break; diff --git a/lib/Transforms/Utils/LoopUnrollAndJam.cpp b/lib/Transforms/Utils/LoopUnrollAndJam.cpp index b919f73c3817..e26762639c13 100644 --- a/lib/Transforms/Utils/LoopUnrollAndJam.cpp +++ b/lib/Transforms/Utils/LoopUnrollAndJam.cpp @@ -72,7 +72,7 @@ static bool partitionOuterLoopBlocks(Loop *L, Loop *SubLoop, for (BasicBlock *BB : ForeBlocks) { if (BB == SubLoopPreHeader) continue; - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) if (!ForeBlocks.count(TI->getSuccessor(i))) return false; @@ -167,12 +167,14 @@ static void moveHeaderPhiOperandsToForeBlocks(BasicBlock *Header, isSafeToUnrollAndJam should be used prior to calling this to make sure the unrolling will be valid. Checking profitablility is also advisable. + + If EpilogueLoop is non-null, it receives the epilogue loop (if it was + necessary to create one and not fully unrolled). */ -LoopUnrollResult -llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, - unsigned TripMultiple, bool UnrollRemainder, - LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, - AssumptionCache *AC, OptimizationRemarkEmitter *ORE) { +LoopUnrollResult llvm::UnrollAndJamLoop( + Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple, + bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, + AssumptionCache *AC, OptimizationRemarkEmitter *ORE, Loop **EpilogueLoop) { // When we enter here we should have already checked that it is safe BasicBlock *Header = L->getHeader(); @@ -181,7 +183,7 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, // Don't enter the unroll code if there is nothing to do. if (TripCount == 0 && Count < 2) { - LLVM_DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; almost nothing to do\n"); return LoopUnrollResult::Unmodified; } @@ -196,7 +198,8 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, if (TripMultiple == 1 || TripMultiple % Count != 0) { if (!UnrollRuntimeLoopRemainder(L, Count, /*AllowExpensiveTripCount*/ false, /*UseEpilogRemainder*/ true, - UnrollRemainder, LI, SE, DT, AC, true)) { + UnrollRemainder, LI, SE, DT, AC, true, + EpilogueLoop)) { LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; remainder loop could not be " "generated when assuming runtime trip count\n"); return LoopUnrollResult::Unmodified; @@ -297,8 +300,15 @@ llvm::UnrollAndJamLoop(Loop *L, unsigned Count, unsigned TripCount, for (BasicBlock *BB : L->getBlocks()) for (Instruction &I : *BB) if (!isa<DbgInfoIntrinsic>(&I)) - if (const DILocation *DIL = I.getDebugLoc()) - I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count)); + if (const DILocation *DIL = I.getDebugLoc()) { + auto NewDIL = DIL->cloneWithDuplicationFactor(Count); + if (NewDIL) + I.setDebugLoc(NewDIL.getValue()); + else + LLVM_DEBUG(dbgs() + << "Failed to create new discriminator: " + << DIL->getFilename() << " Line: " << DIL->getLine()); + } // Copy all blocks for (unsigned It = 1; It != Count; ++It) { @@ -619,16 +629,28 @@ static bool checkDependencies(SmallVector<Value *, 4> &Earlier, if (auto D = DI.depends(Src, Dst, true)) { assert(D->isOrdered() && "Expected an output, flow or anti dep."); - if (D->isConfused()) + if (D->isConfused()) { + LLVM_DEBUG(dbgs() << " Confused dependency between:\n" + << " " << *Src << "\n" + << " " << *Dst << "\n"); return false; + } if (!InnerLoop) { - if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) + if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT) { + LLVM_DEBUG(dbgs() << " > dependency between:\n" + << " " << *Src << "\n" + << " " << *Dst << "\n"); return false; + } } else { assert(LoopDepth + 1 <= D->getLevels()); if (D->getDirection(LoopDepth) & Dependence::DVEntry::GT && - D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) + D->getDirection(LoopDepth + 1) & Dependence::DVEntry::LT) { + LLVM_DEBUG(dbgs() << " < > dependency between:\n" + << " " << *Src << "\n" + << " " << *Dst << "\n"); return false; + } } } } @@ -716,38 +738,45 @@ bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT, if (SubLoopLatch != SubLoopExit) return false; - if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) + if (Header->hasAddressTaken() || SubLoopHeader->hasAddressTaken()) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Address taken\n"); return false; + } // Split blocks into Fore/SubLoop/Aft based on dominators BasicBlockSet SubLoopBlocks; BasicBlockSet ForeBlocks; BasicBlockSet AftBlocks; if (!partitionOuterLoopBlocks(L, SubLoop, ForeBlocks, SubLoopBlocks, - AftBlocks, &DT)) + AftBlocks, &DT)) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Incompatible loop layout\n"); return false; + } // Aft blocks may need to move instructions to fore blocks, which becomes more // difficult if there are multiple (potentially conditionally executed) // blocks. For now we just exclude loops with multiple aft blocks. - if (AftBlocks.size() != 1) + if (AftBlocks.size() != 1) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Can't currently handle " + "multiple blocks after the loop\n"); return false; + } - // Check inner loop IV is consistent between all iterations - const SCEV *SubLoopBECountSC = SE.getExitCount(SubLoop, SubLoopLatch); - if (isa<SCEVCouldNotCompute>(SubLoopBECountSC) || - !SubLoopBECountSC->getType()->isIntegerTy()) - return false; - ScalarEvolution::LoopDisposition LD = - SE.getLoopDisposition(SubLoopBECountSC, L); - if (LD != ScalarEvolution::LoopInvariant) + // Check inner loop backedge count is consistent on all iterations of the + // outer loop + if (!hasIterationCountInvariantInParent(SubLoop, SE)) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Inner loop iteration count is " + "not consistent on each iteration\n"); return false; + } // Check the loop safety info for exceptions. - LoopSafetyInfo LSI; - computeLoopSafetyInfo(&LSI, L); - if (LSI.MayThrow) + SimpleLoopSafetyInfo LSI; + LSI.computeLoopSafetyInfo(L); + if (LSI.anyBlockMayThrow()) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; Something may throw\n"); return false; + } // We've ruled out the easy stuff and now need to check that there are no // interdependencies which may prevent us from moving the: @@ -772,14 +801,19 @@ bool llvm::isSafeToUnrollAndJam(Loop *L, ScalarEvolution &SE, DominatorTree &DT, } // Keep going return true; - })) + })) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't move required " + "instructions after subloop to before it\n"); return false; + } // Check for memory dependencies which prohibit the unrolling we are doing. // Because of the way we are unrolling Fore/Sub/Aft blocks, we need to check // there are no dependencies between Fore-Sub, Fore-Aft, Sub-Aft and Sub-Sub. - if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) + if (!checkDependencies(L, ForeBlocks, SubLoopBlocks, AftBlocks, DI)) { + LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; failed dependency check\n"); return false; + } return true; } diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp index 78afe748e596..151a285af4e9 100644 --- a/lib/Transforms/Utils/LoopUnrollPeel.cpp +++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -615,11 +615,17 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI, // the original loop body. if (Iter == 0) DT->changeImmediateDominator(Exit, cast<BasicBlock>(LVMap[Latch])); +#ifdef EXPENSIVE_CHECKS assert(DT->verify(DominatorTree::VerificationLevel::Fast)); +#endif } - updateBranchWeights(InsertBot, cast<BranchInst>(VMap[LatchBR]), Iter, + auto *LatchBRCopy = cast<BranchInst>(VMap[LatchBR]); + updateBranchWeights(InsertBot, LatchBRCopy, Iter, PeelCount, ExitWeight); + // Remove Loop metadata from the latch branch instruction + // because it is not the Loop's latch branch anymore. + LatchBRCopy->setMetadata(LLVMContext::MD_loop, nullptr); InsertTop = InsertBot; InsertBot = SplitBlock(InsertBot, InsertBot->getTerminator(), DT, LI); diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 0057b4ba7ce1..00d2fd2fdbac 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -70,6 +70,17 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, BasicBlock *PreHeader, BasicBlock *NewPreHeader, ValueToValueMapTy &VMap, DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA) { + // Loop structure should be the following: + // Preheader + // PrologHeader + // ... + // PrologLatch + // PrologExit + // NewPreheader + // Header + // ... + // Latch + // LatchExit BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Loop must have a latch"); BasicBlock *PrologLatch = cast<BasicBlock>(VMap[Latch]); @@ -83,14 +94,21 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, for (PHINode &PN : Succ->phis()) { // Add a new PHI node to the prolog end block and add the // appropriate incoming values. + // TODO: This code assumes that the PrologExit (or the LatchExit block for + // prolog loop) contains only one predecessor from the loop, i.e. the + // PrologLatch. When supporting multiple-exiting block loops, we can have + // two or more blocks that have the LatchExit as the target in the + // original loop. PHINode *NewPN = PHINode::Create(PN.getType(), 2, PN.getName() + ".unr", PrologExit->getFirstNonPHI()); // Adding a value to the new PHI node from the original loop preheader. // This is the value that skips all the prolog code. if (L->contains(&PN)) { + // Succ is loop header. NewPN->addIncoming(PN.getIncomingValueForBlock(NewPreHeader), PreHeader); } else { + // Succ is LatchExit. NewPN->addIncoming(UndefValue::get(PN.getType()), PreHeader); } @@ -124,7 +142,7 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, PrologExitPreds.push_back(PredBB); SplitBlockPredecessors(PrologExit, PrologExitPreds, ".unr-lcssa", DT, LI, - PreserveLCSSA); + nullptr, PreserveLCSSA); } // Create a branch around the original loop, which is taken if there are no @@ -143,7 +161,7 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // Split the exit to maintain loop canonicalization guarantees SmallVector<BasicBlock *, 4> Preds(predecessors(OriginalLoopLatchExit)); SplitBlockPredecessors(OriginalLoopLatchExit, Preds, ".unr-lcssa", DT, LI, - PreserveLCSSA); + nullptr, PreserveLCSSA); // Add the branch to the exit block (around the unrolled loop) B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader); InsertPt->eraseFromParent(); @@ -257,7 +275,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, assert(Exit && "Loop must have a single exit block only"); // Split the epilogue exit to maintain loop canonicalization guarantees SmallVector<BasicBlock*, 4> Preds(predecessors(Exit)); - SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, + SplitBlockPredecessors(Exit, Preds, ".epilog-lcssa", DT, LI, nullptr, PreserveLCSSA); // Add the branch to the exit block (around the unrolling loop) B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit); @@ -267,7 +285,7 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // Split the main loop exit to maintain canonicalization guarantees. SmallVector<BasicBlock*, 4> NewExitPreds{Latch}; - SplitBlockPredecessors(NewExit, NewExitPreds, ".loopexit", DT, LI, + SplitBlockPredecessors(NewExit, NewExitPreds, ".loopexit", DT, LI, nullptr, PreserveLCSSA); } @@ -380,6 +398,7 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, } if (CreateRemainderLoop) { Loop *NewLoop = NewLoops[L]; + MDNode *LoopID = NewLoop->getLoopID(); assert(NewLoop && "L should have been cloned"); // Only add loop metadata if the loop is not going to be completely @@ -387,6 +406,16 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, if (UnrollRemainder) return NewLoop; + Optional<MDNode *> NewLoopID = makeFollowupLoopID( + LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); + if (NewLoopID.hasValue()) { + NewLoop->setLoopID(NewLoopID.getValue()); + + // Do not setLoopAlreadyUnrolled if loop attributes have been defined + // explicitly. + return NewLoop; + } + // Add unroll disable metadata to disable future unrolling for this loop. NewLoop->setLoopAlreadyUnrolled(); return NewLoop; @@ -525,10 +554,10 @@ static bool canProfitablyUnrollMultiExitLoop( bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, bool AllowExpensiveTripCount, bool UseEpilogRemainder, - bool UnrollRemainder, - LoopInfo *LI, ScalarEvolution *SE, - DominatorTree *DT, AssumptionCache *AC, - bool PreserveLCSSA) { + bool UnrollRemainder, LoopInfo *LI, + ScalarEvolution *SE, DominatorTree *DT, + AssumptionCache *AC, bool PreserveLCSSA, + Loop **ResultLoop) { LLVM_DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n"); LLVM_DEBUG(L->dump()); LLVM_DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n" @@ -545,13 +574,27 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, BasicBlock *Header = L->getHeader(); BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); + + if (!LatchBR || LatchBR->isUnconditional()) { + // The loop-rotate pass can be helpful to avoid this in many cases. + LLVM_DEBUG( + dbgs() + << "Loop latch not terminated by a conditional branch.\n"); + return false; + } + unsigned ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0; BasicBlock *LatchExit = LatchBR->getSuccessor(ExitIndex); - // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the - // targets of the Latch be an exit block out of the loop. This needs - // to be guaranteed by the callers of UnrollRuntimeLoopRemainder. - assert(!L->contains(LatchExit) && - "one of the loop latch successors should be the exit block!"); + + if (L->contains(LatchExit)) { + // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the + // targets of the Latch be an exit block out of the loop. + LLVM_DEBUG( + dbgs() + << "One of the loop latch successors must be the exit block.\n"); + return false; + } + // These are exit blocks other than the target of the latch exiting block. SmallVector<BasicBlock *, 4> OtherExits; bool isMultiExitUnrollingEnabled = @@ -636,8 +679,8 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, NewPreHeader->setName(PreHeader->getName() + ".new"); // Split LatchExit to create phi nodes from branch above. SmallVector<BasicBlock*, 4> Preds(predecessors(LatchExit)); - NewExit = SplitBlockPredecessors(LatchExit, Preds, ".unr-lcssa", - DT, LI, PreserveLCSSA); + NewExit = SplitBlockPredecessors(LatchExit, Preds, ".unr-lcssa", DT, LI, + nullptr, PreserveLCSSA); // NewExit gets its DebugLoc from LatchExit, which is not part of the // original Loop. // Fix this by setting Loop's DebugLoc to NewExit. @@ -762,10 +805,7 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Now the loop blocks are cloned and the other exiting blocks from the // remainder are connected to the original Loop's exit blocks. The remaining // work is to update the phi nodes in the original loop, and take in the - // values from the cloned region. Also update the dominator info for - // OtherExits and their immediate successors, since we have new edges into - // OtherExits. - SmallPtrSet<BasicBlock*, 8> ImmediateSuccessorsOfExitBlocks; + // values from the cloned region. for (auto *BB : OtherExits) { for (auto &II : *BB) { @@ -800,27 +840,30 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, "Breaks the definition of dedicated exits!"); } #endif - // Update the dominator info because the immediate dominator is no longer the - // header of the original Loop. BB has edges both from L and remainder code. - // Since the preheader determines which loop is run (L or directly jump to - // the remainder code), we set the immediate dominator as the preheader. - if (DT) { - DT->changeImmediateDominator(BB, PreHeader); - // Also update the IDom for immediate successors of BB. If the current - // IDom is the header, update the IDom to be the preheader because that is - // the nearest common dominator of all predecessors of SuccBB. We need to - // check for IDom being the header because successors of exit blocks can - // have edges from outside the loop, and we should not incorrectly update - // the IDom in that case. - for (BasicBlock *SuccBB: successors(BB)) - if (ImmediateSuccessorsOfExitBlocks.insert(SuccBB).second) { - if (DT->getNode(SuccBB)->getIDom()->getBlock() == Header) { - assert(!SuccBB->getSinglePredecessor() && - "BB should be the IDom then!"); - DT->changeImmediateDominator(SuccBB, PreHeader); - } - } + } + + // Update the immediate dominator of the exit blocks and blocks that are + // reachable from the exit blocks. This is needed because we now have paths + // from both the original loop and the remainder code reaching the exit + // blocks. While the IDom of these exit blocks were from the original loop, + // now the IDom is the preheader (which decides whether the original loop or + // remainder code should run). + if (DT && !L->getExitingBlock()) { + SmallVector<BasicBlock *, 16> ChildrenToUpdate; + // NB! We have to examine the dom children of all loop blocks, not just + // those which are the IDom of the exit blocks. This is because blocks + // reachable from the exit blocks can have their IDom as the nearest common + // dominator of the exit blocks. + for (auto *BB : L->blocks()) { + auto *DomNodeBB = DT->getNode(BB); + for (auto *DomChild : DomNodeBB->getChildren()) { + auto *DomChildBB = DomChild->getBlock(); + if (!L->contains(LI->getLoopFor(DomChildBB))) + ChildrenToUpdate.push_back(DomChildBB); + } } + for (auto *BB : ChildrenToUpdate) + DT->changeImmediateDominator(BB, PreHeader); } // Loop structure should be the following: @@ -884,6 +927,12 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // of its parent loops, so the Scalar Evolution pass needs to be run again. SE->forgetTopmostLoop(L); + // Verify that the Dom Tree is correct. +#if defined(EXPENSIVE_CHECKS) && !defined(NDEBUG) + if (DT) + assert(DT->verify(DominatorTree::VerificationLevel::Full)); +#endif + // Canonicalize to LoopSimplifyForm both original and remainder loops. We // cannot rely on the LoopUnrollPass to do this because it only does // canonicalization for parent/subloops and not the sibling loops. @@ -897,16 +946,20 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, formDedicatedExitBlocks(remainderLoop, DT, LI, PreserveLCSSA); } + auto UnrollResult = LoopUnrollResult::Unmodified; if (remainderLoop && UnrollRemainder) { LLVM_DEBUG(dbgs() << "Unrolling remainder loop\n"); - UnrollLoop(remainderLoop, /*Count*/ Count - 1, /*TripCount*/ Count - 1, - /*Force*/ false, /*AllowRuntime*/ false, - /*AllowExpensiveTripCount*/ false, /*PreserveCondBr*/ true, - /*PreserveOnlyFirst*/ false, /*TripMultiple*/ 1, - /*PeelCount*/ 0, /*UnrollRemainder*/ false, LI, SE, DT, AC, - /*ORE*/ nullptr, PreserveLCSSA); + UnrollResult = + UnrollLoop(remainderLoop, /*Count*/ Count - 1, /*TripCount*/ Count - 1, + /*Force*/ false, /*AllowRuntime*/ false, + /*AllowExpensiveTripCount*/ false, /*PreserveCondBr*/ true, + /*PreserveOnlyFirst*/ false, /*TripMultiple*/ 1, + /*PeelCount*/ 0, /*UnrollRemainder*/ false, LI, SE, DT, AC, + /*ORE*/ nullptr, PreserveLCSSA); } + if (ResultLoop && UnrollResult != LoopUnrollResult::FullyUnrolled) + *ResultLoop = remainderLoop; NumRuntimeUnrolled++; return true; } diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index 46af120a428b..a93d1aeb62ef 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -26,8 +26,11 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/DomTreeUpdater.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" @@ -41,1104 +44,7 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "loop-utils" -bool RecurrenceDescriptor::areAllUsesIn(Instruction *I, - SmallPtrSetImpl<Instruction *> &Set) { - for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; ++Use) - if (!Set.count(dyn_cast<Instruction>(*Use))) - return false; - return true; -} - -bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurrenceKind Kind) { - switch (Kind) { - default: - break; - case RK_IntegerAdd: - case RK_IntegerMult: - case RK_IntegerOr: - case RK_IntegerAnd: - case RK_IntegerXor: - case RK_IntegerMinMax: - return true; - } - return false; -} - -bool RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind Kind) { - return (Kind != RK_NoRecurrence) && !isIntegerRecurrenceKind(Kind); -} - -bool RecurrenceDescriptor::isArithmeticRecurrenceKind(RecurrenceKind Kind) { - switch (Kind) { - default: - break; - case RK_IntegerAdd: - case RK_IntegerMult: - case RK_FloatAdd: - case RK_FloatMult: - return true; - } - return false; -} - -/// Determines if Phi may have been type-promoted. If Phi has a single user -/// that ANDs the Phi with a type mask, return the user. RT is updated to -/// account for the narrower bit width represented by the mask, and the AND -/// instruction is added to CI. -static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, - SmallPtrSetImpl<Instruction *> &Visited, - SmallPtrSetImpl<Instruction *> &CI) { - if (!Phi->hasOneUse()) - return Phi; - - const APInt *M = nullptr; - Instruction *I, *J = cast<Instruction>(Phi->use_begin()->getUser()); - - // Matches either I & 2^x-1 or 2^x-1 & I. If we find a match, we update RT - // with a new integer type of the corresponding bit width. - if (match(J, m_c_And(m_Instruction(I), m_APInt(M)))) { - int32_t Bits = (*M + 1).exactLogBase2(); - if (Bits > 0) { - RT = IntegerType::get(Phi->getContext(), Bits); - Visited.insert(Phi); - CI.insert(J); - return J; - } - } - return Phi; -} - -/// Compute the minimal bit width needed to represent a reduction whose exit -/// instruction is given by Exit. -static std::pair<Type *, bool> computeRecurrenceType(Instruction *Exit, - DemandedBits *DB, - AssumptionCache *AC, - DominatorTree *DT) { - bool IsSigned = false; - const DataLayout &DL = Exit->getModule()->getDataLayout(); - uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType()); - - if (DB) { - // Use the demanded bits analysis to determine the bits that are live out - // of the exit instruction, rounding up to the nearest power of two. If the - // use of demanded bits results in a smaller bit width, we know the value - // must be positive (i.e., IsSigned = false), because if this were not the - // case, the sign bit would have been demanded. - auto Mask = DB->getDemandedBits(Exit); - MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); - } - - if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { - // If demanded bits wasn't able to limit the bit width, we can try to use - // value tracking instead. This can be the case, for example, if the value - // may be negative. - auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT); - auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType()); - MaxBitWidth = NumTypeBits - NumSignBits; - KnownBits Bits = computeKnownBits(Exit, DL); - if (!Bits.isNonNegative()) { - // If the value is not known to be non-negative, we set IsSigned to true, - // meaning that we will use sext instructions instead of zext - // instructions to restore the original type. - IsSigned = true; - if (!Bits.isNegative()) - // If the value is not known to be negative, we don't known what the - // upper bit is, and therefore, we don't know what kind of extend we - // will need. In this case, just increase the bit width by one bit and - // use sext. - ++MaxBitWidth; - } - } - if (!isPowerOf2_64(MaxBitWidth)) - MaxBitWidth = NextPowerOf2(MaxBitWidth); - - return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth), - IsSigned); -} - -/// Collect cast instructions that can be ignored in the vectorizer's cost -/// model, given a reduction exit value and the minimal type in which the -/// reduction can be represented. -static void collectCastsToIgnore(Loop *TheLoop, Instruction *Exit, - Type *RecurrenceType, - SmallPtrSetImpl<Instruction *> &Casts) { - - SmallVector<Instruction *, 8> Worklist; - SmallPtrSet<Instruction *, 8> Visited; - Worklist.push_back(Exit); - - while (!Worklist.empty()) { - Instruction *Val = Worklist.pop_back_val(); - Visited.insert(Val); - if (auto *Cast = dyn_cast<CastInst>(Val)) - if (Cast->getSrcTy() == RecurrenceType) { - // If the source type of a cast instruction is equal to the recurrence - // type, it will be eliminated, and should be ignored in the vectorizer - // cost model. - Casts.insert(Cast); - continue; - } - - // Add all operands to the work list if they are loop-varying values that - // we haven't yet visited. - for (Value *O : cast<User>(Val)->operands()) - if (auto *I = dyn_cast<Instruction>(O)) - if (TheLoop->contains(I) && !Visited.count(I)) - Worklist.push_back(I); - } -} - -bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, - Loop *TheLoop, bool HasFunNoNaNAttr, - RecurrenceDescriptor &RedDes, - DemandedBits *DB, - AssumptionCache *AC, - DominatorTree *DT) { - if (Phi->getNumIncomingValues() != 2) - return false; - - // Reduction variables are only found in the loop header block. - if (Phi->getParent() != TheLoop->getHeader()) - return false; - - // Obtain the reduction start value from the value that comes from the loop - // preheader. - Value *RdxStart = Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader()); - - // ExitInstruction is the single value which is used outside the loop. - // We only allow for a single reduction value to be used outside the loop. - // This includes users of the reduction, variables (which form a cycle - // which ends in the phi node). - Instruction *ExitInstruction = nullptr; - // Indicates that we found a reduction operation in our scan. - bool FoundReduxOp = false; - - // We start with the PHI node and scan for all of the users of this - // instruction. All users must be instructions that can be used as reduction - // variables (such as ADD). We must have a single out-of-block user. The cycle - // must include the original PHI. - bool FoundStartPHI = false; - - // To recognize min/max patterns formed by a icmp select sequence, we store - // the number of instruction we saw from the recognized min/max pattern, - // to make sure we only see exactly the two instructions. - unsigned NumCmpSelectPatternInst = 0; - InstDesc ReduxDesc(false, nullptr); - - // Data used for determining if the recurrence has been type-promoted. - Type *RecurrenceType = Phi->getType(); - SmallPtrSet<Instruction *, 4> CastInsts; - Instruction *Start = Phi; - bool IsSigned = false; - - SmallPtrSet<Instruction *, 8> VisitedInsts; - SmallVector<Instruction *, 8> Worklist; - - // Return early if the recurrence kind does not match the type of Phi. If the - // recurrence kind is arithmetic, we attempt to look through AND operations - // resulting from the type promotion performed by InstCombine. Vector - // operations are not limited to the legal integer widths, so we may be able - // to evaluate the reduction in the narrower width. - if (RecurrenceType->isFloatingPointTy()) { - if (!isFloatingPointRecurrenceKind(Kind)) - return false; - } else { - if (!isIntegerRecurrenceKind(Kind)) - return false; - if (isArithmeticRecurrenceKind(Kind)) - Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts); - } - - Worklist.push_back(Start); - VisitedInsts.insert(Start); - - // A value in the reduction can be used: - // - By the reduction: - // - Reduction operation: - // - One use of reduction value (safe). - // - Multiple use of reduction value (not safe). - // - PHI: - // - All uses of the PHI must be the reduction (safe). - // - Otherwise, not safe. - // - By instructions outside of the loop (safe). - // * One value may have several outside users, but all outside - // uses must be of the same value. - // - By an instruction that is not part of the reduction (not safe). - // This is either: - // * An instruction type other than PHI or the reduction operation. - // * A PHI in the header other than the initial PHI. - while (!Worklist.empty()) { - Instruction *Cur = Worklist.back(); - Worklist.pop_back(); - - // No Users. - // If the instruction has no users then this is a broken chain and can't be - // a reduction variable. - if (Cur->use_empty()) - return false; - - bool IsAPhi = isa<PHINode>(Cur); - - // A header PHI use other than the original PHI. - if (Cur != Phi && IsAPhi && Cur->getParent() == Phi->getParent()) - return false; - - // Reductions of instructions such as Div, and Sub is only possible if the - // LHS is the reduction variable. - if (!Cur->isCommutative() && !IsAPhi && !isa<SelectInst>(Cur) && - !isa<ICmpInst>(Cur) && !isa<FCmpInst>(Cur) && - !VisitedInsts.count(dyn_cast<Instruction>(Cur->getOperand(0)))) - return false; - - // Any reduction instruction must be of one of the allowed kinds. We ignore - // the starting value (the Phi or an AND instruction if the Phi has been - // type-promoted). - if (Cur != Start) { - ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); - if (!ReduxDesc.isRecurrence()) - return false; - } - - // A reduction operation must only have one use of the reduction value. - if (!IsAPhi && Kind != RK_IntegerMinMax && Kind != RK_FloatMinMax && - hasMultipleUsesOf(Cur, VisitedInsts)) - return false; - - // All inputs to a PHI node must be a reduction value. - if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts)) - return false; - - if (Kind == RK_IntegerMinMax && - (isa<ICmpInst>(Cur) || isa<SelectInst>(Cur))) - ++NumCmpSelectPatternInst; - if (Kind == RK_FloatMinMax && (isa<FCmpInst>(Cur) || isa<SelectInst>(Cur))) - ++NumCmpSelectPatternInst; - - // Check whether we found a reduction operator. - FoundReduxOp |= !IsAPhi && Cur != Start; - - // Process users of current instruction. Push non-PHI nodes after PHI nodes - // onto the stack. This way we are going to have seen all inputs to PHI - // nodes once we get to them. - SmallVector<Instruction *, 8> NonPHIs; - SmallVector<Instruction *, 8> PHIs; - for (User *U : Cur->users()) { - Instruction *UI = cast<Instruction>(U); - - // Check if we found the exit user. - BasicBlock *Parent = UI->getParent(); - if (!TheLoop->contains(Parent)) { - // If we already know this instruction is used externally, move on to - // the next user. - if (ExitInstruction == Cur) - continue; - - // Exit if you find multiple values used outside or if the header phi - // node is being used. In this case the user uses the value of the - // previous iteration, in which case we would loose "VF-1" iterations of - // the reduction operation if we vectorize. - if (ExitInstruction != nullptr || Cur == Phi) - return false; - - // The instruction used by an outside user must be the last instruction - // before we feed back to the reduction phi. Otherwise, we loose VF-1 - // operations on the value. - if (!is_contained(Phi->operands(), Cur)) - return false; - - ExitInstruction = Cur; - continue; - } - - // Process instructions only once (termination). Each reduction cycle - // value must only be used once, except by phi nodes and min/max - // reductions which are represented as a cmp followed by a select. - InstDesc IgnoredVal(false, nullptr); - if (VisitedInsts.insert(UI).second) { - if (isa<PHINode>(UI)) - PHIs.push_back(UI); - else - NonPHIs.push_back(UI); - } else if (!isa<PHINode>(UI) && - ((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) && - !isa<SelectInst>(UI)) || - !isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence())) - return false; - - // Remember that we completed the cycle. - if (UI == Phi) - FoundStartPHI = true; - } - Worklist.append(PHIs.begin(), PHIs.end()); - Worklist.append(NonPHIs.begin(), NonPHIs.end()); - } - - // This means we have seen one but not the other instruction of the - // pattern or more than just a select and cmp. - if ((Kind == RK_IntegerMinMax || Kind == RK_FloatMinMax) && - NumCmpSelectPatternInst != 2) - return false; - - if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) - return false; - - if (Start != Phi) { - // If the starting value is not the same as the phi node, we speculatively - // looked through an 'and' instruction when evaluating a potential - // arithmetic reduction to determine if it may have been type-promoted. - // - // We now compute the minimal bit width that is required to represent the - // reduction. If this is the same width that was indicated by the 'and', we - // can represent the reduction in the smaller type. The 'and' instruction - // will be eliminated since it will essentially be a cast instruction that - // can be ignore in the cost model. If we compute a different type than we - // did when evaluating the 'and', the 'and' will not be eliminated, and we - // will end up with different kinds of operations in the recurrence - // expression (e.g., RK_IntegerAND, RK_IntegerADD). We give up if this is - // the case. - // - // The vectorizer relies on InstCombine to perform the actual - // type-shrinking. It does this by inserting instructions to truncate the - // exit value of the reduction to the width indicated by RecurrenceType and - // then extend this value back to the original width. If IsSigned is false, - // a 'zext' instruction will be generated; otherwise, a 'sext' will be - // used. - // - // TODO: We should not rely on InstCombine to rewrite the reduction in the - // smaller type. We should just generate a correctly typed expression - // to begin with. - Type *ComputedType; - std::tie(ComputedType, IsSigned) = - computeRecurrenceType(ExitInstruction, DB, AC, DT); - if (ComputedType != RecurrenceType) - return false; - - // The recurrence expression will be represented in a narrower type. If - // there are any cast instructions that will be unnecessary, collect them - // in CastInsts. Note that the 'and' instruction was already included in - // this list. - // - // TODO: A better way to represent this may be to tag in some way all the - // instructions that are a part of the reduction. The vectorizer cost - // model could then apply the recurrence type to these instructions, - // without needing a white list of instructions to ignore. - collectCastsToIgnore(TheLoop, ExitInstruction, RecurrenceType, CastInsts); - } - - // We found a reduction var if we have reached the original phi node and we - // only have a single instruction with out-of-loop users. - - // The ExitInstruction(Instruction which is allowed to have out-of-loop users) - // is saved as part of the RecurrenceDescriptor. - - // Save the description of this reduction variable. - RecurrenceDescriptor RD( - RdxStart, ExitInstruction, Kind, ReduxDesc.getMinMaxKind(), - ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts); - RedDes = RD; - - return true; -} - -/// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction -/// pattern corresponding to a min(X, Y) or max(X, Y). -RecurrenceDescriptor::InstDesc -RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev) { - - assert((isa<ICmpInst>(I) || isa<FCmpInst>(I) || isa<SelectInst>(I)) && - "Expect a select instruction"); - Instruction *Cmp = nullptr; - SelectInst *Select = nullptr; - - // We must handle the select(cmp()) as a single instruction. Advance to the - // select. - if ((Cmp = dyn_cast<ICmpInst>(I)) || (Cmp = dyn_cast<FCmpInst>(I))) { - if (!Cmp->hasOneUse() || !(Select = dyn_cast<SelectInst>(*I->user_begin()))) - return InstDesc(false, I); - return InstDesc(Select, Prev.getMinMaxKind()); - } - - // Only handle single use cases for now. - if (!(Select = dyn_cast<SelectInst>(I))) - return InstDesc(false, I); - if (!(Cmp = dyn_cast<ICmpInst>(I->getOperand(0))) && - !(Cmp = dyn_cast<FCmpInst>(I->getOperand(0)))) - return InstDesc(false, I); - if (!Cmp->hasOneUse()) - return InstDesc(false, I); - - Value *CmpLeft; - Value *CmpRight; - - // Look for a min/max pattern. - if (m_UMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_UIntMin); - else if (m_UMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_UIntMax); - else if (m_SMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_SIntMax); - else if (m_SMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_SIntMin); - else if (m_OrdFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_FloatMin); - else if (m_OrdFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_FloatMax); - else if (m_UnordFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_FloatMin); - else if (m_UnordFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) - return InstDesc(Select, MRK_FloatMax); - - return InstDesc(false, I); -} - -RecurrenceDescriptor::InstDesc -RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind, - InstDesc &Prev, bool HasFunNoNaNAttr) { - bool FP = I->getType()->isFloatingPointTy(); - Instruction *UAI = Prev.getUnsafeAlgebraInst(); - if (!UAI && FP && !I->isFast()) - UAI = I; // Found an unsafe (unvectorizable) algebra instruction. - - switch (I->getOpcode()) { - default: - return InstDesc(false, I); - case Instruction::PHI: - return InstDesc(I, Prev.getMinMaxKind(), Prev.getUnsafeAlgebraInst()); - case Instruction::Sub: - case Instruction::Add: - return InstDesc(Kind == RK_IntegerAdd, I); - case Instruction::Mul: - return InstDesc(Kind == RK_IntegerMult, I); - case Instruction::And: - return InstDesc(Kind == RK_IntegerAnd, I); - case Instruction::Or: - return InstDesc(Kind == RK_IntegerOr, I); - case Instruction::Xor: - return InstDesc(Kind == RK_IntegerXor, I); - case Instruction::FMul: - return InstDesc(Kind == RK_FloatMult, I, UAI); - case Instruction::FSub: - case Instruction::FAdd: - return InstDesc(Kind == RK_FloatAdd, I, UAI); - case Instruction::FCmp: - case Instruction::ICmp: - case Instruction::Select: - if (Kind != RK_IntegerMinMax && - (!HasFunNoNaNAttr || Kind != RK_FloatMinMax)) - return InstDesc(false, I); - return isMinMaxSelectCmpPattern(I, Prev); - } -} - -bool RecurrenceDescriptor::hasMultipleUsesOf( - Instruction *I, SmallPtrSetImpl<Instruction *> &Insts) { - unsigned NumUses = 0; - for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; - ++Use) { - if (Insts.count(dyn_cast<Instruction>(*Use))) - ++NumUses; - if (NumUses > 1) - return true; - } - - return false; -} -bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, - RecurrenceDescriptor &RedDes, - DemandedBits *DB, AssumptionCache *AC, - DominatorTree *DT) { - - BasicBlock *Header = TheLoop->getHeader(); - Function &F = *Header->getParent(); - bool HasFunNoNaNAttr = - F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; - - if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, RedDes, - DB, AC, DT)) { - LLVM_DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); - return true; - } - if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes, DB, - AC, DT)) { - LLVM_DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi - << "\n"); - return true; - } - // Not a reduction of known type. - return false; -} - -bool RecurrenceDescriptor::isFirstOrderRecurrence( - PHINode *Phi, Loop *TheLoop, - DenseMap<Instruction *, Instruction *> &SinkAfter, DominatorTree *DT) { - - // Ensure the phi node is in the loop header and has two incoming values. - if (Phi->getParent() != TheLoop->getHeader() || - Phi->getNumIncomingValues() != 2) - return false; - - // Ensure the loop has a preheader and a single latch block. The loop - // vectorizer will need the latch to set up the next iteration of the loop. - auto *Preheader = TheLoop->getLoopPreheader(); - auto *Latch = TheLoop->getLoopLatch(); - if (!Preheader || !Latch) - return false; - - // Ensure the phi node's incoming blocks are the loop preheader and latch. - if (Phi->getBasicBlockIndex(Preheader) < 0 || - Phi->getBasicBlockIndex(Latch) < 0) - return false; - - // Get the previous value. The previous value comes from the latch edge while - // the initial value comes form the preheader edge. - auto *Previous = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch)); - if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous) || - SinkAfter.count(Previous)) // Cannot rely on dominance due to motion. - return false; - - // Ensure every user of the phi node is dominated by the previous value. - // The dominance requirement ensures the loop vectorizer will not need to - // vectorize the initial value prior to the first iteration of the loop. - // TODO: Consider extending this sinking to handle other kinds of instructions - // and expressions, beyond sinking a single cast past Previous. - if (Phi->hasOneUse()) { - auto *I = Phi->user_back(); - if (I->isCast() && (I->getParent() == Phi->getParent()) && I->hasOneUse() && - DT->dominates(Previous, I->user_back())) { - if (!DT->dominates(Previous, I)) // Otherwise we're good w/o sinking. - SinkAfter[I] = Previous; - return true; - } - } - - for (User *U : Phi->users()) - if (auto *I = dyn_cast<Instruction>(U)) { - if (!DT->dominates(Previous, I)) - return false; - } - - return true; -} - -/// This function returns the identity element (or neutral element) for -/// the operation K. -Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurrenceKind K, - Type *Tp) { - switch (K) { - case RK_IntegerXor: - case RK_IntegerAdd: - case RK_IntegerOr: - // Adding, Xoring, Oring zero to a number does not change it. - return ConstantInt::get(Tp, 0); - case RK_IntegerMult: - // Multiplying a number by 1 does not change it. - return ConstantInt::get(Tp, 1); - case RK_IntegerAnd: - // AND-ing a number with an all-1 value does not change it. - return ConstantInt::get(Tp, -1, true); - case RK_FloatMult: - // Multiplying a number by 1 does not change it. - return ConstantFP::get(Tp, 1.0L); - case RK_FloatAdd: - // Adding zero to a number does not change it. - return ConstantFP::get(Tp, 0.0L); - default: - llvm_unreachable("Unknown recurrence kind"); - } -} - -/// This function translates the recurrence kind to an LLVM binary operator. -unsigned RecurrenceDescriptor::getRecurrenceBinOp(RecurrenceKind Kind) { - switch (Kind) { - case RK_IntegerAdd: - return Instruction::Add; - case RK_IntegerMult: - return Instruction::Mul; - case RK_IntegerOr: - return Instruction::Or; - case RK_IntegerAnd: - return Instruction::And; - case RK_IntegerXor: - return Instruction::Xor; - case RK_FloatMult: - return Instruction::FMul; - case RK_FloatAdd: - return Instruction::FAdd; - case RK_IntegerMinMax: - return Instruction::ICmp; - case RK_FloatMinMax: - return Instruction::FCmp; - default: - llvm_unreachable("Unknown recurrence operation"); - } -} - -Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder, - MinMaxRecurrenceKind RK, - Value *Left, Value *Right) { - CmpInst::Predicate P = CmpInst::ICMP_NE; - switch (RK) { - default: - llvm_unreachable("Unknown min/max recurrence kind"); - case MRK_UIntMin: - P = CmpInst::ICMP_ULT; - break; - case MRK_UIntMax: - P = CmpInst::ICMP_UGT; - break; - case MRK_SIntMin: - P = CmpInst::ICMP_SLT; - break; - case MRK_SIntMax: - P = CmpInst::ICMP_SGT; - break; - case MRK_FloatMin: - P = CmpInst::FCMP_OLT; - break; - case MRK_FloatMax: - P = CmpInst::FCMP_OGT; - break; - } - - // We only match FP sequences that are 'fast', so we can unconditionally - // set it on any generated instructions. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - FastMathFlags FMF; - FMF.setFast(); - Builder.setFastMathFlags(FMF); - - Value *Cmp; - if (RK == MRK_FloatMin || RK == MRK_FloatMax) - Cmp = Builder.CreateFCmp(P, Left, Right, "rdx.minmax.cmp"); - else - Cmp = Builder.CreateICmp(P, Left, Right, "rdx.minmax.cmp"); - - Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); - return Select; -} - -InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, - const SCEV *Step, BinaryOperator *BOp, - SmallVectorImpl<Instruction *> *Casts) - : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) { - assert(IK != IK_NoInduction && "Not an induction"); - - // Start value type should match the induction kind and the value - // itself should not be null. - assert(StartValue && "StartValue is null"); - assert((IK != IK_PtrInduction || StartValue->getType()->isPointerTy()) && - "StartValue is not a pointer for pointer induction"); - assert((IK != IK_IntInduction || StartValue->getType()->isIntegerTy()) && - "StartValue is not an integer for integer induction"); - - // Check the Step Value. It should be non-zero integer value. - assert((!getConstIntStepValue() || !getConstIntStepValue()->isZero()) && - "Step value is zero"); - - assert((IK != IK_PtrInduction || getConstIntStepValue()) && - "Step value should be constant for pointer induction"); - assert((IK == IK_FpInduction || Step->getType()->isIntegerTy()) && - "StepValue is not an integer"); - - assert((IK != IK_FpInduction || Step->getType()->isFloatingPointTy()) && - "StepValue is not FP for FpInduction"); - assert((IK != IK_FpInduction || (InductionBinOp && - (InductionBinOp->getOpcode() == Instruction::FAdd || - InductionBinOp->getOpcode() == Instruction::FSub))) && - "Binary opcode should be specified for FP induction"); - - if (Casts) { - for (auto &Inst : *Casts) { - RedundantCasts.push_back(Inst); - } - } -} - -int InductionDescriptor::getConsecutiveDirection() const { - ConstantInt *ConstStep = getConstIntStepValue(); - if (ConstStep && (ConstStep->isOne() || ConstStep->isMinusOne())) - return ConstStep->getSExtValue(); - return 0; -} - -ConstantInt *InductionDescriptor::getConstIntStepValue() const { - if (isa<SCEVConstant>(Step)) - return dyn_cast<ConstantInt>(cast<SCEVConstant>(Step)->getValue()); - return nullptr; -} - -Value *InductionDescriptor::transform(IRBuilder<> &B, Value *Index, - ScalarEvolution *SE, - const DataLayout& DL) const { - - SCEVExpander Exp(*SE, DL, "induction"); - assert(Index->getType() == Step->getType() && - "Index type does not match StepValue type"); - switch (IK) { - case IK_IntInduction: { - assert(Index->getType() == StartValue->getType() && - "Index type does not match StartValue type"); - - // FIXME: Theoretically, we can call getAddExpr() of ScalarEvolution - // and calculate (Start + Index * Step) for all cases, without - // special handling for "isOne" and "isMinusOne". - // But in the real life the result code getting worse. We mix SCEV - // expressions and ADD/SUB operations and receive redundant - // intermediate values being calculated in different ways and - // Instcombine is unable to reduce them all. - - if (getConstIntStepValue() && - getConstIntStepValue()->isMinusOne()) - return B.CreateSub(StartValue, Index); - if (getConstIntStepValue() && - getConstIntStepValue()->isOne()) - return B.CreateAdd(StartValue, Index); - const SCEV *S = SE->getAddExpr(SE->getSCEV(StartValue), - SE->getMulExpr(Step, SE->getSCEV(Index))); - return Exp.expandCodeFor(S, StartValue->getType(), &*B.GetInsertPoint()); - } - case IK_PtrInduction: { - assert(isa<SCEVConstant>(Step) && - "Expected constant step for pointer induction"); - const SCEV *S = SE->getMulExpr(SE->getSCEV(Index), Step); - Index = Exp.expandCodeFor(S, Index->getType(), &*B.GetInsertPoint()); - return B.CreateGEP(nullptr, StartValue, Index); - } - case IK_FpInduction: { - assert(Step->getType()->isFloatingPointTy() && "Expected FP Step value"); - assert(InductionBinOp && - (InductionBinOp->getOpcode() == Instruction::FAdd || - InductionBinOp->getOpcode() == Instruction::FSub) && - "Original bin op should be defined for FP induction"); - - Value *StepValue = cast<SCEVUnknown>(Step)->getValue(); - - // Floating point operations had to be 'fast' to enable the induction. - FastMathFlags Flags; - Flags.setFast(); - - Value *MulExp = B.CreateFMul(StepValue, Index); - if (isa<Instruction>(MulExp)) - // We have to check, the MulExp may be a constant. - cast<Instruction>(MulExp)->setFastMathFlags(Flags); - - Value *BOp = B.CreateBinOp(InductionBinOp->getOpcode() , StartValue, - MulExp, "induction"); - if (isa<Instruction>(BOp)) - cast<Instruction>(BOp)->setFastMathFlags(Flags); - - return BOp; - } - case IK_NoInduction: - return nullptr; - } - llvm_unreachable("invalid enum"); -} - -bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, - ScalarEvolution *SE, - InductionDescriptor &D) { - - // Here we only handle FP induction variables. - assert(Phi->getType()->isFloatingPointTy() && "Unexpected Phi type"); - - if (TheLoop->getHeader() != Phi->getParent()) - return false; - - // The loop may have multiple entrances or multiple exits; we can analyze - // this phi if it has a unique entry value and a unique backedge value. - if (Phi->getNumIncomingValues() != 2) - return false; - Value *BEValue = nullptr, *StartValue = nullptr; - if (TheLoop->contains(Phi->getIncomingBlock(0))) { - BEValue = Phi->getIncomingValue(0); - StartValue = Phi->getIncomingValue(1); - } else { - assert(TheLoop->contains(Phi->getIncomingBlock(1)) && - "Unexpected Phi node in the loop"); - BEValue = Phi->getIncomingValue(1); - StartValue = Phi->getIncomingValue(0); - } - - BinaryOperator *BOp = dyn_cast<BinaryOperator>(BEValue); - if (!BOp) - return false; - - Value *Addend = nullptr; - if (BOp->getOpcode() == Instruction::FAdd) { - if (BOp->getOperand(0) == Phi) - Addend = BOp->getOperand(1); - else if (BOp->getOperand(1) == Phi) - Addend = BOp->getOperand(0); - } else if (BOp->getOpcode() == Instruction::FSub) - if (BOp->getOperand(0) == Phi) - Addend = BOp->getOperand(1); - - if (!Addend) - return false; - - // The addend should be loop invariant - if (auto *I = dyn_cast<Instruction>(Addend)) - if (TheLoop->contains(I)) - return false; - - // FP Step has unknown SCEV - const SCEV *Step = SE->getUnknown(Addend); - D = InductionDescriptor(StartValue, IK_FpInduction, Step, BOp); - return true; -} - -/// This function is called when we suspect that the update-chain of a phi node -/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, -/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime -/// predicate P under which the SCEV expression for the phi can be the -/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the -/// cast instructions that are involved in the update-chain of this induction. -/// A caller that adds the required runtime predicate can be free to drop these -/// cast instructions, and compute the phi using \p AR (instead of some scev -/// expression with casts). -/// -/// For example, without a predicate the scev expression can take the following -/// form: -/// (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy) -/// -/// It corresponds to the following IR sequence: -/// %for.body: -/// %x = phi i64 [ 0, %ph ], [ %add, %for.body ] -/// %casted_phi = "ExtTrunc i64 %x" -/// %add = add i64 %casted_phi, %step -/// -/// where %x is given in \p PN, -/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate, -/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of -/// several forms, for example, such as: -/// ExtTrunc1: %casted_phi = and %x, 2^n-1 -/// or: -/// ExtTrunc2: %t = shl %x, m -/// %casted_phi = ashr %t, m -/// -/// If we are able to find such sequence, we return the instructions -/// we found, namely %casted_phi and the instructions on its use-def chain up -/// to the phi (not including the phi). -static bool getCastsForInductionPHI(PredicatedScalarEvolution &PSE, - const SCEVUnknown *PhiScev, - const SCEVAddRecExpr *AR, - SmallVectorImpl<Instruction *> &CastInsts) { - - assert(CastInsts.empty() && "CastInsts is expected to be empty."); - auto *PN = cast<PHINode>(PhiScev->getValue()); - assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression"); - const Loop *L = AR->getLoop(); - - // Find any cast instructions that participate in the def-use chain of - // PhiScev in the loop. - // FORNOW/TODO: We currently expect the def-use chain to include only - // two-operand instructions, where one of the operands is an invariant. - // createAddRecFromPHIWithCasts() currently does not support anything more - // involved than that, so we keep the search simple. This can be - // extended/generalized as needed. - - auto getDef = [&](const Value *Val) -> Value * { - const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val); - if (!BinOp) - return nullptr; - Value *Op0 = BinOp->getOperand(0); - Value *Op1 = BinOp->getOperand(1); - Value *Def = nullptr; - if (L->isLoopInvariant(Op0)) - Def = Op1; - else if (L->isLoopInvariant(Op1)) - Def = Op0; - return Def; - }; - - // Look for the instruction that defines the induction via the - // loop backedge. - BasicBlock *Latch = L->getLoopLatch(); - if (!Latch) - return false; - Value *Val = PN->getIncomingValueForBlock(Latch); - if (!Val) - return false; - - // Follow the def-use chain until the induction phi is reached. - // If on the way we encounter a Value that has the same SCEV Expr as the - // phi node, we can consider the instructions we visit from that point - // as part of the cast-sequence that can be ignored. - bool InCastSequence = false; - auto *Inst = dyn_cast<Instruction>(Val); - while (Val != PN) { - // If we encountered a phi node other than PN, or if we left the loop, - // we bail out. - if (!Inst || !L->contains(Inst)) { - return false; - } - auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val)); - if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR)) - InCastSequence = true; - if (InCastSequence) { - // Only the last instruction in the cast sequence is expected to have - // uses outside the induction def-use chain. - if (!CastInsts.empty()) - if (!Inst->hasOneUse()) - return false; - CastInsts.push_back(Inst); - } - Val = getDef(Val); - if (!Val) - return false; - Inst = dyn_cast<Instruction>(Val); - } - - return InCastSequence; -} - -bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, - PredicatedScalarEvolution &PSE, - InductionDescriptor &D, - bool Assume) { - Type *PhiTy = Phi->getType(); - - // Handle integer and pointer inductions variables. - // Now we handle also FP induction but not trying to make a - // recurrent expression from the PHI node in-place. - - if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy() && - !PhiTy->isFloatTy() && !PhiTy->isDoubleTy() && !PhiTy->isHalfTy()) - return false; - - if (PhiTy->isFloatingPointTy()) - return isFPInductionPHI(Phi, TheLoop, PSE.getSE(), D); - - const SCEV *PhiScev = PSE.getSCEV(Phi); - const auto *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); - - // We need this expression to be an AddRecExpr. - if (Assume && !AR) - AR = PSE.getAsAddRec(Phi); - - if (!AR) { - LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); - return false; - } - - // Record any Cast instructions that participate in the induction update - const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev); - // If we started from an UnknownSCEV, and managed to build an addRecurrence - // only after enabling Assume with PSCEV, this means we may have encountered - // cast instructions that required adding a runtime check in order to - // guarantee the correctness of the AddRecurence respresentation of the - // induction. - if (PhiScev != AR && SymbolicPhi) { - SmallVector<Instruction *, 2> Casts; - if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts)) - return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts); - } - - return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR); -} - -bool InductionDescriptor::isInductionPHI( - PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE, - InductionDescriptor &D, const SCEV *Expr, - SmallVectorImpl<Instruction *> *CastsToIgnore) { - Type *PhiTy = Phi->getType(); - // We only handle integer and pointer inductions variables. - if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) - return false; - - // Check that the PHI is consecutive. - const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi); - const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); - - if (!AR) { - LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); - return false; - } - - if (AR->getLoop() != TheLoop) { - // FIXME: We should treat this as a uniform. Unfortunately, we - // don't currently know how to handled uniform PHIs. - LLVM_DEBUG( - dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); - return false; - } - - Value *StartValue = - Phi->getIncomingValueForBlock(AR->getLoop()->getLoopPreheader()); - const SCEV *Step = AR->getStepRecurrence(*SE); - // Calculate the pointer stride and check if it is consecutive. - // The stride may be a constant or a loop invariant integer value. - const SCEVConstant *ConstStep = dyn_cast<SCEVConstant>(Step); - if (!ConstStep && !SE->isLoopInvariant(Step, TheLoop)) - return false; - - if (PhiTy->isIntegerTy()) { - D = InductionDescriptor(StartValue, IK_IntInduction, Step, /*BOp=*/ nullptr, - CastsToIgnore); - return true; - } - - assert(PhiTy->isPointerTy() && "The PHI must be a pointer"); - // Pointer induction should be a constant. - if (!ConstStep) - return false; - - ConstantInt *CV = ConstStep->getValue(); - Type *PointerElementType = PhiTy->getPointerElementType(); - // The pointer stride cannot be determined if the pointer element type is not - // sized. - if (!PointerElementType->isSized()) - return false; - - const DataLayout &DL = Phi->getModule()->getDataLayout(); - int64_t Size = static_cast<int64_t>(DL.getTypeAllocSize(PointerElementType)); - if (!Size) - return false; - - int64_t CVSize = CV->getSExtValue(); - if (CVSize % Size) - return false; - auto *StepValue = SE->getConstant(CV->getType(), CVSize / Size, - true /* signed */); - D = InductionDescriptor(StartValue, IK_PtrInduction, StepValue); - return true; -} +static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced"; bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, bool PreserveLCSSA) { @@ -1173,7 +79,7 @@ bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI, return false; auto *NewExitBB = SplitBlockPredecessors( - BB, InLoopPredecessors, ".loopexit", DT, LI, PreserveLCSSA); + BB, InLoopPredecessors, ".loopexit", DT, LI, nullptr, PreserveLCSSA); if (!NewExitBB) LLVM_DEBUG( @@ -1286,37 +192,231 @@ void llvm::initializeLoopPassPass(PassRegistry &Registry) { /// If it has a value (e.g. {"llvm.distribute", 1} return the value as an /// operand or null otherwise. If the string metadata is not found return /// Optional's not-a-value. -Optional<const MDOperand *> llvm::findStringMetadataForLoop(Loop *TheLoop, +Optional<const MDOperand *> llvm::findStringMetadataForLoop(const Loop *TheLoop, StringRef Name) { - MDNode *LoopID = TheLoop->getLoopID(); - // Return none if LoopID is false. - if (!LoopID) + MDNode *MD = findOptionMDForLoop(TheLoop, Name); + if (!MD) return None; + switch (MD->getNumOperands()) { + case 1: + return nullptr; + case 2: + return &MD->getOperand(1); + default: + llvm_unreachable("loop metadata has 0 or 1 operand"); + } +} - // First operand should refer to the loop id itself. - assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); - assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); +static Optional<bool> getOptionalBoolLoopAttribute(const Loop *TheLoop, + StringRef Name) { + MDNode *MD = findOptionMDForLoop(TheLoop, Name); + if (!MD) + return None; + switch (MD->getNumOperands()) { + case 1: + // When the value is absent it is interpreted as 'attribute set'. + return true; + case 2: + return mdconst::extract_or_null<ConstantInt>(MD->getOperand(1).get()); + } + llvm_unreachable("unexpected number of options"); +} - // Iterate over LoopID operands and look for MDString Metadata - for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { - MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); - if (!MD) - continue; - MDString *S = dyn_cast<MDString>(MD->getOperand(0)); - if (!S) +static bool getBooleanLoopAttribute(const Loop *TheLoop, StringRef Name) { + return getOptionalBoolLoopAttribute(TheLoop, Name).getValueOr(false); +} + +llvm::Optional<int> llvm::getOptionalIntLoopAttribute(Loop *TheLoop, + StringRef Name) { + const MDOperand *AttrMD = + findStringMetadataForLoop(TheLoop, Name).getValueOr(nullptr); + if (!AttrMD) + return None; + + ConstantInt *IntMD = mdconst::extract_or_null<ConstantInt>(AttrMD->get()); + if (!IntMD) + return None; + + return IntMD->getSExtValue(); +} + +Optional<MDNode *> llvm::makeFollowupLoopID( + MDNode *OrigLoopID, ArrayRef<StringRef> FollowupOptions, + const char *InheritOptionsExceptPrefix, bool AlwaysNew) { + if (!OrigLoopID) { + if (AlwaysNew) + return nullptr; + return None; + } + + assert(OrigLoopID->getOperand(0) == OrigLoopID); + + bool InheritAllAttrs = !InheritOptionsExceptPrefix; + bool InheritSomeAttrs = + InheritOptionsExceptPrefix && InheritOptionsExceptPrefix[0] != '\0'; + SmallVector<Metadata *, 8> MDs; + MDs.push_back(nullptr); + + bool Changed = false; + if (InheritAllAttrs || InheritSomeAttrs) { + for (const MDOperand &Existing : drop_begin(OrigLoopID->operands(), 1)) { + MDNode *Op = cast<MDNode>(Existing.get()); + + auto InheritThisAttribute = [InheritSomeAttrs, + InheritOptionsExceptPrefix](MDNode *Op) { + if (!InheritSomeAttrs) + return false; + + // Skip malformatted attribute metadata nodes. + if (Op->getNumOperands() == 0) + return true; + Metadata *NameMD = Op->getOperand(0).get(); + if (!isa<MDString>(NameMD)) + return true; + StringRef AttrName = cast<MDString>(NameMD)->getString(); + + // Do not inherit excluded attributes. + return !AttrName.startswith(InheritOptionsExceptPrefix); + }; + + if (InheritThisAttribute(Op)) + MDs.push_back(Op); + else + Changed = true; + } + } else { + // Modified if we dropped at least one attribute. + Changed = OrigLoopID->getNumOperands() > 1; + } + + bool HasAnyFollowup = false; + for (StringRef OptionName : FollowupOptions) { + MDNode *FollowupNode = findOptionMDForLoopID(OrigLoopID, OptionName); + if (!FollowupNode) continue; - // Return true if MDString holds expected MetaData. - if (Name.equals(S->getString())) - switch (MD->getNumOperands()) { - case 1: - return nullptr; - case 2: - return &MD->getOperand(1); - default: - llvm_unreachable("loop metadata has 0 or 1 operand"); - } + + HasAnyFollowup = true; + for (const MDOperand &Option : drop_begin(FollowupNode->operands(), 1)) { + MDs.push_back(Option.get()); + Changed = true; + } } - return None; + + // Attributes of the followup loop not specified explicity, so signal to the + // transformation pass to add suitable attributes. + if (!AlwaysNew && !HasAnyFollowup) + return None; + + // If no attributes were added or remove, the previous loop Id can be reused. + if (!AlwaysNew && !Changed) + return OrigLoopID; + + // No attributes is equivalent to having no !llvm.loop metadata at all. + if (MDs.size() == 1) + return nullptr; + + // Build the new loop ID. + MDTuple *FollowupLoopID = MDNode::get(OrigLoopID->getContext(), MDs); + FollowupLoopID->replaceOperandWith(0, FollowupLoopID); + return FollowupLoopID; +} + +bool llvm::hasDisableAllTransformsHint(const Loop *L) { + return getBooleanLoopAttribute(L, LLVMLoopDisableNonforced); +} + +TransformationMode llvm::hasUnrollTransformation(Loop *L) { + if (getBooleanLoopAttribute(L, "llvm.loop.unroll.disable")) + return TM_SuppressedByUser; + + Optional<int> Count = + getOptionalIntLoopAttribute(L, "llvm.loop.unroll.count"); + if (Count.hasValue()) + return Count.getValue() == 1 ? TM_SuppressedByUser : TM_ForcedByUser; + + if (getBooleanLoopAttribute(L, "llvm.loop.unroll.enable")) + return TM_ForcedByUser; + + if (getBooleanLoopAttribute(L, "llvm.loop.unroll.full")) + return TM_ForcedByUser; + + if (hasDisableAllTransformsHint(L)) + return TM_Disable; + + return TM_Unspecified; +} + +TransformationMode llvm::hasUnrollAndJamTransformation(Loop *L) { + if (getBooleanLoopAttribute(L, "llvm.loop.unroll_and_jam.disable")) + return TM_SuppressedByUser; + + Optional<int> Count = + getOptionalIntLoopAttribute(L, "llvm.loop.unroll_and_jam.count"); + if (Count.hasValue()) + return Count.getValue() == 1 ? TM_SuppressedByUser : TM_ForcedByUser; + + if (getBooleanLoopAttribute(L, "llvm.loop.unroll_and_jam.enable")) + return TM_ForcedByUser; + + if (hasDisableAllTransformsHint(L)) + return TM_Disable; + + return TM_Unspecified; +} + +TransformationMode llvm::hasVectorizeTransformation(Loop *L) { + Optional<bool> Enable = + getOptionalBoolLoopAttribute(L, "llvm.loop.vectorize.enable"); + + if (Enable == false) + return TM_SuppressedByUser; + + Optional<int> VectorizeWidth = + getOptionalIntLoopAttribute(L, "llvm.loop.vectorize.width"); + Optional<int> InterleaveCount = + getOptionalIntLoopAttribute(L, "llvm.loop.interleave.count"); + + if (Enable == true) { + // 'Forcing' vector width and interleave count to one effectively disables + // this tranformation. + if (VectorizeWidth == 1 && InterleaveCount == 1) + return TM_SuppressedByUser; + return TM_ForcedByUser; + } + + if (getBooleanLoopAttribute(L, "llvm.loop.isvectorized")) + return TM_Disable; + + if (VectorizeWidth == 1 && InterleaveCount == 1) + return TM_Disable; + + if (VectorizeWidth > 1 || InterleaveCount > 1) + return TM_Enable; + + if (hasDisableAllTransformsHint(L)) + return TM_Disable; + + return TM_Unspecified; +} + +TransformationMode llvm::hasDistributeTransformation(Loop *L) { + if (getBooleanLoopAttribute(L, "llvm.loop.distribute.enable")) + return TM_ForcedByUser; + + if (hasDisableAllTransformsHint(L)) + return TM_Disable; + + return TM_Unspecified; +} + +TransformationMode llvm::hasLICMVersioningTransformation(Loop *L) { + if (getBooleanLoopAttribute(L, "llvm.loop.licm_versioning.disable")) + return TM_SuppressedByUser; + + if (hasDisableAllTransformsHint(L)) + return TM_Disable; + + return TM_Unspecified; } /// Does a BFS from a given node to all of its children inside a given loop. @@ -1425,14 +525,19 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, // Remove the old branch. Preheader->getTerminator()->eraseFromParent(); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); if (DT) { // Update the dominator tree by informing it about the new edge from the // preheader to the exit. - DT->insertEdge(Preheader, ExitBlock); + DTU.insertEdge(Preheader, ExitBlock); // Inform the dominator tree about the removed edge. - DT->deleteEdge(Preheader, L->getHeader()); + DTU.deleteEdge(Preheader, L->getHeader()); } + // Use a map to unique and a vector to guarantee deterministic ordering. + llvm::SmallDenseSet<std::pair<DIVariable *, DIExpression *>, 4> DeadDebugSet; + llvm::SmallVector<DbgVariableIntrinsic *, 4> DeadDebugInst; + // Given LCSSA form is satisfied, we should not have users of instructions // within the dead loop outside of the loop. However, LCSSA doesn't take // unreachable uses into account. We handle them here. @@ -1457,8 +562,27 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, "Unexpected user in reachable block"); U.set(Undef); } + auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I); + if (!DVI) + continue; + auto Key = DeadDebugSet.find({DVI->getVariable(), DVI->getExpression()}); + if (Key != DeadDebugSet.end()) + continue; + DeadDebugSet.insert({DVI->getVariable(), DVI->getExpression()}); + DeadDebugInst.push_back(DVI); } + // After the loop has been deleted all the values defined and modified + // inside the loop are going to be unavailable. + // Since debug values in the loop have been deleted, inserting an undef + // dbg.value truncates the range of any dbg.value before the loop where the + // loop used to be. This is particularly important for constant values. + DIBuilder DIB(*ExitBlock->getModule()); + for (auto *DVI : DeadDebugInst) + DIB.insertDbgValueIntrinsic( + UndefValue::get(Builder.getInt32Ty()), DVI->getVariable(), + DVI->getExpression(), DVI->getDebugLoc(), ExitBlock->getFirstNonPHI()); + // Remove the block from the reference counting scheme, so that we can // delete it freely later. for (auto *Block : L->blocks()) @@ -1519,6 +643,28 @@ Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { return (FalseVal + (TrueVal / 2)) / TrueVal; } +bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, + ScalarEvolution &SE) { + Loop *OuterL = InnerLoop->getParentLoop(); + if (!OuterL) + return true; + + // Get the backedge taken count for the inner loop + BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); + const SCEV *InnerLoopBECountSC = SE.getExitCount(InnerLoop, InnerLoopLatch); + if (isa<SCEVCouldNotCompute>(InnerLoopBECountSC) || + !InnerLoopBECountSC->getType()->isIntegerTy()) + return false; + + // Get whether count is invariant to the outer loop + ScalarEvolution::LoopDisposition LD = + SE.getLoopDisposition(InnerLoopBECountSC, OuterL); + if (LD != ScalarEvolution::LoopInvariant) + return false; + + return true; +} + /// Adds a 'fast' flag to floating point operations. static Value *addFastMathFlag(Value *V) { if (isa<FPMathOperator>(V)) { @@ -1529,6 +675,51 @@ static Value *addFastMathFlag(Value *V) { return V; } +Value *llvm::createMinMaxOp(IRBuilder<> &Builder, + RecurrenceDescriptor::MinMaxRecurrenceKind RK, + Value *Left, Value *Right) { + CmpInst::Predicate P = CmpInst::ICMP_NE; + switch (RK) { + default: + llvm_unreachable("Unknown min/max recurrence kind"); + case RecurrenceDescriptor::MRK_UIntMin: + P = CmpInst::ICMP_ULT; + break; + case RecurrenceDescriptor::MRK_UIntMax: + P = CmpInst::ICMP_UGT; + break; + case RecurrenceDescriptor::MRK_SIntMin: + P = CmpInst::ICMP_SLT; + break; + case RecurrenceDescriptor::MRK_SIntMax: + P = CmpInst::ICMP_SGT; + break; + case RecurrenceDescriptor::MRK_FloatMin: + P = CmpInst::FCMP_OLT; + break; + case RecurrenceDescriptor::MRK_FloatMax: + P = CmpInst::FCMP_OGT; + break; + } + + // We only match FP sequences that are 'fast', so we can unconditionally + // set it on any generated instructions. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + FastMathFlags FMF; + FMF.setFast(); + Builder.setFastMathFlags(FMF); + + Value *Cmp; + if (RK == RecurrenceDescriptor::MRK_FloatMin || + RK == RecurrenceDescriptor::MRK_FloatMax) + Cmp = Builder.CreateFCmp(P, Left, Right, "rdx.minmax.cmp"); + else + Cmp = Builder.CreateICmp(P, Left, Right, "rdx.minmax.cmp"); + + Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); + return Select; +} + // Helper to generate an ordered reduction. Value * llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src, @@ -1550,8 +741,7 @@ llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src, } else { assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid && "Invalid min/max"); - Result = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, Result, - Ext); + Result = createMinMaxOp(Builder, MinMaxKind, Result, Ext); } if (!RedOps.empty()) @@ -1594,8 +784,7 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op, } else { assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid && "Invalid min/max"); - TmpVec = RecurrenceDescriptor::createMinMaxOp(Builder, MinMaxKind, TmpVec, - Shuf); + TmpVec = createMinMaxOp(Builder, MinMaxKind, TmpVec, Shuf); } if (!RedOps.empty()) propagateIRFlags(TmpVec, RedOps); @@ -1613,7 +802,7 @@ Value *llvm::createSimpleTargetReduction( assert(isa<VectorType>(Src->getType()) && "Type must be a vector"); Value *ScalarUdf = UndefValue::get(Src->getType()->getVectorElementType()); - std::function<Value*()> BuildFunc; + std::function<Value *()> BuildFunc; using RD = RecurrenceDescriptor; RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid; // TODO: Support creating ordered reductions. @@ -1739,3 +928,39 @@ void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) { VecOp->andIRFlags(V); } } + +bool llvm::isKnownNegativeInLoop(const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(S->getType()); + return SE.isAvailableAtLoopEntry(S, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SLT, S, Zero); +} + +bool llvm::isKnownNonNegativeInLoop(const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + const SCEV *Zero = SE.getZero(S->getType()); + return SE.isAvailableAtLoopEntry(S, L) && + SE.isLoopEntryGuardedByCond(L, ICmpInst::ICMP_SGE, S, Zero); +} + +bool llvm::cannotBeMinInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE, + bool Signed) { + unsigned BitWidth = cast<IntegerType>(S->getType())->getBitWidth(); + APInt Min = Signed ? APInt::getSignedMinValue(BitWidth) : + APInt::getMinValue(BitWidth); + auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; + return SE.isAvailableAtLoopEntry(S, L) && + SE.isLoopEntryGuardedByCond(L, Predicate, S, + SE.getConstant(Min)); +} + +bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE, + bool Signed) { + unsigned BitWidth = cast<IntegerType>(S->getType())->getBitWidth(); + APInt Max = Signed ? APInt::getSignedMaxValue(BitWidth) : + APInt::getMaxValue(BitWidth); + auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; + return SE.isAvailableAtLoopEntry(S, L) && + SE.isLoopEntryGuardedByCond(L, Predicate, S, + SE.getConstant(Max)); +} diff --git a/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 03006ef3a2d3..661b4fa5bcb7 100644 --- a/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -301,7 +301,7 @@ static void createMemMoveLoop(Instruction *InsertBefore, // the appropriate conditional branches when the loop is built. ICmpInst *PtrCompare = new ICmpInst(InsertBefore, ICmpInst::ICMP_ULT, SrcAddr, DstAddr, "compare_src_dst"); - TerminatorInst *ThenTerm, *ElseTerm; + Instruction *ThenTerm, *ElseTerm; SplitBlockAndInsertIfThenElse(PtrCompare, InsertBefore, &ThenTerm, &ElseTerm); diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index e99ecfef19cd..d019a44fc705 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -372,7 +372,7 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), Case.getCaseSuccessor())); - llvm::sort(Cases.begin(), Cases.end(), CaseCmp()); + llvm::sort(Cases, CaseCmp()); // Merge case into clusters if (Cases.size() >= 2) { diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index ba4b7f3cc263..ae5e72ea4d30 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -174,6 +174,49 @@ std::pair<Function *, Function *> llvm::createSanitizerCtorAndInitFunctions( return std::make_pair(Ctor, InitFunction); } +std::pair<Function *, Function *> +llvm::getOrCreateSanitizerCtorAndInitFunctions( + Module &M, StringRef CtorName, StringRef InitName, + ArrayRef<Type *> InitArgTypes, ArrayRef<Value *> InitArgs, + function_ref<void(Function *, Function *)> FunctionsCreatedCallback, + StringRef VersionCheckName) { + assert(!CtorName.empty() && "Expected ctor function name"); + + if (Function *Ctor = M.getFunction(CtorName)) + // FIXME: Sink this logic into the module, similar to the handling of + // globals. This will make moving to a concurrent model much easier. + if (Ctor->arg_size() == 0 || + Ctor->getReturnType() == Type::getVoidTy(M.getContext())) + return {Ctor, declareSanitizerInitFunction(M, InitName, InitArgTypes)}; + + Function *Ctor, *InitFunction; + std::tie(Ctor, InitFunction) = llvm::createSanitizerCtorAndInitFunctions( + M, CtorName, InitName, InitArgTypes, InitArgs, VersionCheckName); + FunctionsCreatedCallback(Ctor, InitFunction); + return std::make_pair(Ctor, InitFunction); +} + +Function *llvm::getOrCreateInitFunction(Module &M, StringRef Name) { + assert(!Name.empty() && "Expected init function name"); + if (Function *F = M.getFunction(Name)) { + if (F->arg_size() != 0 || + F->getReturnType() != Type::getVoidTy(M.getContext())) { + std::string Err; + raw_string_ostream Stream(Err); + Stream << "Sanitizer interface function defined with wrong type: " << *F; + report_fatal_error(Err); + } + return F; + } + Function *F = checkSanitizerInterfaceFunction(M.getOrInsertFunction( + Name, AttributeList(), Type::getVoidTy(M.getContext()))); + F->setLinkage(Function::ExternalLinkage); + + appendToGlobalCtors(M, F, 0); + + return F; +} + void llvm::filterDeadComdatFunctions( Module &M, SmallVectorImpl<Function *> &DeadComdatFunctions) { // Build a map from the comdat to the number of entries in that comdat we diff --git a/lib/Transforms/Utils/OrderedInstructions.cpp b/lib/Transforms/Utils/OrderedInstructions.cpp deleted file mode 100644 index 6d0b96f6aa8a..000000000000 --- a/lib/Transforms/Utils/OrderedInstructions.cpp +++ /dev/null @@ -1,51 +0,0 @@ -//===-- OrderedInstructions.cpp - Instruction dominance function ---------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This file defines utility to check dominance relation of 2 instructions. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Utils/OrderedInstructions.h" -using namespace llvm; - -bool OrderedInstructions::localDominates(const Instruction *InstA, - const Instruction *InstB) const { - assert(InstA->getParent() == InstB->getParent() && - "Instructions must be in the same basic block"); - - const BasicBlock *IBB = InstA->getParent(); - auto OBB = OBBMap.find(IBB); - if (OBB == OBBMap.end()) - OBB = OBBMap.insert({IBB, make_unique<OrderedBasicBlock>(IBB)}).first; - return OBB->second->dominates(InstA, InstB); -} - -/// Given 2 instructions, use OrderedBasicBlock to check for dominance relation -/// if the instructions are in the same basic block, Otherwise, use dominator -/// tree. -bool OrderedInstructions::dominates(const Instruction *InstA, - const Instruction *InstB) const { - // Use ordered basic block to do dominance check in case the 2 instructions - // are in the same basic block. - if (InstA->getParent() == InstB->getParent()) - return localDominates(InstA, InstB); - return DT->dominates(InstA->getParent(), InstB->getParent()); -} - -bool OrderedInstructions::dfsBefore(const Instruction *InstA, - const Instruction *InstB) const { - // Use ordered basic block in case the 2 instructions are in the same basic - // block. - if (InstA->getParent() == InstB->getParent()) - return localDominates(InstA, InstB); - - DomTreeNode *DA = DT->getNode(InstA->getParent()); - DomTreeNode *DB = DT->getNode(InstB->getParent()); - return DA->getDFSNumIn() < DB->getDFSNumIn(); -} diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp index 2923977b791a..585ce6b4c118 100644 --- a/lib/Transforms/Utils/PredicateInfo.cpp +++ b/lib/Transforms/Utils/PredicateInfo.cpp @@ -35,7 +35,6 @@ #include "llvm/Support/DebugCounter.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Transforms/Utils.h" -#include "llvm/Transforms/Utils/OrderedInstructions.h" #include <algorithm> #define DEBUG_TYPE "predicateinfo" using namespace llvm; @@ -523,7 +522,7 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter, if (isa<PredicateWithEdge>(ValInfo)) { IRBuilder<> B(getBranchTerminator(ValInfo)); Function *IF = getCopyDeclaration(F.getParent(), Op->getType()); - if (IF->user_begin() == IF->user_end()) + if (empty(IF->users())) CreatedDeclarations.insert(IF); CallInst *PIC = B.CreateCall(IF, Op, Op->getName() + "." + Twine(Counter++)); @@ -535,7 +534,7 @@ Value *PredicateInfo::materializeStack(unsigned int &Counter, "Should not have gotten here without it being an assume"); IRBuilder<> B(PAssume->AssumeInst); Function *IF = getCopyDeclaration(F.getParent(), Op->getType()); - if (IF->user_begin() == IF->user_end()) + if (empty(IF->users())) CreatedDeclarations.insert(IF); CallInst *PIC = B.CreateCall(IF, Op); PredicateMap.insert({PIC, ValInfo}); @@ -570,7 +569,7 @@ void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { auto Comparator = [&](const Value *A, const Value *B) { return valueComesBefore(OI, A, B); }; - llvm::sort(OpsToRename.begin(), OpsToRename.end(), Comparator); + llvm::sort(OpsToRename, Comparator); ValueDFS_Compare Compare(OI); // Compute liveness, and rename in O(uses) per Op. for (auto *Op : OpsToRename) { diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index 86e15bbd7f22..91e4f4254b3e 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -82,8 +82,7 @@ bool llvm::isAllocaPromotable(const AllocaInst *AI) { if (SI->isVolatile()) return false; } else if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U)) { - if (II->getIntrinsicID() != Intrinsic::lifetime_start && - II->getIntrinsicID() != Intrinsic::lifetime_end) + if (!II->isLifetimeStartOrEnd()) return false; } else if (const BitCastInst *BCI = dyn_cast<BitCastInst>(U)) { if (BCI->getType() != Type::getInt8PtrTy(U->getContext(), AS)) @@ -116,7 +115,7 @@ struct AllocaInfo { bool OnlyUsedInOneBlock; Value *AllocaPointerVal; - TinyPtrVector<DbgInfoIntrinsic *> DbgDeclares; + TinyPtrVector<DbgVariableIntrinsic *> DbgDeclares; void clear() { DefiningBlocks.clear(); @@ -263,7 +262,7 @@ struct PromoteMem2Reg { /// For each alloca, we keep track of the dbg.declare intrinsic that /// describes it, if any, so that we can convert it to a dbg.value /// intrinsic if the alloca gets promoted. - SmallVector<TinyPtrVector<DbgInfoIntrinsic *>, 8> AllocaDbgDeclares; + SmallVector<TinyPtrVector<DbgVariableIntrinsic *>, 8> AllocaDbgDeclares; /// The set of basic blocks the renamer has already visited. SmallPtrSet<BasicBlock *, 16> Visited; @@ -426,7 +425,7 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, // Record debuginfo for the store and remove the declaration's // debuginfo. - for (DbgInfoIntrinsic *DII : Info.DbgDeclares) { + for (DbgVariableIntrinsic *DII : Info.DbgDeclares) { DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DII, Info.OnlyStore, DIB); DII->eraseFromParent(); @@ -477,7 +476,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, // Sort the stores by their index, making it efficient to do a lookup with a // binary search. - llvm::sort(StoresByIndex.begin(), StoresByIndex.end(), less_first()); + llvm::sort(StoresByIndex, less_first()); // Walk all of the loads from this alloca, replacing them with the nearest // store above them, if any. @@ -527,7 +526,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, while (!AI->use_empty()) { StoreInst *SI = cast<StoreInst>(AI->user_back()); // Record debuginfo for the store before removing it. - for (DbgInfoIntrinsic *DII : Info.DbgDeclares) { + for (DbgVariableIntrinsic *DII : Info.DbgDeclares) { DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); ConvertDebugDeclareToDebugValue(DII, SI, DIB); } @@ -539,7 +538,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, LBI.deleteValue(AI); // The alloca's debuginfo can be removed as well. - for (DbgInfoIntrinsic *DII : Info.DbgDeclares) { + for (DbgVariableIntrinsic *DII : Info.DbgDeclares) { DII->eraseFromParent(); LBI.deleteValue(DII); } @@ -638,10 +637,9 @@ void PromoteMem2Reg::run() { SmallVector<BasicBlock *, 32> PHIBlocks; IDF.calculate(PHIBlocks); if (PHIBlocks.size() > 1) - llvm::sort(PHIBlocks.begin(), PHIBlocks.end(), - [this](BasicBlock *A, BasicBlock *B) { - return BBNumbers.lookup(A) < BBNumbers.lookup(B); - }); + llvm::sort(PHIBlocks, [this](BasicBlock *A, BasicBlock *B) { + return BBNumbers.lookup(A) < BBNumbers.lookup(B); + }); unsigned CurrentVersion = 0; for (BasicBlock *BB : PHIBlocks) @@ -752,14 +750,18 @@ void PromoteMem2Reg::run() { // Ok, now we know that all of the PHI nodes are missing entries for some // basic blocks. Start by sorting the incoming predecessors for efficient // access. - llvm::sort(Preds.begin(), Preds.end()); + auto CompareBBNumbers = [this](BasicBlock *A, BasicBlock *B) { + return BBNumbers.lookup(A) < BBNumbers.lookup(B); + }; + llvm::sort(Preds, CompareBBNumbers); // Now we loop through all BB's which have entries in SomePHI and remove // them from the Preds list. for (unsigned i = 0, e = SomePHI->getNumIncomingValues(); i != e; ++i) { // Do a log(n) search of the Preds list for the entry we want. SmallVectorImpl<BasicBlock *>::iterator EntIt = std::lower_bound( - Preds.begin(), Preds.end(), SomePHI->getIncomingBlock(i)); + Preds.begin(), Preds.end(), SomePHI->getIncomingBlock(i), + CompareBBNumbers); assert(EntIt != Preds.end() && *EntIt == SomePHI->getIncomingBlock(i) && "PHI node has entry for a block which is not a predecessor!"); @@ -932,7 +934,7 @@ NextIteration: // The currently active variable for this block is now the PHI. IncomingVals[AllocaNo] = APN; - for (DbgInfoIntrinsic *DII : AllocaDbgDeclares[AllocaNo]) + for (DbgVariableIntrinsic *DII : AllocaDbgDeclares[AllocaNo]) ConvertDebugDeclareToDebugValue(DII, APN, DIB); // Get the next phi node. @@ -951,7 +953,7 @@ NextIteration: if (!Visited.insert(BB).second) return; - for (BasicBlock::iterator II = BB->begin(); !isa<TerminatorInst>(II);) { + for (BasicBlock::iterator II = BB->begin(); !II->isTerminator();) { Instruction *I = &*II++; // get the instruction, increment iterator if (LoadInst *LI = dyn_cast<LoadInst>(I)) { @@ -992,7 +994,7 @@ NextIteration: // Record debuginfo for the store before removing it. IncomingLocs[AllocaNo] = SI->getDebugLoc(); - for (DbgInfoIntrinsic *DII : AllocaDbgDeclares[ai->second]) + for (DbgVariableIntrinsic *DII : AllocaDbgDeclares[ai->second]) ConvertDebugDeclareToDebugValue(DII, SI, DIB); BB->getInstList().erase(SI); } diff --git a/lib/Transforms/Utils/SSAUpdater.cpp b/lib/Transforms/Utils/SSAUpdater.cpp index 4a1fd8d571aa..9e5fb0e7172d 100644 --- a/lib/Transforms/Utils/SSAUpdater.cpp +++ b/lib/Transforms/Utils/SSAUpdater.cpp @@ -64,6 +64,11 @@ bool SSAUpdater::HasValueForBlock(BasicBlock *BB) const { return getAvailableVals(AV).count(BB); } +Value *SSAUpdater::FindValueForBlock(BasicBlock *BB) const { + AvailableValsTy::iterator AVI = getAvailableVals(AV).find(BB); + return (AVI != getAvailableVals(AV).end()) ? AVI->second : nullptr; +} + void SSAUpdater::AddAvailableValue(BasicBlock *BB, Value *V) { assert(ProtoType && "Need to initialize SSAUpdater"); assert(ProtoType == V->getType() && diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index c87b5c16ffce..03b73954321d 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -173,14 +173,15 @@ class SimplifyCFGOpt { const DataLayout &DL; SmallPtrSetImpl<BasicBlock *> *LoopHeaders; const SimplifyCFGOptions &Options; + bool Resimplify; - Value *isValueEqualityComparison(TerminatorInst *TI); + Value *isValueEqualityComparison(Instruction *TI); BasicBlock *GetValueEqualityComparisonCases( - TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); - bool SimplifyEqualityComparisonWithOnlyPredecessor(TerminatorInst *TI, + Instruction *TI, std::vector<ValueEqualityComparisonCase> &Cases); + bool SimplifyEqualityComparisonWithOnlyPredecessor(Instruction *TI, BasicBlock *Pred, IRBuilder<> &Builder); - bool FoldValueComparisonIntoPredecessors(TerminatorInst *TI, + bool FoldValueComparisonIntoPredecessors(Instruction *TI, IRBuilder<> &Builder); bool SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder); @@ -194,6 +195,9 @@ class SimplifyCFGOpt { bool SimplifyUncondBranch(BranchInst *BI, IRBuilder<> &Builder); bool SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder); + bool tryToSimplifyUncondBranchWithICmpInIt(ICmpInst *ICI, + IRBuilder<> &Builder); + public: SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout &DL, SmallPtrSetImpl<BasicBlock *> *LoopHeaders, @@ -201,6 +205,13 @@ public: : TTI(TTI), DL(DL), LoopHeaders(LoopHeaders), Options(Opts) {} bool run(BasicBlock *BB); + bool simplifyOnce(BasicBlock *BB); + + // Helper to set Resimplify and return change indication. + bool requestResimplify() { + Resimplify = true; + return true; + } }; } // end anonymous namespace @@ -208,7 +219,7 @@ public: /// Return true if it is safe to merge these two /// terminator instructions together. static bool -SafeToMergeTerminators(TerminatorInst *SI1, TerminatorInst *SI2, +SafeToMergeTerminators(Instruction *SI1, Instruction *SI2, SmallSetVector<BasicBlock *, 4> *FailBlocks = nullptr) { if (SI1 == SI2) return false; // Can't merge with self! @@ -315,7 +326,7 @@ static unsigned ComputeSpeculationCost(const User *I, /// V plus its non-dominating operands. If that cost is greater than /// CostRemaining, false is returned and CostRemaining is undefined. static bool DominatesMergePoint(Value *V, BasicBlock *BB, - SmallPtrSetImpl<Instruction *> *AggressiveInsts, + SmallPtrSetImpl<Instruction *> &AggressiveInsts, unsigned &CostRemaining, const TargetTransformInfo &TTI, unsigned Depth = 0) { @@ -349,13 +360,8 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB, if (!BI || BI->isConditional() || BI->getSuccessor(0) != BB) return true; - // If we aren't allowing aggressive promotion anymore, then don't consider - // instructions in the 'if region'. - if (!AggressiveInsts) - return false; - // If we have seen this instruction before, don't count it again. - if (AggressiveInsts->count(I)) + if (AggressiveInsts.count(I)) return true; // Okay, it looks like the instruction IS in the "condition". Check to @@ -373,7 +379,7 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB, // is expected to be undone in CodeGenPrepare if the speculation has not // enabled further IR optimizations. if (Cost > CostRemaining && - (!SpeculateOneExpensiveInst || !AggressiveInsts->empty() || Depth > 0)) + (!SpeculateOneExpensiveInst || !AggressiveInsts.empty() || Depth > 0)) return false; // Avoid unsigned wrap. @@ -386,7 +392,7 @@ static bool DominatesMergePoint(Value *V, BasicBlock *BB, Depth + 1)) return false; // Okay, it's safe to do this! Remember this instruction. - AggressiveInsts->insert(I); + AggressiveInsts.insert(I); return true; } @@ -664,7 +670,7 @@ private: } // end anonymous namespace -static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { +static void EraseTerminatorAndDCECond(Instruction *TI) { Instruction *Cond = nullptr; if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { Cond = dyn_cast<Instruction>(SI->getCondition()); @@ -682,12 +688,12 @@ static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) { /// Return true if the specified terminator checks /// to see if a value is equal to constant integer value. -Value *SimplifyCFGOpt::isValueEqualityComparison(TerminatorInst *TI) { +Value *SimplifyCFGOpt::isValueEqualityComparison(Instruction *TI) { Value *CV = nullptr; if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { // Do not permit merging of large switch instructions into their // predecessors unless there is only one predecessor. - if (SI->getNumSuccessors() * pred_size(SI->getParent()) <= 128) + if (!SI->getParent()->hasNPredecessorsOrMore(128 / SI->getNumSuccessors())) CV = SI->getCondition(); } else if (BranchInst *BI = dyn_cast<BranchInst>(TI)) if (BI->isConditional() && BI->getCondition()->hasOneUse()) @@ -710,7 +716,7 @@ Value *SimplifyCFGOpt::isValueEqualityComparison(TerminatorInst *TI) { /// Given a value comparison instruction, /// decode all of the 'cases' that it represents and return the 'default' block. BasicBlock *SimplifyCFGOpt::GetValueEqualityComparisonCases( - TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases) { + Instruction *TI, std::vector<ValueEqualityComparisonCase> &Cases) { if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { Cases.reserve(SI->getNumCases()); for (auto Case : SI->cases()) @@ -800,7 +806,7 @@ static void setBranchWeights(Instruction *I, uint32_t TrueWeight, /// determines the outcome of this comparison. If so, simplify TI. This does a /// very limited form of jump threading. bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( - TerminatorInst *TI, BasicBlock *Pred, IRBuilder<> &Builder) { + Instruction *TI, BasicBlock *Pred, IRBuilder<> &Builder) { Value *PredVal = isValueEqualityComparison(Pred->getTerminator()); if (!PredVal) return false; // Not a value comparison in predecessor. @@ -848,7 +854,7 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( << "Through successor TI: " << *TI << "Leaving: " << *NI << "\n"); - EraseTerminatorInstAndDCECond(TI); + EraseTerminatorAndDCECond(TI); return true; } @@ -930,7 +936,7 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( << "Through successor TI: " << *TI << "Leaving: " << *NI << "\n"); - EraseTerminatorInstAndDCECond(TI); + EraseTerminatorAndDCECond(TI); return true; } @@ -965,10 +971,10 @@ static inline bool HasBranchWeights(const Instruction *I) { return false; } -/// Get Weights of a given TerminatorInst, the default weight is at the front +/// Get Weights of a given terminator, the default weight is at the front /// of the vector. If TI is a conditional eq, we need to swap the branch-weight /// metadata. -static void GetBranchWeights(TerminatorInst *TI, +static void GetBranchWeights(Instruction *TI, SmallVectorImpl<uint64_t> &Weights) { MDNode *MD = TI->getMetadata(LLVMContext::MD_prof); assert(MD); @@ -1002,7 +1008,7 @@ static void FitWeights(MutableArrayRef<uint64_t> Weights) { /// (either a switch or a branch on "X == c"). /// See if any of the predecessors of the terminator block are value comparisons /// on the same value. If so, and if safe to do so, fold them together. -bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, +bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(Instruction *TI, IRBuilder<> &Builder) { BasicBlock *BB = TI->getParent(); Value *CV = isValueEqualityComparison(TI); // CondVal @@ -1014,7 +1020,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, BasicBlock *Pred = Preds.pop_back_val(); // See if the predecessor is a comparison with the same value. - TerminatorInst *PTI = Pred->getTerminator(); + Instruction *PTI = Pred->getTerminator(); Value *PCV = isValueEqualityComparison(PTI); // PredCondVal if (PCV == CV && TI != PTI) { @@ -1191,7 +1197,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, setBranchWeights(NewSI, MDWeights); } - EraseTerminatorInstAndDCECond(PTI); + EraseTerminatorAndDCECond(PTI); // Okay, last check. If BB is still a successor of PSI, then we must // have an infinite loop case. If so, add an infinitely looping block @@ -1270,7 +1276,7 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, do { // If we are hoisting the terminator instruction, don't move one (making a // broken BB), instead clone it, and remove BI. - if (isa<TerminatorInst>(I1)) + if (I1->isTerminator()) goto HoistTerminator; // If we're going to hoist a call, make sure that the two instructions we're @@ -1315,8 +1321,9 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, LLVMContext::MD_align, LLVMContext::MD_dereferenceable, LLVMContext::MD_dereferenceable_or_null, - LLVMContext::MD_mem_parallel_loop_access}; - combineMetadata(I1, I2, KnownIDs); + LLVMContext::MD_mem_parallel_loop_access, + LLVMContext::MD_access_group}; + combineMetadata(I1, I2, KnownIDs, true); // I1 and I2 are being combined into a single instruction. Its debug // location is the merged locations of the original instructions. @@ -1375,7 +1382,13 @@ HoistTerminator: NT->takeName(I1); } + // Ensure terminator gets a debug location, even an unknown one, in case + // it involves inlinable calls. + NT->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); + + // PHIs created below will adopt NT's merged DebugLoc. IRBuilder<NoFolder> Builder(NT); + // Hoisting one of the terminators from our successor is a great thing. // Unfortunately, the successors of the if/else blocks may have PHI nodes in // them. If they do, all PHI entries for BB1/BB2 must agree for all PHI @@ -1407,7 +1420,7 @@ HoistTerminator: for (BasicBlock *Succ : successors(BB1)) AddPredecessorToBlock(Succ, BIParent, BB1); - EraseTerminatorInstAndDCECond(BI); + EraseTerminatorAndDCECond(BI); return true; } @@ -1582,7 +1595,7 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { // However, as N-way merge for CallInst is rare, so we use simplified API // instead of using complex API for N-way merge. I0->applyMergedLocation(I0->getDebugLoc(), I->getDebugLoc()); - combineMetadataForCSE(I0, I); + combineMetadataForCSE(I0, I, true); I0->andIRFlags(I); } @@ -1940,11 +1953,11 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, } assert(EndBB == BI->getSuccessor(!Invert) && "No edge from to end block"); - // Keep a count of how many times instructions are used within CondBB when - // they are candidates for sinking into CondBB. Specifically: + // Keep a count of how many times instructions are used within ThenBB when + // they are candidates for sinking into ThenBB. Specifically: // - They are defined in BB, and // - They have no side effects, and - // - All of their uses are in CondBB. + // - All of their uses are in ThenBB. SmallDenseMap<Instruction *, unsigned, 4> SinkCandidateUseCounts; SmallVector<Instruction *, 4> SpeculatedDbgIntrinsics; @@ -1994,14 +2007,14 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, } } - // Consider any sink candidates which are only used in CondBB as costs for + // Consider any sink candidates which are only used in ThenBB as costs for // speculation. Note, while we iterate over a DenseMap here, we are summing // and so iteration order isn't significant. for (SmallDenseMap<Instruction *, unsigned, 4>::iterator I = SinkCandidateUseCounts.begin(), E = SinkCandidateUseCounts.end(); I != E; ++I) - if (I->first->getNumUses() == I->second) { + if (I->first->hasNUses(I->second)) { ++SpeculationCost; if (SpeculationCost > 1) return false; @@ -2241,7 +2254,7 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL, // Loop over all of the edges from PredBB to BB, changing them to branch // to EdgeBB instead. - TerminatorInst *PredBBTI = PredBB->getTerminator(); + Instruction *PredBBTI = PredBB->getTerminator(); for (unsigned i = 0, e = PredBBTI->getNumSuccessors(); i != e; ++i) if (PredBBTI->getSuccessor(i) == BB) { BB->removePredecessor(PredBB); @@ -2249,7 +2262,7 @@ static bool FoldCondBranchOnPHI(BranchInst *BI, const DataLayout &DL, } // Recurse, simplifying any other constants. - return FoldCondBranchOnPHI(BI, DL, AC) | true; + return FoldCondBranchOnPHI(BI, DL, AC) || true; } return false; @@ -2304,9 +2317,9 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, continue; } - if (!DominatesMergePoint(PN->getIncomingValue(0), BB, &AggressiveInsts, + if (!DominatesMergePoint(PN->getIncomingValue(0), BB, AggressiveInsts, MaxCostVal0, TTI) || - !DominatesMergePoint(PN->getIncomingValue(1), BB, &AggressiveInsts, + !DominatesMergePoint(PN->getIncomingValue(1), BB, AggressiveInsts, MaxCostVal1, TTI)) return false; } @@ -2336,8 +2349,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, IfBlock1 = nullptr; } else { DomBlock = *pred_begin(IfBlock1); - for (BasicBlock::iterator I = IfBlock1->begin(); !isa<TerminatorInst>(I); - ++I) + for (BasicBlock::iterator I = IfBlock1->begin(); !I->isTerminator(); ++I) if (!AggressiveInsts.count(&*I) && !isa<DbgInfoIntrinsic>(I)) { // This is not an aggressive instruction that we can promote. // Because of this, we won't be able to get rid of the control flow, so @@ -2350,8 +2362,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, IfBlock2 = nullptr; } else { DomBlock = *pred_begin(IfBlock2); - for (BasicBlock::iterator I = IfBlock2->begin(); !isa<TerminatorInst>(I); - ++I) + for (BasicBlock::iterator I = IfBlock2->begin(); !I->isTerminator(); ++I) if (!AggressiveInsts.count(&*I) && !isa<DbgInfoIntrinsic>(I)) { // This is not an aggressive instruction that we can promote. // Because of this, we won't be able to get rid of the control flow, so @@ -2371,20 +2382,10 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // Move all 'aggressive' instructions, which are defined in the // conditional parts of the if's up to the dominating block. - if (IfBlock1) { - for (auto &I : *IfBlock1) - I.dropUnknownNonDebugMetadata(); - DomBlock->getInstList().splice(InsertPt->getIterator(), - IfBlock1->getInstList(), IfBlock1->begin(), - IfBlock1->getTerminator()->getIterator()); - } - if (IfBlock2) { - for (auto &I : *IfBlock2) - I.dropUnknownNonDebugMetadata(); - DomBlock->getInstList().splice(InsertPt->getIterator(), - IfBlock2->getInstList(), IfBlock2->begin(), - IfBlock2->getTerminator()->getIterator()); - } + if (IfBlock1) + hoistAllInstructionsInto(DomBlock, InsertPt, IfBlock1); + if (IfBlock2) + hoistAllInstructionsInto(DomBlock, InsertPt, IfBlock2); while (PHINode *PN = dyn_cast<PHINode>(BB->begin())) { // Change the PHI node into a select instruction. @@ -2400,7 +2401,7 @@ static bool FoldTwoEntryPHINode(PHINode *PN, const TargetTransformInfo &TTI, // At this point, IfBlock1 and IfBlock2 are both empty, so our if statement // has been flattened. Change DomBlock to jump directly to our new block to // avoid other simplifycfg's kicking in on the diamond. - TerminatorInst *OldTI = DomBlock->getTerminator(); + Instruction *OldTI = DomBlock->getTerminator(); Builder.SetInsertPoint(OldTI); Builder.CreateBr(BB); OldTI->eraseFromParent(); @@ -2434,7 +2435,7 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI, TrueSucc->removePredecessor(BI->getParent()); FalseSucc->removePredecessor(BI->getParent()); Builder.CreateRetVoid(); - EraseTerminatorInstAndDCECond(BI); + EraseTerminatorAndDCECond(BI); return true; } @@ -2490,7 +2491,7 @@ static bool SimplifyCondBranchToTwoReturns(BranchInst *BI, << "\n " << *BI << "NewRet = " << *RI << "TRUEBLOCK: " << *TrueSucc << "FALSEBLOCK: " << *FalseSucc); - EraseTerminatorInstAndDCECond(BI); + EraseTerminatorAndDCECond(BI); return true; } @@ -2541,6 +2542,8 @@ static bool extractPredSuccWeights(BranchInst *PBI, BranchInst *BI, bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { BasicBlock *BB = BI->getParent(); + const unsigned PredCount = pred_size(BB); + Instruction *Cond = nullptr; if (BI->isConditional()) Cond = dyn_cast<Instruction>(BI->getCondition()); @@ -2590,7 +2593,8 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // too many instructions and these involved instructions can be executed // unconditionally. We denote all involved instructions except the condition // as "bonus instructions", and only allow this transformation when the - // number of the bonus instructions does not exceed a certain threshold. + // number of the bonus instructions we'll need to create when cloning into + // each predecessor does not exceed a certain threshold. unsigned NumBonusInsts = 0; for (auto I = BB->begin(); Cond != &*I; ++I) { // Ignore dbg intrinsics. @@ -2605,7 +2609,10 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // I is used in the same BB. Since BI uses Cond and doesn't have more slots // to use any other instruction, User must be an instruction between next(I) // and Cond. - ++NumBonusInsts; + + // Account for the cost of duplicating this instruction into each + // predecessor. + NumBonusInsts += PredCount; // Early exits once we reach the limit. if (NumBonusInsts > BonusInstThreshold) return false; @@ -2711,16 +2718,16 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // Clone Cond into the predecessor basic block, and or/and the // two conditions together. - Instruction *New = Cond->clone(); - RemapInstruction(New, VMap, + Instruction *CondInPred = Cond->clone(); + RemapInstruction(CondInPred, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - PredBlock->getInstList().insert(PBI->getIterator(), New); - New->takeName(Cond); - Cond->setName(New->getName() + ".old"); + PredBlock->getInstList().insert(PBI->getIterator(), CondInPred); + CondInPred->takeName(Cond); + Cond->setName(CondInPred->getName() + ".old"); if (BI->isConditional()) { Instruction *NewCond = cast<Instruction>( - Builder.CreateBinOp(Opc, PBI->getCondition(), New, "or.cond")); + Builder.CreateBinOp(Opc, PBI->getCondition(), CondInPred, "or.cond")); PBI->setCondition(NewCond); uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight; @@ -2784,7 +2791,8 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { Instruction *NotCond = cast<Instruction>( Builder.CreateNot(PBI->getCondition(), "not.cond")); MergedCond = cast<Instruction>( - Builder.CreateBinOp(Instruction::And, NotCond, New, "and.cond")); + Builder.CreateBinOp(Instruction::And, NotCond, CondInPred, + "and.cond")); if (PBI_C->isOne()) MergedCond = cast<Instruction>(Builder.CreateBinOp( Instruction::Or, PBI->getCondition(), MergedCond, "or.cond")); @@ -2793,7 +2801,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { // PBI_C is true: (PBI_Cond and BI_Value) or (!PBI_Cond) // is false: PBI_Cond and BI_Value MergedCond = cast<Instruction>(Builder.CreateBinOp( - Instruction::And, PBI->getCondition(), New, "and.cond")); + Instruction::And, PBI->getCondition(), CondInPred, "and.cond")); if (PBI_C->isOne()) { Instruction *NotCond = cast<Instruction>( Builder.CreateNot(PBI->getCondition(), "not.cond")); @@ -2807,7 +2815,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { } // Change PBI from Conditional to Unconditional. BranchInst *New_PBI = BranchInst::Create(TrueDest, PBI); - EraseTerminatorInstAndDCECond(PBI); + EraseTerminatorAndDCECond(PBI); PBI = New_PBI; } @@ -2873,7 +2881,7 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB, if (!AlternativeV) break; - assert(pred_size(Succ) == 2); + assert(Succ->hasNPredecessors(2)); auto PredI = pred_begin(Succ); BasicBlock *OtherPredBB = *PredI == BB ? *++PredI : *PredI; if (PHI->getIncomingValueForBlock(OtherPredBB) == AlternativeV) @@ -2922,7 +2930,7 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, isa<StoreInst>(I)) ++N; // Free instructions. - else if (isa<TerminatorInst>(I) || IsaBitcastOfPointerType(I)) + else if (I.isTerminator() || IsaBitcastOfPointerType(I)) continue; else return false; @@ -3402,7 +3410,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // Takes care of updating the successors and removing the old terminator. // Also makes sure not to introduce new successors by assuming that edges to // non-successor TrueBBs and FalseBBs aren't reachable. -static bool SimplifyTerminatorOnSelect(TerminatorInst *OldTerm, Value *Cond, +static bool SimplifyTerminatorOnSelect(Instruction *OldTerm, Value *Cond, BasicBlock *TrueBB, BasicBlock *FalseBB, uint32_t TrueWeight, uint32_t FalseWeight) { @@ -3414,7 +3422,7 @@ static bool SimplifyTerminatorOnSelect(TerminatorInst *OldTerm, Value *Cond, BasicBlock *KeepEdge2 = TrueBB != FalseBB ? FalseBB : nullptr; // Then remove the rest. - for (BasicBlock *Succ : OldTerm->successors()) { + for (BasicBlock *Succ : successors(OldTerm)) { // Make sure only to keep exactly one copy of each edge. if (Succ == KeepEdge1) KeepEdge1 = nullptr; @@ -3457,7 +3465,7 @@ static bool SimplifyTerminatorOnSelect(TerminatorInst *OldTerm, Value *Cond, Builder.CreateBr(FalseBB); } - EraseTerminatorInstAndDCECond(OldTerm); + EraseTerminatorAndDCECond(OldTerm); return true; } @@ -3534,9 +3542,8 @@ static bool SimplifyIndirectBrOnSelect(IndirectBrInst *IBI, SelectInst *SI) { /// /// We prefer to split the edge to 'end' so that there is a true/false entry to /// the PHI, merging the third icmp into the switch. -static bool tryToSimplifyUncondBranchWithICmpInIt( - ICmpInst *ICI, IRBuilder<> &Builder, const DataLayout &DL, - const TargetTransformInfo &TTI, const SimplifyCFGOptions &Options) { +bool SimplifyCFGOpt::tryToSimplifyUncondBranchWithICmpInIt( + ICmpInst *ICI, IRBuilder<> &Builder) { BasicBlock *BB = ICI->getParent(); // If the block has any PHIs in it or the icmp has multiple uses, it is too @@ -3571,7 +3578,7 @@ static bool tryToSimplifyUncondBranchWithICmpInIt( ICI->eraseFromParent(); } // BB is now empty, so it is likely to simplify away. - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } // Ok, the block is reachable from the default dest. If the constant we're @@ -3587,7 +3594,7 @@ static bool tryToSimplifyUncondBranchWithICmpInIt( ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); // BB is now empty, so it is likely to simplify away. - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } // The use of the icmp has to be in the 'end' block, by the only PHI node in @@ -3701,7 +3708,7 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, BasicBlock *NewBB = BB->splitBasicBlock(BI->getIterator(), "switch.early.test"); // Remove the uncond branch added to the old block. - TerminatorInst *OldTI = BB->getTerminator(); + Instruction *OldTI = BB->getTerminator(); Builder.SetInsertPoint(OldTI); if (TrueWhenEqual) @@ -3745,7 +3752,7 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, IRBuilder<> &Builder, } // Erase the old branch instruction. - EraseTerminatorInstAndDCECond(BI); + EraseTerminatorAndDCECond(BI); LLVM_DEBUG(dbgs() << " ** 'icmp' chain result is:\n" << *BB << '\n'); return true; @@ -3861,9 +3868,9 @@ bool SimplifyCFGOpt::SimplifySingleResume(ResumeInst *RI) { } // The landingpad is now unreachable. Zap it. - BB->eraseFromParent(); if (LoopHeaders) LoopHeaders->erase(BB); + BB->eraseFromParent(); return true; } @@ -3993,7 +4000,7 @@ static bool removeEmptyCleanup(CleanupReturnInst *RI) { if (UnwindDest == nullptr) { removeUnwindEdge(PredBB); } else { - TerminatorInst *TI = PredBB->getTerminator(); + Instruction *TI = PredBB->getTerminator(); TI->replaceUsesOfWith(BB, UnwindDest); } } @@ -4062,7 +4069,7 @@ bool SimplifyCFGOpt::SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder) { SmallVector<BranchInst *, 8> CondBranchPreds; for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { BasicBlock *P = *PI; - TerminatorInst *PTI = P->getTerminator(); + Instruction *PTI = P->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(PTI)) { if (BI->isUnconditional()) UncondBranchPreds.push_back(P); @@ -4083,9 +4090,9 @@ bool SimplifyCFGOpt::SimplifyReturn(ReturnInst *RI, IRBuilder<> &Builder) { // If we eliminated all predecessors of the block, delete the block now. if (pred_empty(BB)) { // We know there are no successors, so just nuke the block. - BB->eraseFromParent(); if (LoopHeaders) LoopHeaders->erase(BB); + BB->eraseFromParent(); } return true; @@ -4167,7 +4174,7 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { SmallVector<BasicBlock *, 8> Preds(pred_begin(BB), pred_end(BB)); for (unsigned i = 0, e = Preds.size(); i != e; ++i) { - TerminatorInst *TI = Preds[i]->getTerminator(); + Instruction *TI = Preds[i]->getTerminator(); IRBuilder<> Builder(TI); if (auto *BI = dyn_cast<BranchInst>(TI)) { if (BI->isUnconditional()) { @@ -4179,10 +4186,10 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { } else { if (BI->getSuccessor(0) == BB) { Builder.CreateBr(BI->getSuccessor(1)); - EraseTerminatorInstAndDCECond(BI); + EraseTerminatorAndDCECond(BI); } else if (BI->getSuccessor(1) == BB) { Builder.CreateBr(BI->getSuccessor(0)); - EraseTerminatorInstAndDCECond(BI); + EraseTerminatorAndDCECond(BI); Changed = true; } } @@ -4245,9 +4252,9 @@ bool SimplifyCFGOpt::SimplifyUnreachable(UnreachableInst *UI) { // If this block is now dead, remove it. if (pred_empty(BB) && BB != &BB->getParent()->getEntryBlock()) { // We know there are no successors, so just nuke the block. - BB->eraseFromParent(); if (LoopHeaders) LoopHeaders->erase(BB); + BB->eraseFromParent(); return true; } @@ -4424,7 +4431,7 @@ static bool eliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, SplitBlock(&*NewDefault, &NewDefault->front()); auto *OldTI = NewDefault->getTerminator(); new UnreachableInst(SI->getContext(), OldTI); - EraseTerminatorInstAndDCECond(OldTI); + EraseTerminatorAndDCECond(OldTI); return true; } @@ -4635,12 +4642,12 @@ GetCaseResults(SwitchInst *SI, ConstantInt *CaseVal, BasicBlock *CaseDest, SmallDenseMap<Value *, Constant *> ConstantPool; ConstantPool.insert(std::make_pair(SI->getCondition(), CaseVal)); for (Instruction &I :CaseDest->instructionsWithoutDebug()) { - if (TerminatorInst *T = dyn_cast<TerminatorInst>(&I)) { + if (I.isTerminator()) { // If the terminator is a simple branch, continue to the next block. - if (T->getNumSuccessors() != 1 || T->isExceptional()) + if (I.getNumSuccessors() != 1 || I.isExceptionalTerminator()) return false; Pred = CaseDest; - CaseDest = T->getSuccessor(0); + CaseDest = I.getSuccessor(0); } else if (Constant *C = ConstantFold(&I, DL, ConstantPool)) { // Instruction is side-effect free and constant. @@ -5031,6 +5038,9 @@ SwitchLookupTable::SwitchLookupTable( GlobalVariable::PrivateLinkage, Initializer, "switch.table." + FuncName); Array->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); + // Set the alignment to that of an array items. We will be only loading one + // value out of it. + Array->setAlignment(DL.getPrefTypeAlignment(ValueType)); Kind = ArrayKind; } @@ -5257,7 +5267,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Figure out the corresponding result for each case value and phi node in the // common destination, as well as the min and max case values. - assert(SI->case_begin() != SI->case_end()); + assert(!empty(SI->cases())); SwitchInst::CaseIt CI = SI->case_begin(); ConstantInt *MinCaseVal = CI->getCaseValue(); ConstantInt *MaxCaseVal = CI->getCaseValue(); @@ -5509,7 +5519,7 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, SmallVector<int64_t,4> Values; for (auto &C : SI->cases()) Values.push_back(C.getCaseValue()->getValue().getSExtValue()); - llvm::sort(Values.begin(), Values.end()); + llvm::sort(Values); // If the switch is already dense, there's nothing useful to do here. if (isSwitchDense(Values)) @@ -5583,33 +5593,33 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // see if that predecessor totally determines the outcome of this switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(SI, OnlyPred, Builder)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); Value *Cond = SI->getCondition(); if (SelectInst *Select = dyn_cast<SelectInst>(Cond)) if (SimplifySwitchOnSelect(SI, Select)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); // If the block only contains the switch, see if we can fold the block // away into any preds. if (SI == &*BB->instructionsWithoutDebug().begin()) if (FoldValueComparisonIntoPredecessors(SI, Builder)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } // Try to transform the switch into an icmp and a branch. if (TurnSwitchRangeIntoICmp(SI, Builder)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); // Remove unreachable cases. if (eliminateDeadSwitchCases(SI, Options.AC, DL)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); if (switchToSelect(SI, Builder, DL, TTI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); if (Options.ForwardSwitchCondToPhi && ForwardSwitchConditionToPHI(SI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); // The conversion from switch to lookup tables results in difficult-to-analyze // code and makes pruning branches much harder. This is a problem if the @@ -5618,10 +5628,10 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // optimisation pipeline. if (Options.ConvertSwitchToLookupTable && SwitchToLookupTable(SI, Builder, DL, TTI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); if (ReduceSwitchRange(SI, Builder, DL, TTI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); return false; } @@ -5646,20 +5656,20 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { if (IBI->getNumDestinations() == 0) { // If the indirectbr has no successors, change it to unreachable. new UnreachableInst(IBI->getContext(), IBI); - EraseTerminatorInstAndDCECond(IBI); + EraseTerminatorAndDCECond(IBI); return true; } if (IBI->getNumDestinations() == 1) { // If the indirectbr has one successor, change it to a direct branch. BranchInst::Create(IBI->getDestination(0), IBI); - EraseTerminatorInstAndDCECond(IBI); + EraseTerminatorAndDCECond(IBI); return true; } if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (SimplifyIndirectBrOnSelect(IBI, SI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } return Changed; } @@ -5755,7 +5765,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, // backedge, so we can eliminate BB. bool NeedCanonicalLoop = Options.NeedCanonicalLoop && - (LoopHeaders && pred_size(BB) > 1 && + (LoopHeaders && BB->hasNPredecessorsOrMore(2) && (LoopHeaders->count(BB) || LoopHeaders->count(Succ))); BasicBlock::iterator I = BB->getFirstNonPHIOrDbg()->getIterator(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && @@ -5769,7 +5779,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, for (++I; isa<DbgInfoIntrinsic>(I); ++I) ; if (I->isTerminator() && - tryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, DL, TTI, Options)) + tryToSimplifyUncondBranchWithICmpInIt(ICI, Builder)) return true; } @@ -5787,7 +5797,7 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. if (FoldBranchToCommonDest(BI, Options.BonusInstThreshold)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); return false; } @@ -5815,18 +5825,18 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); // This block must be empty, except for the setcond inst, if it exists. // Ignore dbg intrinsics. auto I = BB->instructionsWithoutDebug().begin(); if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } else if (&*I == cast<Instruction>(BI->getCondition())) { ++I; if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } } @@ -5834,35 +5844,24 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (SimplifyBranchOnICmpChain(BI, Builder, DL)) return true; - // If this basic block has a single dominating predecessor block and the - // dominating block's condition implies BI's condition, we know the direction - // of the BI branch. - if (BasicBlock *Dom = BB->getSinglePredecessor()) { - auto *PBI = dyn_cast_or_null<BranchInst>(Dom->getTerminator()); - if (PBI && PBI->isConditional() && - PBI->getSuccessor(0) != PBI->getSuccessor(1)) { - assert(PBI->getSuccessor(0) == BB || PBI->getSuccessor(1) == BB); - bool CondIsTrue = PBI->getSuccessor(0) == BB; - Optional<bool> Implication = isImpliedCondition( - PBI->getCondition(), BI->getCondition(), DL, CondIsTrue); - if (Implication) { - // Turn this into a branch on constant. - auto *OldCond = BI->getCondition(); - ConstantInt *CI = *Implication - ? ConstantInt::getTrue(BB->getContext()) - : ConstantInt::getFalse(BB->getContext()); - BI->setCondition(CI); - RecursivelyDeleteTriviallyDeadInstructions(OldCond); - return simplifyCFG(BB, TTI, Options) | true; - } - } + // If this basic block has dominating predecessor blocks and the dominating + // blocks' conditions imply BI's condition, we know the direction of BI. + Optional<bool> Imp = isImpliedByDomCondition(BI->getCondition(), BI, DL); + if (Imp) { + // Turn this into a branch on constant. + auto *OldCond = BI->getCondition(); + ConstantInt *TorF = *Imp ? ConstantInt::getTrue(BB->getContext()) + : ConstantInt::getFalse(BB->getContext()); + BI->setCondition(TorF); + RecursivelyDeleteTriviallyDeadInstructions(OldCond); + return requestResimplify(); } // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. if (FoldBranchToCommonDest(BI, Options.BonusInstThreshold)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); // We have a conditional branch to two blocks that are only reachable // from BI. We know that the condbr dominates the two blocks, so see if @@ -5871,24 +5870,24 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { if (HoistThenElseCodeToIf(BI, TTI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } else { // If Successor #1 has multiple preds, we may be able to conditionally // execute Successor #0 if it branches to Successor #1. - TerminatorInst *Succ0TI = BI->getSuccessor(0)->getTerminator(); + Instruction *Succ0TI = BI->getSuccessor(0)->getTerminator(); if (Succ0TI->getNumSuccessors() == 1 && Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), TTI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } } else if (BI->getSuccessor(1)->getSinglePredecessor()) { // If Successor #0 has multiple preds, we may be able to conditionally // execute Successor #1 if it branches to Successor #0. - TerminatorInst *Succ1TI = BI->getSuccessor(1)->getTerminator(); + Instruction *Succ1TI = BI->getSuccessor(1)->getTerminator(); if (Succ1TI->getNumSuccessors() == 1 && Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), TTI)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); } // If this is a branch on a phi node in the current block, thread control @@ -5896,14 +5895,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) if (PN->getParent() == BI->getParent()) if (FoldCondBranchOnPHI(BI, DL, Options.AC)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); // Scan predecessor blocks for conditional branches. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) if (PBI != BI && PBI->isConditional()) if (SimplifyCondBranchToCondBranch(PBI, BI, DL)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); // Look for diamond patterns. if (MergeCondStores) @@ -5911,7 +5910,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BranchInst *PBI = dyn_cast<BranchInst>(PrevBB->getTerminator())) if (PBI != BI && PBI->isConditional()) if (mergeConditionalStores(PBI, BI, DL)) - return simplifyCFG(BB, TTI, Options) | true; + return requestResimplify(); return false; } @@ -5974,7 +5973,7 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB) { for (PHINode &PHI : BB->phis()) for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i) if (passingValueIsAlwaysUndefined(PHI.getIncomingValue(i), &PHI)) { - TerminatorInst *T = PHI.getIncomingBlock(i)->getTerminator(); + Instruction *T = PHI.getIncomingBlock(i)->getTerminator(); IRBuilder<> Builder(T); if (BranchInst *BI = dyn_cast<BranchInst>(T)) { BB->removePredecessor(PHI.getIncomingBlock(i)); @@ -5994,7 +5993,7 @@ static bool removeUndefIntroducingPredecessor(BasicBlock *BB) { return false; } -bool SimplifyCFGOpt::run(BasicBlock *BB) { +bool SimplifyCFGOpt::simplifyOnce(BasicBlock *BB) { bool Changed = false; assert(BB && BB->getParent() && "Block not embedded in function!"); @@ -6068,6 +6067,21 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { return Changed; } +bool SimplifyCFGOpt::run(BasicBlock *BB) { + bool Changed = false; + + // Repeated simplify BB as long as resimplification is requested. + do { + Resimplify = false; + + // Perform one round of simplifcation. Resimplify flag will be set if + // another iteration is requested. + Changed |= simplifyOnce(BB); + } while (Resimplify); + + return Changed; +} + bool llvm::simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, const SimplifyCFGOptions &Options, SmallPtrSetImpl<BasicBlock *> *LoopHeaders) { diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index 65b23f4d94a1..7faf291e73d9 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -106,8 +106,9 @@ namespace { /// Otherwise return null. Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) { Value *IVSrc = nullptr; - unsigned OperIdx = 0; + const unsigned OperIdx = 0; const SCEV *FoldedExpr = nullptr; + bool MustDropExactFlag = false; switch (UseInst->getOpcode()) { default: return nullptr; @@ -140,6 +141,11 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) APInt::getOneBitSet(BitWidth, D->getZExtValue())); } FoldedExpr = SE->getUDivExpr(SE->getSCEV(IVSrc), SE->getSCEV(D)); + // We might have 'exact' flag set at this point which will no longer be + // correct after we make the replacement. + if (UseInst->isExact() && + SE->getSCEV(IVSrc) != SE->getMulExpr(FoldedExpr, SE->getSCEV(D))) + MustDropExactFlag = true; } // We have something that might fold it's operand. Compare SCEVs. if (!SE->isSCEVable(UseInst->getType())) @@ -155,6 +161,9 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) UseInst->setOperand(OperIdx, IVSrc); assert(SE->getSCEV(UseInst) == FoldedExpr && "bad SCEV with folded oper"); + if (MustDropExactFlag) + UseInst->dropPoisonGeneratingFlags(); + ++NumElimOperand; Changed = true; if (IVOperand->use_empty()) diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 15e035874002..1bb26caa2af2 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/SimplifyLibCalls.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" @@ -22,6 +23,7 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/Loads.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -150,6 +152,32 @@ static bool isLocallyOpenedFile(Value *File, CallInst *CI, IRBuilder<> &B, return true; } +static bool isOnlyUsedInComparisonWithZero(Value *V) { + for (User *U : V->users()) { + if (ICmpInst *IC = dyn_cast<ICmpInst>(U)) + if (Constant *C = dyn_cast<Constant>(IC->getOperand(1))) + if (C->isNullValue()) + continue; + // Unknown instruction. + return false; + } + return true; +} + +static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len, + const DataLayout &DL) { + if (!isOnlyUsedInComparisonWithZero(CI)) + return false; + + if (!isDereferenceableAndAlignedPointer(Str, 1, APInt(64, Len), DL)) + return false; + + if (CI->getFunction()->hasFnAttribute(Attribute::SanitizeMemory)) + return false; + + return true; +} + //===----------------------------------------------------------------------===// // String and Memory Library Call Optimizations //===----------------------------------------------------------------------===// @@ -322,6 +350,21 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilder<> &B) { B, DL, TLI); } + // strcmp to memcmp + if (!HasStr1 && HasStr2) { + if (canTransformToMemCmp(CI, Str1P, Len2, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2), B, DL, + TLI); + } else if (HasStr1 && !HasStr2) { + if (canTransformToMemCmp(CI, Str2P, Len1, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1), B, DL, + TLI); + } + return nullptr; } @@ -361,6 +404,26 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilder<> &B) { if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x return B.CreateZExt(B.CreateLoad(Str1P, "strcmpload"), CI->getType()); + uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len2 = GetStringLength(Str2P); + + // strncmp to memcmp + if (!HasStr1 && HasStr2) { + Len2 = std::min(Len2, Length); + if (canTransformToMemCmp(CI, Str1P, Len2, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2), B, DL, + TLI); + } else if (HasStr1 && !HasStr2) { + Len1 = std::min(Len1, Length); + if (canTransformToMemCmp(CI, Str2P, Len1, DL)) + return emitMemCmp( + Str1P, Str2P, + ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1), B, DL, + TLI); + } + return nullptr; } @@ -735,8 +798,11 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilder<> &B) { Bitfield.setBit((unsigned char)C); Value *BitfieldC = B.getInt(Bitfield); - // First check that the bit field access is within bounds. + // Adjust width of "C" to the bitfield width, then mask off the high bits. Value *C = B.CreateZExtOrTrunc(CI->getArgOperand(1), BitfieldC->getType()); + C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); + + // First check that the bit field access is within bounds. Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), "memchr.bounds"); @@ -860,8 +926,7 @@ Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) { } /// Fold memset[_chk](malloc(n), 0, n) --> calloc(1, n). -static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, - const TargetLibraryInfo &TLI) { +Value *LibCallSimplifier::foldMallocMemset(CallInst *Memset, IRBuilder<> &B) { // This has to be a memset of zeros (bzero). auto *FillValue = dyn_cast<ConstantInt>(Memset->getArgOperand(1)); if (!FillValue || FillValue->getZExtValue() != 0) @@ -881,7 +946,7 @@ static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, return nullptr; LibFunc Func; - if (!TLI.getLibFunc(*InnerCallee, Func) || !TLI.has(Func) || + if (!TLI->getLibFunc(*InnerCallee, Func) || !TLI->has(Func) || Func != LibFunc_malloc) return nullptr; @@ -896,18 +961,18 @@ static Value *foldMallocMemset(CallInst *Memset, IRBuilder<> &B, IntegerType *SizeType = DL.getIntPtrType(B.GetInsertBlock()->getContext()); Value *Calloc = emitCalloc(ConstantInt::get(SizeType, 1), Malloc->getArgOperand(0), Malloc->getAttributes(), - B, TLI); + B, *TLI); if (!Calloc) return nullptr; Malloc->replaceAllUsesWith(Calloc); - Malloc->eraseFromParent(); + eraseFromParent(Malloc); return Calloc; } Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) { - if (auto *Calloc = foldMallocMemset(CI, B, *TLI)) + if (auto *Calloc = foldMallocMemset(CI, B)) return Calloc; // memset(p, v, n) -> llvm.memset(align 1 p, v, n) @@ -927,6 +992,20 @@ Value *LibCallSimplifier::optimizeRealloc(CallInst *CI, IRBuilder<> &B) { // Math Library Optimizations //===----------------------------------------------------------------------===// +// Replace a libcall \p CI with a call to intrinsic \p IID +static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { + // Propagate fast-math flags from the existing call to the new call. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + Module *M = CI->getModule(); + Value *V = CI->getArgOperand(0); + Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); + CallInst *NewCall = B.CreateCall(F, V); + NewCall->takeName(CI); + return NewCall; +} + /// Return a variant of Val with float type. /// Currently this works in two cases: If Val is an FPExtension of a float /// value to something bigger, simply return the operand. @@ -949,104 +1028,75 @@ static Value *valueHasFloatPrecision(Value *Val) { return nullptr; } -/// Shrink double -> float for unary functions like 'floor'. -static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, - bool CheckRetType) { - Function *Callee = CI->getCalledFunction(); - // We know this libcall has a valid prototype, but we don't know which. +/// Shrink double -> float functions. +static Value *optimizeDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isBinary, bool isPrecise = false) { if (!CI->getType()->isDoubleTy()) return nullptr; - if (CheckRetType) { - // Check if all the uses for function like 'sin' are converted to float. + // If not all the uses of the function are converted to float, then bail out. + // This matters if the precision of the result is more important than the + // precision of the arguments. + if (isPrecise) for (User *U : CI->users()) { FPTruncInst *Cast = dyn_cast<FPTruncInst>(U); if (!Cast || !Cast->getType()->isFloatTy()) return nullptr; } - } - // If this is something like 'floor((double)floatval)', convert to floorf. - Value *V = valueHasFloatPrecision(CI->getArgOperand(0)); - if (V == nullptr) + // If this is something like 'g((double) float)', convert to 'gf(float)'. + Value *V[2]; + V[0] = valueHasFloatPrecision(CI->getArgOperand(0)); + V[1] = isBinary ? valueHasFloatPrecision(CI->getArgOperand(1)) : nullptr; + if (!V[0] || (isBinary && !V[1])) return nullptr; // If call isn't an intrinsic, check that it isn't within a function with the - // same name as the float version of this call. + // same name as the float version of this call, otherwise the result is an + // infinite loop. For example, from MinGW-w64: // - // e.g. inline float expf(float val) { return (float) exp((double) val); } - // - // A similar such definition exists in the MinGW-w64 math.h header file which - // when compiled with -O2 -ffast-math causes the generation of infinite loops - // where expf is called. - if (!Callee->isIntrinsic()) { - const Function *F = CI->getFunction(); - StringRef FName = F->getName(); - StringRef CalleeName = Callee->getName(); - if ((FName.size() == (CalleeName.size() + 1)) && - (FName.back() == 'f') && - FName.startswith(CalleeName)) + // float expf(float val) { return (float) exp((double) val); } + Function *CalleeFn = CI->getCalledFunction(); + StringRef CalleeNm = CalleeFn->getName(); + AttributeList CalleeAt = CalleeFn->getAttributes(); + if (CalleeFn && !CalleeFn->isIntrinsic()) { + const Function *Fn = CI->getFunction(); + StringRef FnName = Fn->getName(); + if (FnName.back() == 'f' && + FnName.size() == (CalleeNm.size() + 1) && + FnName.startswith(CalleeNm)) return nullptr; } - // Propagate fast-math flags from the existing call to the new call. + // Propagate the math semantics from the current function to the new function. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(CI->getFastMathFlags()); - // floor((double)floatval) -> (double)floorf(floatval) - if (Callee->isIntrinsic()) { + // g((double) float) -> (double) gf(float) + Value *R; + if (CalleeFn->isIntrinsic()) { Module *M = CI->getModule(); - Intrinsic::ID IID = Callee->getIntrinsicID(); - Function *F = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); - V = B.CreateCall(F, V); - } else { - // The call is a library call rather than an intrinsic. - V = emitUnaryFloatFnCall(V, Callee->getName(), B, Callee->getAttributes()); + Intrinsic::ID IID = CalleeFn->getIntrinsicID(); + Function *Fn = Intrinsic::getDeclaration(M, IID, B.getFloatTy()); + R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]); } + else + R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], CalleeNm, B, CalleeAt) + : emitUnaryFloatFnCall(V[0], CalleeNm, B, CalleeAt); - return B.CreateFPExt(V, B.getDoubleTy()); + return B.CreateFPExt(R, B.getDoubleTy()); } -// Replace a libcall \p CI with a call to intrinsic \p IID -static Value *replaceUnaryCall(CallInst *CI, IRBuilder<> &B, Intrinsic::ID IID) { - // Propagate fast-math flags from the existing call to the new call. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - Module *M = CI->getModule(); - Value *V = CI->getArgOperand(0); - Function *F = Intrinsic::getDeclaration(M, IID, CI->getType()); - CallInst *NewCall = B.CreateCall(F, V); - NewCall->takeName(CI); - return NewCall; +/// Shrink double -> float for unary functions. +static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isPrecise = false) { + return optimizeDoubleFP(CI, B, false, isPrecise); } -/// Shrink double -> float for binary functions like 'fmin/fmax'. -static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - // We know this libcall has a valid prototype, but we don't know which. - if (!CI->getType()->isDoubleTy()) - return nullptr; - - // If this is something like 'fmin((double)floatval1, (double)floatval2)', - // or fmin(1.0, (double)floatval), then we convert it to fminf. - Value *V1 = valueHasFloatPrecision(CI->getArgOperand(0)); - if (V1 == nullptr) - return nullptr; - Value *V2 = valueHasFloatPrecision(CI->getArgOperand(1)); - if (V2 == nullptr) - return nullptr; - - // Propagate fast-math flags from the existing call to the new call. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - // fmin((double)floatval1, (double)floatval2) - // -> (double)fminf(floatval1, floatval2) - // TODO: Handle intrinsics in the same way as in optimizeUnaryDoubleFP(). - Value *V = emitBinaryFloatFnCall(V1, V2, Callee->getName(), B, - Callee->getAttributes()); - return B.CreateFPExt(V, B.getDoubleTy()); +/// Shrink double -> float for binary functions. +static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B, + bool isPrecise = false) { + return optimizeDoubleFP(CI, B, true, isPrecise); } // cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z))) @@ -1078,20 +1128,39 @@ Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilder<> &B) { return B.CreateCall(FSqrt, B.CreateFAdd(RealReal, ImagImag), "cabs"); } -Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) { - Function *Callee = CI->getCalledFunction(); - Value *Ret = nullptr; - StringRef Name = Callee->getName(); - if (UnsafeFPShrink && Name == "cos" && hasFloatVersion(Name)) - Ret = optimizeUnaryDoubleFP(CI, B, true); - - // cos(-x) -> cos(x) - Value *Op1 = CI->getArgOperand(0); - if (BinaryOperator::isFNeg(Op1)) { - BinaryOperator *BinExpr = cast<BinaryOperator>(Op1); - return B.CreateCall(Callee, BinExpr->getOperand(1), "cos"); +static Value *optimizeTrigReflections(CallInst *Call, LibFunc Func, + IRBuilder<> &B) { + if (!isa<FPMathOperator>(Call)) + return nullptr; + + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Call->getFastMathFlags()); + + // TODO: Can this be shared to also handle LLVM intrinsics? + Value *X; + switch (Func) { + case LibFunc_sin: + case LibFunc_sinf: + case LibFunc_sinl: + case LibFunc_tan: + case LibFunc_tanf: + case LibFunc_tanl: + // sin(-X) --> -sin(X) + // tan(-X) --> -tan(X) + if (match(Call->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) + return B.CreateFNeg(B.CreateCall(Call->getCalledFunction(), X)); + break; + case LibFunc_cos: + case LibFunc_cosf: + case LibFunc_cosl: + // cos(-X) --> cos(X) + if (match(Call->getArgOperand(0), m_FNeg(m_Value(X)))) + return B.CreateCall(Call->getCalledFunction(), X, "cos"); + break; + default: + break; } - return Ret; + return nullptr; } static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) { @@ -1119,37 +1188,175 @@ static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) { return InnerChain[Exp]; } -/// Use square root in place of pow(x, +/-0.5). -Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { - // TODO: There is some subset of 'fast' under which these transforms should - // be allowed. - if (!Pow->isFast()) - return nullptr; - - Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); +/// Use exp{,2}(x * y) for pow(exp{,2}(x), y); +/// exp2(n * x) for pow(2.0 ** n, x); exp10(x) for pow(10.0, x). +Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilder<> &B) { + Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); + bool Ignored; - const APFloat *ExpoF; - if (!match(Expo, m_APFloat(ExpoF)) || - (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) + // Evaluate special cases related to a nested function as the base. + + // pow(exp(x), y) -> exp(x * y) + // pow(exp2(x), y) -> exp2(x * y) + // If exp{,2}() is used only once, it is better to fold two transcendental + // math functions into one. If used again, exp{,2}() would still have to be + // called with the original argument, then keep both original transcendental + // functions. However, this transformation is only safe with fully relaxed + // math semantics, since, besides rounding differences, it changes overflow + // and underflow behavior quite dramatically. For example: + // pow(exp(1000), 0.001) = pow(inf, 0.001) = inf + // Whereas: + // exp(1000 * 0.001) = exp(1) + // TODO: Loosen the requirement for fully relaxed math semantics. + // TODO: Handle exp10() when more targets have it available. + CallInst *BaseFn = dyn_cast<CallInst>(Base); + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()) { + LibFunc LibFn; + + Function *CalleeFn = BaseFn->getCalledFunction(); + if (CalleeFn && + TLI->getLibFunc(CalleeFn->getName(), LibFn) && TLI->has(LibFn)) { + StringRef ExpName; + Intrinsic::ID ID; + Value *ExpFn; + LibFunc LibFnFloat; + LibFunc LibFnDouble; + LibFunc LibFnLongDouble; + + switch (LibFn) { + default: + return nullptr; + case LibFunc_expf: case LibFunc_exp: case LibFunc_expl: + ExpName = TLI->getName(LibFunc_exp); + ID = Intrinsic::exp; + LibFnFloat = LibFunc_expf; + LibFnDouble = LibFunc_exp; + LibFnLongDouble = LibFunc_expl; + break; + case LibFunc_exp2f: case LibFunc_exp2: case LibFunc_exp2l: + ExpName = TLI->getName(LibFunc_exp2); + ID = Intrinsic::exp2; + LibFnFloat = LibFunc_exp2f; + LibFnDouble = LibFunc_exp2; + LibFnLongDouble = LibFunc_exp2l; + break; + } + + // Create new exp{,2}() with the product as its argument. + Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); + ExpFn = BaseFn->doesNotAccessMemory() + ? B.CreateCall(Intrinsic::getDeclaration(Mod, ID, Ty), + FMul, ExpName) + : emitUnaryFloatFnCall(FMul, TLI, LibFnDouble, LibFnFloat, + LibFnLongDouble, B, + BaseFn->getAttributes()); + + // Since the new exp{,2}() is different from the original one, dead code + // elimination cannot be trusted to remove it, since it may have side + // effects (e.g., errno). When the only consumer for the original + // exp{,2}() is pow(), then it has to be explicitly erased. + BaseFn->replaceAllUsesWith(ExpFn); + eraseFromParent(BaseFn); + + return ExpFn; + } + } + + // Evaluate special cases related to a constant base. + + const APFloat *BaseF; + if (!match(Pow->getArgOperand(0), m_APFloat(BaseF))) return nullptr; + // pow(2.0 ** n, x) -> exp2(n * x) + if (hasUnaryFloatFn(TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) { + APFloat BaseR = APFloat(1.0); + BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored); + BaseR = BaseR / *BaseF; + bool IsInteger = BaseF->isInteger(), + IsReciprocal = BaseR.isInteger(); + const APFloat *NF = IsReciprocal ? &BaseR : BaseF; + APSInt NI(64, false); + if ((IsInteger || IsReciprocal) && + !NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) && + NI > 1 && NI.isPowerOf2()) { + double N = NI.logBase2() * (IsReciprocal ? -1.0 : 1.0); + Value *FMul = B.CreateFMul(Expo, ConstantFP::get(Ty, N), "mul"); + if (Pow->doesNotAccessMemory()) + return B.CreateCall(Intrinsic::getDeclaration(Mod, Intrinsic::exp2, Ty), + FMul, "exp2"); + else + return emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l, B, Attrs); + } + } + + // pow(10.0, x) -> exp10(x) + // TODO: There is no exp10() intrinsic yet, but some day there shall be one. + if (match(Base, m_SpecificFP(10.0)) && + hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) + return emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10, LibFunc_exp10f, + LibFunc_exp10l, B, Attrs); + + return nullptr; +} + +static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno, + Module *M, IRBuilder<> &B, + const TargetLibraryInfo *TLI) { // If errno is never set, then use the intrinsic for sqrt(). - if (Pow->hasFnAttr(Attribute::ReadNone)) { - Function *SqrtFn = Intrinsic::getDeclaration(Pow->getModule(), - Intrinsic::sqrt, Ty); - Sqrt = B.CreateCall(SqrtFn, Base); + if (NoErrno) { + Function *SqrtFn = + Intrinsic::getDeclaration(M, Intrinsic::sqrt, V->getType()); + return B.CreateCall(SqrtFn, V, "sqrt"); } + // Otherwise, use the libcall for sqrt(). - else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) + if (hasUnaryFloatFn(TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) // TODO: We also should check that the target can in fact lower the sqrt() // libcall. We currently have no way to ask this question, so we ask if // the target has a sqrt() libcall, which is not exactly the same. - Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), B, - Pow->getCalledFunction()->getAttributes()); - else + return emitUnaryFloatFnCall(V, TLI, LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl, B, Attrs); + + return nullptr; +} + +/// Use square root in place of pow(x, +/-0.5). +Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { + Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); + AttributeList Attrs = Pow->getCalledFunction()->getAttributes(); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + + const APFloat *ExpoF; + if (!match(Expo, m_APFloat(ExpoF)) || + (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5))) return nullptr; + Sqrt = getSqrtCall(Base, Attrs, Pow->doesNotAccessMemory(), Mod, B, TLI); + if (!Sqrt) + return nullptr; + + // Handle signed zero base by expanding to fabs(sqrt(x)). + if (!Pow->hasNoSignedZeros()) { + Function *FAbsFn = Intrinsic::getDeclaration(Mod, Intrinsic::fabs, Ty); + Sqrt = B.CreateCall(FAbsFn, Sqrt, "abs"); + } + + // Handle non finite base by expanding to + // (x == -infinity ? +infinity : sqrt(x)). + if (!Pow->hasNoInfs()) { + Value *PosInf = ConstantFP::getInfinity(Ty), + *NegInf = ConstantFP::getInfinity(Ty, true); + Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf"); + Sqrt = B.CreateSelect(FCmp, PosInf, Sqrt); + } + // If the exponent is negative, then get the reciprocal. if (ExpoF->isNegative()) Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt, "reciprocal"); @@ -1160,134 +1367,109 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilder<> &B) { Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1); Function *Callee = Pow->getCalledFunction(); - AttributeList Attrs = Callee->getAttributes(); StringRef Name = Callee->getName(); - Module *Module = Pow->getModule(); Type *Ty = Pow->getType(); Value *Shrunk = nullptr; bool Ignored; - if (UnsafeFPShrink && - Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) - Shrunk = optimizeUnaryDoubleFP(Pow, B, true); + // Bail out if simplifying libcalls to pow() is disabled. + if (!hasUnaryFloatFn(TLI, Ty, LibFunc_pow, LibFunc_powf, LibFunc_powl)) + return nullptr; // Propagate the math semantics from the call to any created instructions. IRBuilder<>::FastMathFlagGuard Guard(B); B.setFastMathFlags(Pow->getFastMathFlags()); + // Shrink pow() to powf() if the arguments are single precision, + // unless the result is expected to be double precision. + if (UnsafeFPShrink && + Name == TLI->getName(LibFunc_pow) && hasFloatVersion(Name)) + Shrunk = optimizeBinaryDoubleFP(Pow, B, true); + // Evaluate special cases related to the base. // pow(1.0, x) -> 1.0 - if (match(Base, m_SpecificFP(1.0))) + if (match(Base, m_FPOne())) return Base; - // pow(2.0, x) -> exp2(x) - if (match(Base, m_SpecificFP(2.0))) { - Value *Exp2 = Intrinsic::getDeclaration(Module, Intrinsic::exp2, Ty); - return B.CreateCall(Exp2, Expo, "exp2"); - } - - // pow(10.0, x) -> exp10(x) - if (ConstantFP *BaseC = dyn_cast<ConstantFP>(Base)) - // There's no exp10 intrinsic yet, but, maybe, some day there shall be one. - if (BaseC->isExactlyValue(10.0) && - hasUnaryFloatFn(TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) - return emitUnaryFloatFnCall(Expo, TLI->getName(LibFunc_exp10), B, Attrs); - - // pow(exp(x), y) -> exp(x * y) - // pow(exp2(x), y) -> exp2(x * y) - // We enable these only with fast-math. Besides rounding differences, the - // transformation changes overflow and underflow behavior quite dramatically. - // Example: x = 1000, y = 0.001. - // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1). - auto *BaseFn = dyn_cast<CallInst>(Base); - if (BaseFn && BaseFn->isFast() && Pow->isFast()) { - LibFunc LibFn; - Function *CalleeFn = BaseFn->getCalledFunction(); - if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && - (LibFn == LibFunc_exp || LibFn == LibFunc_exp2) && TLI->has(LibFn)) { - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(Pow->getFastMathFlags()); - - Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul"); - return emitUnaryFloatFnCall(FMul, CalleeFn->getName(), B, - CalleeFn->getAttributes()); - } - } + if (Value *Exp = replacePowWithExp(Pow, B)) + return Exp; // Evaluate special cases related to the exponent. - if (Value *Sqrt = replacePowWithSqrt(Pow, B)) - return Sqrt; - - ConstantFP *ExpoC = dyn_cast<ConstantFP>(Expo); - if (!ExpoC) - return Shrunk; - // pow(x, -1.0) -> 1.0 / x - if (ExpoC->isExactlyValue(-1.0)) + if (match(Expo, m_SpecificFP(-1.0))) return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal"); // pow(x, 0.0) -> 1.0 - if (ExpoC->getValueAPF().isZero()) - return ConstantFP::get(Ty, 1.0); + if (match(Expo, m_SpecificFP(0.0))) + return ConstantFP::get(Ty, 1.0); // pow(x, 1.0) -> x - if (ExpoC->isExactlyValue(1.0)) + if (match(Expo, m_FPOne())) return Base; // pow(x, 2.0) -> x * x - if (ExpoC->isExactlyValue(2.0)) + if (match(Expo, m_SpecificFP(2.0))) return B.CreateFMul(Base, Base, "square"); - // FIXME: Correct the transforms and pull this into replacePowWithSqrt(). - if (ExpoC->isExactlyValue(0.5) && - hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) { - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). - // This is faster than calling pow(), and still handles -0.0 and - // negative infinity correctly. - // TODO: In finite-only mode, this could be just fabs(sqrt(x)). - Value *PosInf = ConstantFP::getInfinity(Ty); - Value *NegInf = ConstantFP::getInfinity(Ty, true); - - // TODO: As above, we should lower to the sqrt() intrinsic if the pow() is - // an intrinsic, to match errno semantics. - Value *Sqrt = emitUnaryFloatFnCall(Base, TLI->getName(LibFunc_sqrt), - B, Attrs); - Function *FAbsFn = Intrinsic::getDeclaration(Module, Intrinsic::fabs, Ty); - Value *FAbs = B.CreateCall(FAbsFn, Sqrt, "abs"); - Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf"); - Sqrt = B.CreateSelect(FCmp, PosInf, FAbs); + if (Value *Sqrt = replacePowWithSqrt(Pow, B)) return Sqrt; - } - // pow(x, n) -> x * x * x * .... - if (Pow->isFast()) { - APFloat ExpoA = abs(ExpoC->getValueAPF()); - // We limit to a max of 7 fmul(s). Thus the maximum exponent is 32. - // This transformation applies to integer exponents only. - if (!ExpoA.isInteger() || - ExpoA.compare - (APFloat(ExpoA.getSemantics(), 32.0)) == APFloat::cmpGreaterThan) - return nullptr; + // pow(x, n) -> x * x * x * ... + const APFloat *ExpoF; + if (Pow->isFast() && match(Expo, m_APFloat(ExpoF))) { + // We limit to a max of 7 multiplications, thus the maximum exponent is 32. + // If the exponent is an integer+0.5 we generate a call to sqrt and an + // additional fmul. + // TODO: This whole transformation should be backend specific (e.g. some + // backends might prefer libcalls or the limit for the exponent might + // be different) and it should also consider optimizing for size. + APFloat LimF(ExpoF->getSemantics(), 33.0), + ExpoA(abs(*ExpoF)); + if (ExpoA.compare(LimF) == APFloat::cmpLessThan) { + // This transformation applies to integer or integer+0.5 exponents only. + // For integer+0.5, we create a sqrt(Base) call. + Value *Sqrt = nullptr; + if (!ExpoA.isInteger()) { + APFloat Expo2 = ExpoA; + // To check if ExpoA is an integer + 0.5, we add it to itself. If there + // is no floating point exception and the result is an integer, then + // ExpoA == integer + 0.5 + if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK) + return nullptr; + + if (!Expo2.isInteger()) + return nullptr; + + Sqrt = + getSqrtCall(Base, Pow->getCalledFunction()->getAttributes(), + Pow->doesNotAccessMemory(), Pow->getModule(), B, TLI); + } - // We will memoize intermediate products of the Addition Chain. - Value *InnerChain[33] = {nullptr}; - InnerChain[1] = Base; - InnerChain[2] = B.CreateFMul(Base, Base, "square"); + // We will memoize intermediate products of the Addition Chain. + Value *InnerChain[33] = {nullptr}; + InnerChain[1] = Base; + InnerChain[2] = B.CreateFMul(Base, Base, "square"); - // We cannot readily convert a non-double type (like float) to a double. - // So we first convert it to something which could be converted to double. - ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); - Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); + // We cannot readily convert a non-double type (like float) to a double. + // So we first convert it to something which could be converted to double. + ExpoA.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); + Value *FMul = getPow(InnerChain, ExpoA.convertToDouble(), B); - // If the exponent is negative, then get the reciprocal. - if (ExpoC->isNegative()) - FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); - return FMul; + // Expand pow(x, y+0.5) to pow(x, y) * sqrt(x). + if (Sqrt) + FMul = B.CreateFMul(FMul, Sqrt); + + // If the exponent is negative, then get the reciprocal. + if (ExpoF->isNegative()) + FMul = B.CreateFDiv(ConstantFP::get(Ty, 1.0), FMul, "reciprocal"); + + return FMul; + } } - return nullptr; + return Shrunk; } Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilder<> &B) { @@ -2285,11 +2467,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, if (CI->isStrictFP()) return nullptr; + if (Value *V = optimizeTrigReflections(CI, Func, Builder)) + return V; + switch (Func) { - case LibFunc_cosf: - case LibFunc_cos: - case LibFunc_cosl: - return optimizeCos(CI, Builder); case LibFunc_sinpif: case LibFunc_sinpi: case LibFunc_cospif: @@ -2344,6 +2525,7 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_exp: case LibFunc_exp10: case LibFunc_expm1: + case LibFunc_cos: case LibFunc_sin: case LibFunc_sinh: case LibFunc_tanh: @@ -2425,7 +2607,7 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { if (Value *V = optimizeStringMemoryLibCall(SimplifiedCI, TmpBuilder)) { // If we were able to further simplify, remove the now redundant call. SimplifiedCI->replaceAllUsesWith(V); - SimplifiedCI->eraseFromParent(); + eraseFromParent(SimplifiedCI); return V; } } @@ -2504,15 +2686,20 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { LibCallSimplifier::LibCallSimplifier( const DataLayout &DL, const TargetLibraryInfo *TLI, OptimizationRemarkEmitter &ORE, - function_ref<void(Instruction *, Value *)> Replacer) + function_ref<void(Instruction *, Value *)> Replacer, + function_ref<void(Instruction *)> Eraser) : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), - UnsafeFPShrink(false), Replacer(Replacer) {} + UnsafeFPShrink(false), Replacer(Replacer), Eraser(Eraser) {} void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // Indirect through the replacer used in this instance. Replacer(I, With); } +void LibCallSimplifier::eraseFromParent(Instruction *I) { + Eraser(I); +} + // TODO: // Additional cases that we need to add to this file: // diff --git a/lib/Transforms/Utils/SplitModule.cpp b/lib/Transforms/Utils/SplitModule.cpp index f8d758c54983..5db4d2e4df9d 100644 --- a/lib/Transforms/Utils/SplitModule.cpp +++ b/lib/Transforms/Utils/SplitModule.cpp @@ -181,14 +181,12 @@ static void findPartitions(Module *M, ClusterIDMapType &ClusterIDMap, std::make_pair(std::distance(GVtoClusterMap.member_begin(I), GVtoClusterMap.member_end()), I)); - llvm::sort(Sets.begin(), Sets.end(), - [](const SortType &a, const SortType &b) { - if (a.first == b.first) - return a.second->getData()->getName() > - b.second->getData()->getName(); - else - return a.first > b.first; - }); + llvm::sort(Sets, [](const SortType &a, const SortType &b) { + if (a.first == b.first) + return a.second->getData()->getName() > b.second->getData()->getName(); + else + return a.first > b.first; + }); for (auto &I : Sets) { unsigned CurrentClusterID = BalancinQueue.top().first; diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp index afd842f59911..95416de07439 100644 --- a/lib/Transforms/Utils/Utils.cpp +++ b/lib/Transforms/Utils/Utils.cpp @@ -26,6 +26,7 @@ using namespace llvm; void llvm::initializeTransformUtils(PassRegistry &Registry) { initializeAddDiscriminatorsLegacyPassPass(Registry); initializeBreakCriticalEdgesPass(Registry); + initializeCanonicalizeAliasesLegacyPassPass(Registry); initializeInstNamerPass(Registry); initializeLCSSAWrapperPassPass(Registry); initializeLibCallsShrinkWrapLegacyPassPass(Registry); diff --git a/lib/Transforms/Vectorize/CMakeLists.txt b/lib/Transforms/Vectorize/CMakeLists.txt index 27a4d241b320..06eaadf58c3f 100644 --- a/lib/Transforms/Vectorize/CMakeLists.txt +++ b/lib/Transforms/Vectorize/CMakeLists.txt @@ -7,6 +7,7 @@ add_llvm_library(LLVMVectorize VPlan.cpp VPlanHCFGBuilder.cpp VPlanHCFGTransforms.cpp + VPlanSLP.cpp VPlanVerifier.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 5f3d127202ad..9ff18328c219 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -79,6 +79,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Vectorize.h" +#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h" #include <algorithm> #include <cassert> #include <cstdlib> @@ -205,12 +206,12 @@ private: unsigned Alignment); }; -class LoadStoreVectorizer : public FunctionPass { +class LoadStoreVectorizerLegacyPass : public FunctionPass { public: static char ID; - LoadStoreVectorizer() : FunctionPass(ID) { - initializeLoadStoreVectorizerPass(*PassRegistry::getPassRegistry()); + LoadStoreVectorizerLegacyPass() : FunctionPass(ID) { + initializeLoadStoreVectorizerLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; @@ -230,30 +231,23 @@ public: } // end anonymous namespace -char LoadStoreVectorizer::ID = 0; +char LoadStoreVectorizerLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LoadStoreVectorizer, DEBUG_TYPE, +INITIALIZE_PASS_BEGIN(LoadStoreVectorizerLegacyPass, DEBUG_TYPE, "Vectorize load and Store instructions", false, false) INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(LoadStoreVectorizer, DEBUG_TYPE, +INITIALIZE_PASS_END(LoadStoreVectorizerLegacyPass, DEBUG_TYPE, "Vectorize load and store instructions", false, false) Pass *llvm::createLoadStoreVectorizerPass() { - return new LoadStoreVectorizer(); + return new LoadStoreVectorizerLegacyPass(); } -// The real propagateMetadata expects a SmallVector<Value*>, but we deal in -// vectors of Instructions. -static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) { - SmallVector<Value *, 8> VL(IL.begin(), IL.end()); - propagateMetadata(I, VL); -} - -bool LoadStoreVectorizer::runOnFunction(Function &F) { +bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) { // Don't vectorize when the attribute NoImplicitFloat is used. if (skipFunction(F) || F.hasFnAttribute(Attribute::NoImplicitFloat)) return false; @@ -268,6 +262,30 @@ bool LoadStoreVectorizer::runOnFunction(Function &F) { return V.run(); } +PreservedAnalyses LoadStoreVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) { + // Don't vectorize when the attribute NoImplicitFloat is used. + if (F.hasFnAttribute(Attribute::NoImplicitFloat)) + return PreservedAnalyses::all(); + + AliasAnalysis &AA = AM.getResult<AAManager>(F); + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); + + Vectorizer V(F, AA, DT, SE, TTI); + bool Changed = V.run(); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return Changed ? PA : PreservedAnalyses::all(); +} + +// The real propagateMetadata expects a SmallVector<Value*>, but we deal in +// vectors of Instructions. +static void propagateMetadata(Instruction *I, ArrayRef<Instruction *> IL) { + SmallVector<Value *, 8> VL(IL.begin(), IL.end()); + propagateMetadata(I, VL); +} + // Vectorizer Implementation bool Vectorizer::run() { bool Changed = false; @@ -954,11 +972,6 @@ bool Vectorizer::vectorizeStoreChain( // try again. unsigned EltSzInBytes = Sz / 8; unsigned SzInBytes = EltSzInBytes * ChainSize; - if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) { - auto Chains = splitOddVectorElts(Chain, Sz); - return vectorizeStoreChain(Chains.first, InstructionsProcessed) | - vectorizeStoreChain(Chains.second, InstructionsProcessed); - } VectorType *VecTy; VectorType *VecStoreTy = dyn_cast<VectorType>(StoreTy); @@ -991,14 +1004,23 @@ bool Vectorizer::vectorizeStoreChain( // If the store is going to be misaligned, don't vectorize it. if (accessIsMisaligned(SzInBytes, AS, Alignment)) { - if (S0->getPointerAddressSpace() != 0) - return false; + if (S0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { + auto Chains = splitOddVectorElts(Chain, Sz); + return vectorizeStoreChain(Chains.first, InstructionsProcessed) | + vectorizeStoreChain(Chains.second, InstructionsProcessed); + } unsigned NewAlign = getOrEnforceKnownAlignment(S0->getPointerOperand(), StackAdjustedAlignment, DL, S0, nullptr, &DT); - if (NewAlign < StackAdjustedAlignment) - return false; + if (NewAlign != 0) + Alignment = NewAlign; + } + + if (!TTI.isLegalToVectorizeStoreChain(SzInBytes, Alignment, AS)) { + auto Chains = splitOddVectorElts(Chain, Sz); + return vectorizeStoreChain(Chains.first, InstructionsProcessed) | + vectorizeStoreChain(Chains.second, InstructionsProcessed); } BasicBlock::iterator First, Last; @@ -1037,13 +1059,11 @@ bool Vectorizer::vectorizeStoreChain( } } - // This cast is safe because Builder.CreateStore() always creates a bona fide - // StoreInst. - StoreInst *SI = cast<StoreInst>( - Builder.CreateStore(Vec, Builder.CreateBitCast(S0->getPointerOperand(), - VecTy->getPointerTo(AS)))); + StoreInst *SI = Builder.CreateAlignedStore( + Vec, + Builder.CreateBitCast(S0->getPointerOperand(), VecTy->getPointerTo(AS)), + Alignment); propagateMetadata(SI, Chain); - SI->setAlignment(Alignment); eraseInstructions(Chain); ++NumVectorInstructions; @@ -1102,12 +1122,6 @@ bool Vectorizer::vectorizeLoadChain( // try again. unsigned EltSzInBytes = Sz / 8; unsigned SzInBytes = EltSzInBytes * ChainSize; - if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { - auto Chains = splitOddVectorElts(Chain, Sz); - return vectorizeLoadChain(Chains.first, InstructionsProcessed) | - vectorizeLoadChain(Chains.second, InstructionsProcessed); - } - VectorType *VecTy; VectorType *VecLoadTy = dyn_cast<VectorType>(LoadTy); if (VecLoadTy) @@ -1132,18 +1146,27 @@ bool Vectorizer::vectorizeLoadChain( // If the load is going to be misaligned, don't vectorize it. if (accessIsMisaligned(SzInBytes, AS, Alignment)) { - if (L0->getPointerAddressSpace() != 0) - return false; + if (L0->getPointerAddressSpace() != DL.getAllocaAddrSpace()) { + auto Chains = splitOddVectorElts(Chain, Sz); + return vectorizeLoadChain(Chains.first, InstructionsProcessed) | + vectorizeLoadChain(Chains.second, InstructionsProcessed); + } unsigned NewAlign = getOrEnforceKnownAlignment(L0->getPointerOperand(), StackAdjustedAlignment, DL, L0, nullptr, &DT); - if (NewAlign < StackAdjustedAlignment) - return false; + if (NewAlign != 0) + Alignment = NewAlign; Alignment = NewAlign; } + if (!TTI.isLegalToVectorizeLoadChain(SzInBytes, Alignment, AS)) { + auto Chains = splitOddVectorElts(Chain, Sz); + return vectorizeLoadChain(Chains.first, InstructionsProcessed) | + vectorizeLoadChain(Chains.second, InstructionsProcessed); + } + LLVM_DEBUG({ dbgs() << "LSV: Loads to vectorize:\n"; for (Instruction *I : Chain) @@ -1159,11 +1182,8 @@ bool Vectorizer::vectorizeLoadChain( Value *Bitcast = Builder.CreateBitCast(L0->getPointerOperand(), VecTy->getPointerTo(AS)); - // This cast is safe because Builder.CreateLoad always creates a bona fide - // LoadInst. - LoadInst *LI = cast<LoadInst>(Builder.CreateLoad(Bitcast)); + LoadInst *LI = Builder.CreateAlignedLoad(Bitcast, Alignment); propagateMetadata(LI, Chain); - LI->setAlignment(Alignment); if (VecLoadTy) { SmallVector<Instruction *, 16> InstrsToErase; diff --git a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 697bc1b448d7..b44fe5a52a2f 100644 --- a/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -80,10 +80,11 @@ bool LoopVectorizeHints::Hint::validate(unsigned Val) { return false; } -LoopVectorizeHints::LoopVectorizeHints(const Loop *L, bool DisableInterleaving, +LoopVectorizeHints::LoopVectorizeHints(const Loop *L, + bool InterleaveOnlyWhenForced, OptimizationRemarkEmitter &ORE) : Width("vectorize.width", VectorizerParams::VectorizationFactor, HK_WIDTH), - Interleave("interleave.count", DisableInterleaving, HK_UNROLL), + Interleave("interleave.count", InterleaveOnlyWhenForced, HK_UNROLL), Force("vectorize.enable", FK_Undefined, HK_FORCE), IsVectorized("isvectorized", 0, HK_ISVECTORIZED), TheLoop(L), ORE(ORE) { // Populate values with existing loop metadata. @@ -98,19 +99,19 @@ LoopVectorizeHints::LoopVectorizeHints(const Loop *L, bool DisableInterleaving, // consider the loop to have been already vectorized because there's // nothing more that we can do. IsVectorized.Value = Width.Value == 1 && Interleave.Value == 1; - LLVM_DEBUG(if (DisableInterleaving && Interleave.Value == 1) dbgs() + LLVM_DEBUG(if (InterleaveOnlyWhenForced && Interleave.Value == 1) dbgs() << "LV: Interleaving disabled by the pass manager\n"); } -bool LoopVectorizeHints::allowVectorization(Function *F, Loop *L, - bool AlwaysVectorize) const { +bool LoopVectorizeHints::allowVectorization( + Function *F, Loop *L, bool VectorizeOnlyWhenForced) const { if (getForce() == LoopVectorizeHints::FK_Disabled) { LLVM_DEBUG(dbgs() << "LV: Not vectorizing: #pragma vectorize disable.\n"); emitRemarkWithHints(); return false; } - if (!AlwaysVectorize && getForce() != LoopVectorizeHints::FK_Enabled) { + if (VectorizeOnlyWhenForced && getForce() != LoopVectorizeHints::FK_Enabled) { LLVM_DEBUG(dbgs() << "LV: Not vectorizing: No #pragma vectorize enable.\n"); emitRemarkWithHints(); return false; @@ -434,7 +435,7 @@ static Type *getWiderType(const DataLayout &DL, Type *Ty0, Type *Ty1) { /// identified reduction variable. static bool hasOutsideLoopUser(const Loop *TheLoop, Instruction *Inst, SmallPtrSetImpl<Value *> &AllowedExit) { - // Reduction and Induction instructions are allowed to have exit users. All + // Reductions, Inductions and non-header phis are allowed to have exit users. All // other instructions must not have external users. if (!AllowedExit.count(Inst)) // Check that all of the users of the loop are inside the BB. @@ -516,6 +517,18 @@ bool LoopVectorizationLegality::canVectorizeOuterLoop() { return false; } + // Check whether we are able to set up outer loop induction. + if (!setupOuterLoopInductions()) { + LLVM_DEBUG( + dbgs() << "LV: Not vectorizing: Unsupported outer loop Phi(s).\n"); + ORE->emit(createMissedAnalysis("UnsupportedPhi") + << "Unsupported outer loop Phi(s)"); + if (DoExtraAnalysis) + Result = false; + else + return false; + } + return Result; } @@ -561,7 +574,8 @@ void LoopVectorizationLegality::addInductionPhi( // back into the PHI node may have external users. // We can allow those uses, except if the SCEVs we have for them rely // on predicates that only hold within the loop, since allowing the exit - // currently means re-using this SCEV outside the loop. + // currently means re-using this SCEV outside the loop (see PR33706 for more + // details). if (PSE.getUnionPredicate().isAlwaysTrue()) { AllowedExit.insert(Phi); AllowedExit.insert(Phi->getIncomingValueForBlock(TheLoop->getLoopLatch())); @@ -570,6 +584,32 @@ void LoopVectorizationLegality::addInductionPhi( LLVM_DEBUG(dbgs() << "LV: Found an induction variable.\n"); } +bool LoopVectorizationLegality::setupOuterLoopInductions() { + BasicBlock *Header = TheLoop->getHeader(); + + // Returns true if a given Phi is a supported induction. + auto isSupportedPhi = [&](PHINode &Phi) -> bool { + InductionDescriptor ID; + if (InductionDescriptor::isInductionPHI(&Phi, TheLoop, PSE, ID) && + ID.getKind() == InductionDescriptor::IK_IntInduction) { + addInductionPhi(&Phi, ID, AllowedExit); + return true; + } else { + // Bail out for any Phi in the outer loop header that is not a supported + // induction. + LLVM_DEBUG( + dbgs() + << "LV: Found unsupported PHI for outer loop vectorization.\n"); + return false; + } + }; + + if (llvm::all_of(Header->phis(), isSupportedPhi)) + return true; + else + return false; +} + bool LoopVectorizationLegality::canVectorizeInstrs() { BasicBlock *Header = TheLoop->getHeader(); @@ -597,14 +637,12 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // can convert it to select during if-conversion. No need to check if // the PHIs in this block are induction or reduction variables. if (BB != Header) { - // Check that this instruction has no outside users or is an - // identified reduction value with an outside user. - if (!hasOutsideLoopUser(TheLoop, Phi, AllowedExit)) - continue; - ORE->emit(createMissedAnalysis("NeitherInductionNorReduction", Phi) - << "value could not be identified as " - "an induction or reduction variable"); - return false; + // Non-header phi nodes that have outside uses can be vectorized. Add + // them to the list of allowed exits. + // Unsafe cyclic dependencies with header phis are identified during + // legalization for reduction, induction and first order + // recurrences. + continue; } // We only allow if-converted PHIs with exactly two incoming values. @@ -625,6 +663,20 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { continue; } + // TODO: Instead of recording the AllowedExit, it would be good to record the + // complementary set: NotAllowedExit. These include (but may not be + // limited to): + // 1. Reduction phis as they represent the one-before-last value, which + // is not available when vectorized + // 2. Induction phis and increment when SCEV predicates cannot be used + // outside the loop - see addInductionPhi + // 3. Non-Phis with outside uses when SCEV predicates cannot be used + // outside the loop - see call to hasOutsideLoopUser in the non-phi + // handling below + // 4. FirstOrderRecurrence phis that can possibly be handled by + // extraction. + // By recording these, we can then reason about ways to vectorize each + // of these NotAllowedExit. InductionDescriptor ID; if (InductionDescriptor::isInductionPHI(Phi, TheLoop, PSE, ID)) { addInductionPhi(Phi, ID, AllowedExit); @@ -662,10 +714,30 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { !isa<DbgInfoIntrinsic>(CI) && !(CI->getCalledFunction() && TLI && TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { - ORE->emit(createMissedAnalysis("CantVectorizeCall", CI) - << "call instruction cannot be vectorized"); + // If the call is a recognized math libary call, it is likely that + // we can vectorize it given loosened floating-point constraints. + LibFunc Func; + bool IsMathLibCall = + TLI && CI->getCalledFunction() && + CI->getType()->isFloatingPointTy() && + TLI->getLibFunc(CI->getCalledFunction()->getName(), Func) && + TLI->hasOptimizedCodeGen(Func); + + if (IsMathLibCall) { + // TODO: Ideally, we should not use clang-specific language here, + // but it's hard to provide meaningful yet generic advice. + // Also, should this be guarded by allowExtraAnalysis() and/or be part + // of the returned info from isFunctionVectorizable()? + ORE->emit(createMissedAnalysis("CantVectorizeLibcall", CI) + << "library call cannot be vectorized. " + "Try compiling with -fno-math-errno, -ffast-math, " + "or similar flags"); + } else { + ORE->emit(createMissedAnalysis("CantVectorizeCall", CI) + << "call instruction cannot be vectorized"); + } LLVM_DEBUG( - dbgs() << "LV: Found a non-intrinsic, non-libfunc callsite.\n"); + dbgs() << "LV: Found a non-intrinsic callsite.\n"); return false; } @@ -717,6 +789,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // Reduction instructions are allowed to have exit users. // All other instructions must not have external users. if (hasOutsideLoopUser(TheLoop, &I, AllowedExit)) { + // We can safely vectorize loops where instructions within the loop are + // used outside the loop only if the SCEV predicates within the loop is + // same as outside the loop. Allowing the exit means reusing the SCEV + // outside the loop. + if (PSE.getUnionPredicate().isAlwaysTrue()) { + AllowedExit.insert(&I); + continue; + } ORE->emit(createMissedAnalysis("ValueUsedOutsideLoop", &I) << "value cannot be used outside the loop"); return false; @@ -730,6 +810,10 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { ORE->emit(createMissedAnalysis("NoInductionVariable") << "loop induction variable could not be identified"); return false; + } else if (!WidestIndTy) { + ORE->emit(createMissedAnalysis("NoIntegerInductionVariable") + << "integer loop induction variable could not be identified"); + return false; } } @@ -754,13 +838,14 @@ bool LoopVectorizationLegality::canVectorizeMemory() { if (!LAI->canVectorizeMemory()) return false; - if (LAI->hasStoreToLoopInvariantAddress()) { + if (LAI->hasDependenceInvolvingLoopInvariantAddress()) { ORE->emit(createMissedAnalysis("CantVectorizeStoreToLoopInvariantAddress") - << "write to a loop invariant address could not be vectorized"); - LLVM_DEBUG(dbgs() << "LV: We don't allow storing to uniform addresses\n"); + << "write to a loop invariant address could not " + "be vectorized"); + LLVM_DEBUG( + dbgs() << "LV: Non vectorizable stores to a uniform address\n"); return false; } - Requirements->addRuntimePointerChecks(LAI->getNumRuntimePointerChecks()); PSE.addPredicate(LAI->getPSE().getUnionPredicate()); @@ -1069,4 +1154,59 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return Result; } +bool LoopVectorizationLegality::canFoldTailByMasking() { + + LLVM_DEBUG(dbgs() << "LV: checking if tail can be folded by masking.\n"); + + if (!PrimaryInduction) { + ORE->emit(createMissedAnalysis("NoPrimaryInduction") + << "Missing a primary induction variable in the loop, which is " + << "needed in order to fold tail by masking as required."); + LLVM_DEBUG(dbgs() << "LV: No primary induction, cannot fold tail by " + << "masking.\n"); + return false; + } + + // TODO: handle reductions when tail is folded by masking. + if (!Reductions.empty()) { + ORE->emit(createMissedAnalysis("ReductionFoldingTailByMasking") + << "Cannot fold tail by masking in the presence of reductions."); + LLVM_DEBUG(dbgs() << "LV: Loop has reductions, cannot fold tail by " + << "masking.\n"); + return false; + } + + // TODO: handle outside users when tail is folded by masking. + for (auto *AE : AllowedExit) { + // Check that all users of allowed exit values are inside the loop. + for (User *U : AE->users()) { + Instruction *UI = cast<Instruction>(U); + if (TheLoop->contains(UI)) + continue; + ORE->emit(createMissedAnalysis("LiveOutFoldingTailByMasking") + << "Cannot fold tail by masking in the presence of live outs."); + LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking, loop has an " + << "outside user for : " << *UI << '\n'); + return false; + } + } + + // The list of pointers that we can safely read and write to remains empty. + SmallPtrSet<Value *, 8> SafePointers; + + // Check and mark all blocks for predication, including those that ordinarily + // do not need predication such as the header block. + for (BasicBlock *BB : TheLoop->blocks()) { + if (!blockCanBePredicated(BB, SafePointers)) { + ORE->emit(createMissedAnalysis("NoCFGForSelect", BB->getTerminator()) + << "control flow cannot be substituted for a select"); + LLVM_DEBUG(dbgs() << "LV: Cannot fold tail by masking as required.\n"); + return false; + } + } + + LLVM_DEBUG(dbgs() << "LV: can fold tail by masking.\n"); + return true; +} + } // namespace llvm diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 859d0c92ca5a..c45dee590b84 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -58,6 +58,7 @@ #include "LoopVectorizationPlanner.h" #include "VPRecipeBuilder.h" #include "VPlanHCFGBuilder.h" +#include "VPlanHCFGTransforms.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" @@ -151,6 +152,16 @@ using namespace llvm; #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME +/// @{ +/// Metadata attribute names +static const char *const LLVMLoopVectorizeFollowupAll = + "llvm.loop.vectorize.followup_all"; +static const char *const LLVMLoopVectorizeFollowupVectorized = + "llvm.loop.vectorize.followup_vectorized"; +static const char *const LLVMLoopVectorizeFollowupEpilogue = + "llvm.loop.vectorize.followup_epilogue"; +/// @} + STATISTIC(LoopsVectorized, "Number of loops vectorized"); STATISTIC(LoopsAnalyzed, "Number of loops analyzed for vectorization"); @@ -171,11 +182,11 @@ static cl::opt<bool> EnableInterleavedMemAccesses( "enable-interleaved-mem-accesses", cl::init(false), cl::Hidden, cl::desc("Enable vectorization on interleaved memory accesses in a loop")); -/// Maximum factor for an interleaved memory access. -static cl::opt<unsigned> MaxInterleaveGroupFactor( - "max-interleave-group-factor", cl::Hidden, - cl::desc("Maximum factor for an interleaved access group (default = 8)"), - cl::init(8)); +/// An interleave-group may need masking if it resides in a block that needs +/// predication, or in order to mask away gaps. +static cl::opt<bool> EnableMaskedInterleavedMemAccesses( + "enable-masked-interleaved-mem-accesses", cl::init(false), cl::Hidden, + cl::desc("Enable vectorization on masked interleaved memory accesses in a loop")); /// We don't interleave loops with a known constant trip count below this /// number. @@ -240,7 +251,7 @@ static cl::opt<unsigned> MaxNestedScalarReductionIC( cl::desc("The maximum interleave count to use when interleaving a scalar " "reduction in a nested loop.")); -static cl::opt<bool> EnableVPlanNativePath( +cl::opt<bool> EnableVPlanNativePath( "enable-vplan-native-path", cl::init(false), cl::Hidden, cl::desc("Enable VPlan-native vectorization path with " "support for outer loop vectorization.")); @@ -265,10 +276,6 @@ static Type *ToVectorTy(Type *Scalar, unsigned VF) { return VectorType::get(Scalar, VF); } -// FIXME: The following helper functions have multiple implementations -// in the project. They can be effectively organized in a common Load/Store -// utilities unit. - /// A helper function that returns the type of loaded or stored value. static Type *getMemInstValueType(Value *I) { assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && @@ -278,25 +285,6 @@ static Type *getMemInstValueType(Value *I) { return cast<StoreInst>(I)->getValueOperand()->getType(); } -/// A helper function that returns the alignment of load or store instruction. -static unsigned getMemInstAlignment(Value *I) { - assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && - "Expected Load or Store instruction"); - if (auto *LI = dyn_cast<LoadInst>(I)) - return LI->getAlignment(); - return cast<StoreInst>(I)->getAlignment(); -} - -/// A helper function that returns the address space of the pointer operand of -/// load or store instruction. -static unsigned getMemInstAddressSpace(Value *I) { - assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && - "Expected Load or Store instruction"); - if (auto *LI = dyn_cast<LoadInst>(I)) - return LI->getPointerAddressSpace(); - return cast<StoreInst>(I)->getPointerAddressSpace(); -} - /// A helper function that returns true if the given type is irregular. The /// type is irregular if its allocated size doesn't equal the store size of an /// element of the corresponding vector type at the given vectorization factor. @@ -436,8 +424,10 @@ public: /// Construct the vector value of a scalarized value \p V one lane at a time. void packScalarIntoVectorValue(Value *V, const VPIteration &Instance); - /// Try to vectorize the interleaved access group that \p Instr belongs to. - void vectorizeInterleaveGroup(Instruction *Instr); + /// Try to vectorize the interleaved access group that \p Instr belongs to, + /// optionally masking the vector operations if \p BlockInMask is non-null. + void vectorizeInterleaveGroup(Instruction *Instr, + VectorParts *BlockInMask = nullptr); /// Vectorize Load and Store instructions, optionally masking the vector /// operations if \p BlockInMask is non-null. @@ -448,6 +438,9 @@ public: /// the instruction. void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr); + /// Fix the non-induction PHIs in the OrigPHIsToFix vector. + void fixNonInductionPHIs(void); + protected: friend class LoopVectorizationPlanner; @@ -584,6 +577,16 @@ protected: /// Emit bypass checks to check any memory assumptions we may have made. void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); + /// Compute the transformed value of Index at offset StartValue using step + /// StepValue. + /// For integer induction, returns StartValue + Index * StepValue. + /// For pointer induction, returns StartValue[Index * StepValue]. + /// FIXME: The newly created binary instructions should contain nsw/nuw + /// flags, which can be found from the original scalar operations. + Value *emitTransformedIndex(IRBuilder<> &B, Value *Index, ScalarEvolution *SE, + const DataLayout &DL, + const InductionDescriptor &ID) const; + /// Add additional metadata to \p To that was not present on \p Orig. /// /// Currently this is used to add the noalias annotations based on the @@ -705,6 +708,10 @@ protected: // Holds the end values for each induction variable. We save the end values // so we can later fix-up the external users of the induction variables. DenseMap<PHINode *, Value *> IVEndValues; + + // Vector of original scalar PHIs whose corresponding widened PHIs need to be + // fixed up at the end of vector code generation. + SmallVector<PHINode *, 8> OrigPHIsToFix; }; class InnerLoopUnroller : public InnerLoopVectorizer { @@ -752,8 +759,15 @@ void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr)) { const DILocation *DIL = Inst->getDebugLoc(); if (DIL && Inst->getFunction()->isDebugInfoForProfiling() && - !isa<DbgInfoIntrinsic>(Inst)) - B.SetCurrentDebugLocation(DIL->cloneWithDuplicationFactor(UF * VF)); + !isa<DbgInfoIntrinsic>(Inst)) { + auto NewDIL = DIL->cloneWithDuplicationFactor(UF * VF); + if (NewDIL) + B.SetCurrentDebugLocation(NewDIL.getValue()); + else + LLVM_DEBUG(dbgs() + << "Failed to create new discriminator: " + << DIL->getFilename() << " Line: " << DIL->getLine()); + } else B.SetCurrentDebugLocation(DIL); } else @@ -801,367 +815,6 @@ void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To, namespace llvm { -/// The group of interleaved loads/stores sharing the same stride and -/// close to each other. -/// -/// Each member in this group has an index starting from 0, and the largest -/// index should be less than interleaved factor, which is equal to the absolute -/// value of the access's stride. -/// -/// E.g. An interleaved load group of factor 4: -/// for (unsigned i = 0; i < 1024; i+=4) { -/// a = A[i]; // Member of index 0 -/// b = A[i+1]; // Member of index 1 -/// d = A[i+3]; // Member of index 3 -/// ... -/// } -/// -/// An interleaved store group of factor 4: -/// for (unsigned i = 0; i < 1024; i+=4) { -/// ... -/// A[i] = a; // Member of index 0 -/// A[i+1] = b; // Member of index 1 -/// A[i+2] = c; // Member of index 2 -/// A[i+3] = d; // Member of index 3 -/// } -/// -/// Note: the interleaved load group could have gaps (missing members), but -/// the interleaved store group doesn't allow gaps. -class InterleaveGroup { -public: - InterleaveGroup(Instruction *Instr, int Stride, unsigned Align) - : Align(Align), InsertPos(Instr) { - assert(Align && "The alignment should be non-zero"); - - Factor = std::abs(Stride); - assert(Factor > 1 && "Invalid interleave factor"); - - Reverse = Stride < 0; - Members[0] = Instr; - } - - bool isReverse() const { return Reverse; } - unsigned getFactor() const { return Factor; } - unsigned getAlignment() const { return Align; } - unsigned getNumMembers() const { return Members.size(); } - - /// Try to insert a new member \p Instr with index \p Index and - /// alignment \p NewAlign. The index is related to the leader and it could be - /// negative if it is the new leader. - /// - /// \returns false if the instruction doesn't belong to the group. - bool insertMember(Instruction *Instr, int Index, unsigned NewAlign) { - assert(NewAlign && "The new member's alignment should be non-zero"); - - int Key = Index + SmallestKey; - - // Skip if there is already a member with the same index. - if (Members.count(Key)) - return false; - - if (Key > LargestKey) { - // The largest index is always less than the interleave factor. - if (Index >= static_cast<int>(Factor)) - return false; - - LargestKey = Key; - } else if (Key < SmallestKey) { - // The largest index is always less than the interleave factor. - if (LargestKey - Key >= static_cast<int>(Factor)) - return false; - - SmallestKey = Key; - } - - // It's always safe to select the minimum alignment. - Align = std::min(Align, NewAlign); - Members[Key] = Instr; - return true; - } - - /// Get the member with the given index \p Index - /// - /// \returns nullptr if contains no such member. - Instruction *getMember(unsigned Index) const { - int Key = SmallestKey + Index; - if (!Members.count(Key)) - return nullptr; - - return Members.find(Key)->second; - } - - /// Get the index for the given member. Unlike the key in the member - /// map, the index starts from 0. - unsigned getIndex(Instruction *Instr) const { - for (auto I : Members) - if (I.second == Instr) - return I.first - SmallestKey; - - llvm_unreachable("InterleaveGroup contains no such member"); - } - - Instruction *getInsertPos() const { return InsertPos; } - void setInsertPos(Instruction *Inst) { InsertPos = Inst; } - - /// Add metadata (e.g. alias info) from the instructions in this group to \p - /// NewInst. - /// - /// FIXME: this function currently does not add noalias metadata a'la - /// addNewMedata. To do that we need to compute the intersection of the - /// noalias info from all members. - void addMetadata(Instruction *NewInst) const { - SmallVector<Value *, 4> VL; - std::transform(Members.begin(), Members.end(), std::back_inserter(VL), - [](std::pair<int, Instruction *> p) { return p.second; }); - propagateMetadata(NewInst, VL); - } - -private: - unsigned Factor; // Interleave Factor. - bool Reverse; - unsigned Align; - DenseMap<int, Instruction *> Members; - int SmallestKey = 0; - int LargestKey = 0; - - // To avoid breaking dependences, vectorized instructions of an interleave - // group should be inserted at either the first load or the last store in - // program order. - // - // E.g. %even = load i32 // Insert Position - // %add = add i32 %even // Use of %even - // %odd = load i32 - // - // store i32 %even - // %odd = add i32 // Def of %odd - // store i32 %odd // Insert Position - Instruction *InsertPos; -}; -} // end namespace llvm - -namespace { - -/// Drive the analysis of interleaved memory accesses in the loop. -/// -/// Use this class to analyze interleaved accesses only when we can vectorize -/// a loop. Otherwise it's meaningless to do analysis as the vectorization -/// on interleaved accesses is unsafe. -/// -/// The analysis collects interleave groups and records the relationships -/// between the member and the group in a map. -class InterleavedAccessInfo { -public: - InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, - DominatorTree *DT, LoopInfo *LI, - const LoopAccessInfo *LAI) - : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {} - - ~InterleavedAccessInfo() { - SmallPtrSet<InterleaveGroup *, 4> DelSet; - // Avoid releasing a pointer twice. - for (auto &I : InterleaveGroupMap) - DelSet.insert(I.second); - for (auto *Ptr : DelSet) - delete Ptr; - } - - /// Analyze the interleaved accesses and collect them in interleave - /// groups. Substitute symbolic strides using \p Strides. - void analyzeInterleaving(); - - /// Check if \p Instr belongs to any interleave group. - bool isInterleaved(Instruction *Instr) const { - return InterleaveGroupMap.count(Instr); - } - - /// Get the interleave group that \p Instr belongs to. - /// - /// \returns nullptr if doesn't have such group. - InterleaveGroup *getInterleaveGroup(Instruction *Instr) const { - if (InterleaveGroupMap.count(Instr)) - return InterleaveGroupMap.find(Instr)->second; - return nullptr; - } - - /// Returns true if an interleaved group that may access memory - /// out-of-bounds requires a scalar epilogue iteration for correctness. - bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; } - -private: - /// A wrapper around ScalarEvolution, used to add runtime SCEV checks. - /// Simplifies SCEV expressions in the context of existing SCEV assumptions. - /// The interleaved access analysis can also add new predicates (for example - /// by versioning strides of pointers). - PredicatedScalarEvolution &PSE; - - Loop *TheLoop; - DominatorTree *DT; - LoopInfo *LI; - const LoopAccessInfo *LAI; - - /// True if the loop may contain non-reversed interleaved groups with - /// out-of-bounds accesses. We ensure we don't speculatively access memory - /// out-of-bounds by executing at least one scalar epilogue iteration. - bool RequiresScalarEpilogue = false; - - /// Holds the relationships between the members and the interleave group. - DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap; - - /// Holds dependences among the memory accesses in the loop. It maps a source - /// access to a set of dependent sink accesses. - DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences; - - /// The descriptor for a strided memory access. - struct StrideDescriptor { - StrideDescriptor() = default; - StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, - unsigned Align) - : Stride(Stride), Scev(Scev), Size(Size), Align(Align) {} - - // The access's stride. It is negative for a reverse access. - int64_t Stride = 0; - - // The scalar expression of this access. - const SCEV *Scev = nullptr; - - // The size of the memory object. - uint64_t Size = 0; - - // The alignment of this access. - unsigned Align = 0; - }; - - /// A type for holding instructions and their stride descriptors. - using StrideEntry = std::pair<Instruction *, StrideDescriptor>; - - /// Create a new interleave group with the given instruction \p Instr, - /// stride \p Stride and alignment \p Align. - /// - /// \returns the newly created interleave group. - InterleaveGroup *createInterleaveGroup(Instruction *Instr, int Stride, - unsigned Align) { - assert(!InterleaveGroupMap.count(Instr) && - "Already in an interleaved access group"); - InterleaveGroupMap[Instr] = new InterleaveGroup(Instr, Stride, Align); - return InterleaveGroupMap[Instr]; - } - - /// Release the group and remove all the relationships. - void releaseGroup(InterleaveGroup *Group) { - for (unsigned i = 0; i < Group->getFactor(); i++) - if (Instruction *Member = Group->getMember(i)) - InterleaveGroupMap.erase(Member); - - delete Group; - } - - /// Collect all the accesses with a constant stride in program order. - void collectConstStrideAccesses( - MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, - const ValueToValueMap &Strides); - - /// Returns true if \p Stride is allowed in an interleaved group. - static bool isStrided(int Stride) { - unsigned Factor = std::abs(Stride); - return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; - } - - /// Returns true if \p BB is a predicated block. - bool isPredicated(BasicBlock *BB) const { - return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT); - } - - /// Returns true if LoopAccessInfo can be used for dependence queries. - bool areDependencesValid() const { - return LAI && LAI->getDepChecker().getDependences(); - } - - /// Returns true if memory accesses \p A and \p B can be reordered, if - /// necessary, when constructing interleaved groups. - /// - /// \p A must precede \p B in program order. We return false if reordering is - /// not necessary or is prevented because \p A and \p B may be dependent. - bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, - StrideEntry *B) const { - // Code motion for interleaved accesses can potentially hoist strided loads - // and sink strided stores. The code below checks the legality of the - // following two conditions: - // - // 1. Potentially moving a strided load (B) before any store (A) that - // precedes B, or - // - // 2. Potentially moving a strided store (A) after any load or store (B) - // that A precedes. - // - // It's legal to reorder A and B if we know there isn't a dependence from A - // to B. Note that this determination is conservative since some - // dependences could potentially be reordered safely. - - // A is potentially the source of a dependence. - auto *Src = A->first; - auto SrcDes = A->second; - - // B is potentially the sink of a dependence. - auto *Sink = B->first; - auto SinkDes = B->second; - - // Code motion for interleaved accesses can't violate WAR dependences. - // Thus, reordering is legal if the source isn't a write. - if (!Src->mayWriteToMemory()) - return true; - - // At least one of the accesses must be strided. - if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride)) - return true; - - // If dependence information is not available from LoopAccessInfo, - // conservatively assume the instructions can't be reordered. - if (!areDependencesValid()) - return false; - - // If we know there is a dependence from source to sink, assume the - // instructions can't be reordered. Otherwise, reordering is legal. - return !Dependences.count(Src) || !Dependences.lookup(Src).count(Sink); - } - - /// Collect the dependences from LoopAccessInfo. - /// - /// We process the dependences once during the interleaved access analysis to - /// enable constant-time dependence queries. - void collectDependences() { - if (!areDependencesValid()) - return; - auto *Deps = LAI->getDepChecker().getDependences(); - for (auto Dep : *Deps) - Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI)); - } -}; - -} // end anonymous namespace - -static void emitMissedWarning(Function *F, Loop *L, - const LoopVectorizeHints &LH, - OptimizationRemarkEmitter *ORE) { - LH.emitRemarkWithHints(); - - if (LH.getForce() == LoopVectorizeHints::FK_Enabled) { - if (LH.getWidth() != 1) - ORE->emit(DiagnosticInfoOptimizationFailure( - DEBUG_TYPE, "FailedRequestedVectorization", - L->getStartLoc(), L->getHeader()) - << "loop not vectorized: " - << "failed explicitly specified loop vectorization"); - else if (LH.getInterleave() != 1) - ORE->emit(DiagnosticInfoOptimizationFailure( - DEBUG_TYPE, "FailedRequestedInterleaving", L->getStartLoc(), - L->getHeader()) - << "loop not interleaved: " - << "failed explicitly specified loop interleaving"); - } -} - -namespace llvm { - /// LoopVectorizationCostModel - estimates the expected speedups due to /// vectorization. /// In many cases vectorization is not profitable. This can happen because of @@ -1247,34 +900,55 @@ public: /// vectorization factor \p VF. bool isProfitableToScalarize(Instruction *I, unsigned VF) const { assert(VF > 1 && "Profitable to scalarize relevant only for VF > 1."); + + // Cost model is not run in the VPlan-native path - return conservative + // result until this changes. + if (EnableVPlanNativePath) + return false; + auto Scalars = InstsToScalarize.find(VF); assert(Scalars != InstsToScalarize.end() && "VF not yet analyzed for scalarization profitability"); - return Scalars->second.count(I); + return Scalars->second.find(I) != Scalars->second.end(); } /// Returns true if \p I is known to be uniform after vectorization. bool isUniformAfterVectorization(Instruction *I, unsigned VF) const { if (VF == 1) return true; - assert(Uniforms.count(VF) && "VF not yet analyzed for uniformity"); + + // Cost model is not run in the VPlan-native path - return conservative + // result until this changes. + if (EnableVPlanNativePath) + return false; + auto UniformsPerVF = Uniforms.find(VF); - return UniformsPerVF->second.count(I); + assert(UniformsPerVF != Uniforms.end() && + "VF not yet analyzed for uniformity"); + return UniformsPerVF->second.find(I) != UniformsPerVF->second.end(); } /// Returns true if \p I is known to be scalar after vectorization. bool isScalarAfterVectorization(Instruction *I, unsigned VF) const { if (VF == 1) return true; - assert(Scalars.count(VF) && "Scalar values are not calculated for VF"); + + // Cost model is not run in the VPlan-native path - return conservative + // result until this changes. + if (EnableVPlanNativePath) + return false; + auto ScalarsPerVF = Scalars.find(VF); - return ScalarsPerVF->second.count(I); + assert(ScalarsPerVF != Scalars.end() && + "Scalar values are not calculated for VF"); + return ScalarsPerVF->second.find(I) != ScalarsPerVF->second.end(); } /// \returns True if instruction \p I can be truncated to a smaller bitwidth /// for vectorization factor \p VF. bool canTruncateToMinimalBitwidth(Instruction *I, unsigned VF) const { - return VF > 1 && MinBWs.count(I) && !isProfitableToScalarize(I, VF) && + return VF > 1 && MinBWs.find(I) != MinBWs.end() && + !isProfitableToScalarize(I, VF) && !isScalarAfterVectorization(I, VF); } @@ -1298,7 +972,7 @@ public: /// Save vectorization decision \p W and \p Cost taken by the cost model for /// interleaving group \p Grp and vector width \p VF. - void setWideningDecision(const InterleaveGroup *Grp, unsigned VF, + void setWideningDecision(const InterleaveGroup<Instruction> *Grp, unsigned VF, InstWidening W, unsigned Cost) { assert(VF >= 2 && "Expected VF >=2"); /// Broadcast this decicion to all instructions inside the group. @@ -1318,6 +992,12 @@ public: /// through the cost modeling. InstWidening getWideningDecision(Instruction *I, unsigned VF) { assert(VF >= 2 && "Expected VF >=2"); + + // Cost model is not run in the VPlan-native path - return conservative + // result until this changes. + if (EnableVPlanNativePath) + return CM_GatherScatter; + std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF); auto Itr = WideningDecisions.find(InstOnVF); if (Itr == WideningDecisions.end()) @@ -1330,7 +1010,8 @@ public: unsigned getWideningCost(Instruction *I, unsigned VF) { assert(VF >= 2 && "Expected VF >=2"); std::pair<Instruction *, unsigned> InstOnVF = std::make_pair(I, VF); - assert(WideningDecisions.count(InstOnVF) && "The cost is not calculated"); + assert(WideningDecisions.find(InstOnVF) != WideningDecisions.end() && + "The cost is not calculated"); return WideningDecisions[InstOnVF].second; } @@ -1369,7 +1050,7 @@ public: /// that may be vectorized as interleave, gather-scatter or scalarized. void collectUniformsAndScalars(unsigned VF) { // Do the analysis once. - if (VF == 1 || Uniforms.count(VF)) + if (VF == 1 || Uniforms.find(VF) != Uniforms.end()) return; setCostBasedWideningDecision(VF); collectLoopUniforms(VF); @@ -1414,26 +1095,58 @@ public: /// Returns true if \p I is an instruction that will be scalarized with /// predication. Such instructions include conditional stores and /// instructions that may divide by zero. - bool isScalarWithPredication(Instruction *I); + /// If a non-zero VF has been calculated, we check if I will be scalarized + /// predication for that VF. + bool isScalarWithPredication(Instruction *I, unsigned VF = 1); + + // Returns true if \p I is an instruction that will be predicated either + // through scalar predication or masked load/store or masked gather/scatter. + // Superset of instructions that return true for isScalarWithPredication. + bool isPredicatedInst(Instruction *I) { + if (!blockNeedsPredication(I->getParent())) + return false; + // Loads and stores that need some form of masked operation are predicated + // instructions. + if (isa<LoadInst>(I) || isa<StoreInst>(I)) + return Legal->isMaskRequired(I); + return isScalarWithPredication(I); + } /// Returns true if \p I is a memory instruction with consecutive memory /// access that can be widened. bool memoryInstructionCanBeWidened(Instruction *I, unsigned VF = 1); + /// Returns true if \p I is a memory instruction in an interleaved-group + /// of memory accesses that can be vectorized with wide vector loads/stores + /// and shuffles. + bool interleavedAccessCanBeWidened(Instruction *I, unsigned VF = 1); + /// Check if \p Instr belongs to any interleaved access group. bool isAccessInterleaved(Instruction *Instr) { return InterleaveInfo.isInterleaved(Instr); } /// Get the interleaved access group that \p Instr belongs to. - const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { + const InterleaveGroup<Instruction> * + getInterleavedAccessGroup(Instruction *Instr) { return InterleaveInfo.getInterleaveGroup(Instr); } /// Returns true if an interleaved group requires a scalar iteration - /// to handle accesses with gaps. + /// to handle accesses with gaps, and there is nothing preventing us from + /// creating a scalar epilogue. bool requiresScalarEpilogue() const { - return InterleaveInfo.requiresScalarEpilogue(); + return IsScalarEpilogueAllowed && InterleaveInfo.requiresScalarEpilogue(); + } + + /// Returns true if a scalar epilogue is not allowed due to optsize. + bool isScalarEpilogueAllowed() const { return IsScalarEpilogueAllowed; } + + /// Returns true if all loop blocks should be masked to fold tail loop. + bool foldTailByMasking() const { return FoldTailByMasking; } + + bool blockNeedsPredication(BasicBlock *BB) { + return foldTailByMasking() || Legal->blockNeedsPredication(BB); } private: @@ -1482,8 +1195,10 @@ private: /// memory access. unsigned getConsecutiveMemOpCost(Instruction *I, unsigned VF); - /// The cost calculation for Load instruction \p I with uniform pointer - - /// scalar load + broadcast. + /// The cost calculation for Load/Store instruction \p I with uniform pointer - + /// Load: scalar load + broadcast. + /// Store: scalar store + (loop invariant value stored? 0 : extract of last + /// element) unsigned getUniformMemOpCost(Instruction *I, unsigned VF); /// Returns whether the instruction is a load or store and will be a emitted @@ -1517,6 +1232,18 @@ private: /// vectorization as a predicated block. SmallPtrSet<BasicBlock *, 4> PredicatedBBsAfterVectorization; + /// Records whether it is allowed to have the original scalar loop execute at + /// least once. This may be needed as a fallback loop in case runtime + /// aliasing/dependence checks fail, or to handle the tail/remainder + /// iterations when the trip count is unknown or doesn't divide by the VF, + /// or as a peel-loop to handle gaps in interleave-groups. + /// Under optsize and when the trip count is very small we don't allow any + /// iterations to execute in the scalar loop. + bool IsScalarEpilogueAllowed = true; + + /// All blocks of loop are to be masked to fold tail of scalar iterations. + bool FoldTailByMasking = false; + /// A map holding scalar costs for different vectorization factors. The /// presence of a cost for an instruction in the mapping indicates that the /// instruction will be scalarized when vectorizing with the associated @@ -1639,14 +1366,15 @@ static bool isExplicitVecOuterLoop(Loop *OuterLp, return false; Function *Fn = OuterLp->getHeader()->getParent(); - if (!Hints.allowVectorization(Fn, OuterLp, false /*AlwaysVectorize*/)) { + if (!Hints.allowVectorization(Fn, OuterLp, + true /*VectorizeOnlyWhenForced*/)) { LLVM_DEBUG(dbgs() << "LV: Loop hints prevent outer loop vectorization.\n"); return false; } if (!Hints.getWidth()) { LLVM_DEBUG(dbgs() << "LV: Not vectorizing: No user vector width.\n"); - emitMissedWarning(Fn, OuterLp, Hints, ORE); + Hints.emitRemarkWithHints(); return false; } @@ -1654,7 +1382,7 @@ static bool isExplicitVecOuterLoop(Loop *OuterLp, // TODO: Interleave support is future work. LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Interleave is not supported for " "outer loops.\n"); - emitMissedWarning(Fn, OuterLp, Hints, ORE); + Hints.emitRemarkWithHints(); return false; } @@ -1695,10 +1423,11 @@ struct LoopVectorize : public FunctionPass { LoopVectorizePass Impl; - explicit LoopVectorize(bool NoUnrolling = false, bool AlwaysVectorize = true) + explicit LoopVectorize(bool InterleaveOnlyWhenForced = false, + bool VectorizeOnlyWhenForced = false) : FunctionPass(ID) { - Impl.DisableUnrolling = NoUnrolling; - Impl.AlwaysVectorize = AlwaysVectorize; + Impl.InterleaveOnlyWhenForced = InterleaveOnlyWhenForced; + Impl.VectorizeOnlyWhenForced = VectorizeOnlyWhenForced; initializeLoopVectorizePass(*PassRegistry::getPassRegistry()); } @@ -1737,8 +1466,16 @@ struct LoopVectorize : public FunctionPass { AU.addRequired<LoopAccessLegacyAnalysis>(); AU.addRequired<DemandedBitsWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); - AU.addPreserved<LoopInfoWrapperPass>(); - AU.addPreserved<DominatorTreeWrapperPass>(); + + // We currently do not preserve loopinfo/dominator analyses with outer loop + // vectorization. Until this is addressed, mark these analyses as preserved + // only for non-VPlan-native path. + // TODO: Preserve Loop and Dominator analyses for VPlan-native path. + if (!EnableVPlanNativePath) { + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + } + AU.addPreserved<BasicAAWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } @@ -1950,7 +1687,7 @@ void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) { ? Builder.CreateSExtOrTrunc(Induction, IV->getType()) : Builder.CreateCast(Instruction::SIToFP, Induction, IV->getType()); - ScalarIV = ID.transform(Builder, ScalarIV, PSE.getSE(), DL); + ScalarIV = emitTransformedIndex(Builder, ScalarIV, PSE.getSE(), DL, ID); ScalarIV->setName("offset.idx"); } if (Trunc) { @@ -2089,8 +1826,9 @@ Value *InnerLoopVectorizer::getOrCreateVectorValue(Value *V, unsigned Part) { assert(!V->getType()->isVectorTy() && "Can't widen a vector"); assert(!V->getType()->isVoidTy() && "Type does not produce a value"); - // If we have a stride that is replaced by one, do it here. - if (Legal->hasStride(V)) + // If we have a stride that is replaced by one, do it here. Defer this for + // the VPlan-native path until we start running Legal checks in that path. + if (!EnableVPlanNativePath && Legal->hasStride(V)) V = ConstantInt::get(V->getType(), 1); // If we have a vector mapped to this value, return it. @@ -2214,6 +1952,17 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) { "reverse"); } +// Return whether we allow using masked interleave-groups (for dealing with +// strided loads/stores that reside in predicated blocks, or for dealing +// with gaps). +static bool useMaskedInterleavedAccesses(const TargetTransformInfo &TTI) { + // If an override option has been passed in for interleaved accesses, use it. + if (EnableMaskedInterleavedMemAccesses.getNumOccurrences() > 0) + return EnableMaskedInterleavedMemAccesses; + + return TTI.enableMaskedInterleavedAccessVectorization(); +} + // Try to vectorize the interleave group that \p Instr belongs to. // // E.g. Translate following interleaved load group (factor = 3): @@ -2242,8 +1991,10 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) { // %interleaved.vec = shuffle %R_G.vec, %B_U.vec, // <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11> ; Interleave R,G,B elements // store <12 x i32> %interleaved.vec ; Write 4 tuples of R,G,B -void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { - const InterleaveGroup *Group = Cost->getInterleavedAccessGroup(Instr); +void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr, + VectorParts *BlockInMask) { + const InterleaveGroup<Instruction> *Group = + Cost->getInterleavedAccessGroup(Instr); assert(Group && "Fail to get an interleaved access group."); // Skip if current instruction is not the insert position. @@ -2257,13 +2008,22 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Type *ScalarTy = getMemInstValueType(Instr); unsigned InterleaveFactor = Group->getFactor(); Type *VecTy = VectorType::get(ScalarTy, InterleaveFactor * VF); - Type *PtrTy = VecTy->getPointerTo(getMemInstAddressSpace(Instr)); + Type *PtrTy = VecTy->getPointerTo(getLoadStoreAddressSpace(Instr)); // Prepare for the new pointers. setDebugLocFromInst(Builder, Ptr); SmallVector<Value *, 2> NewPtrs; unsigned Index = Group->getIndex(Instr); + VectorParts Mask; + bool IsMaskForCondRequired = BlockInMask; + if (IsMaskForCondRequired) { + Mask = *BlockInMask; + // TODO: extend the masked interleaved-group support to reversed access. + assert(!Group->isReverse() && "Reversed masked interleave-group " + "not supported."); + } + // If the group is reverse, adjust the index to refer to the last vector lane // instead of the first. We adjust the index from the first vector lane, // rather than directly getting the pointer for lane VF - 1, because the @@ -2302,13 +2062,39 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { setDebugLocFromInst(Builder, Instr); Value *UndefVec = UndefValue::get(VecTy); + Value *MaskForGaps = nullptr; + if (Group->requiresScalarEpilogue() && !Cost->isScalarEpilogueAllowed()) { + MaskForGaps = createBitMaskForGaps(Builder, VF, *Group); + assert(MaskForGaps && "Mask for Gaps is required but it is null"); + } + // Vectorize the interleaved load group. if (isa<LoadInst>(Instr)) { // For each unroll part, create a wide load for the group. SmallVector<Value *, 2> NewLoads; for (unsigned Part = 0; Part < UF; Part++) { - auto *NewLoad = Builder.CreateAlignedLoad( - NewPtrs[Part], Group->getAlignment(), "wide.vec"); + Instruction *NewLoad; + if (IsMaskForCondRequired || MaskForGaps) { + assert(useMaskedInterleavedAccesses(*TTI) && + "masked interleaved groups are not allowed."); + Value *GroupMask = MaskForGaps; + if (IsMaskForCondRequired) { + auto *Undefs = UndefValue::get(Mask[Part]->getType()); + auto *RepMask = createReplicatedMask(Builder, InterleaveFactor, VF); + Value *ShuffledMask = Builder.CreateShuffleVector( + Mask[Part], Undefs, RepMask, "interleaved.mask"); + GroupMask = MaskForGaps + ? Builder.CreateBinOp(Instruction::And, ShuffledMask, + MaskForGaps) + : ShuffledMask; + } + NewLoad = + Builder.CreateMaskedLoad(NewPtrs[Part], Group->getAlignment(), + GroupMask, UndefVec, "wide.masked.vec"); + } + else + NewLoad = Builder.CreateAlignedLoad(NewPtrs[Part], + Group->getAlignment(), "wide.vec"); Group->addMetadata(NewLoad); NewLoads.push_back(NewLoad); } @@ -2375,8 +2161,18 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Value *IVec = Builder.CreateShuffleVector(WideVec, UndefVec, IMask, "interleaved.vec"); - Instruction *NewStoreInstr = - Builder.CreateAlignedStore(IVec, NewPtrs[Part], Group->getAlignment()); + Instruction *NewStoreInstr; + if (IsMaskForCondRequired) { + auto *Undefs = UndefValue::get(Mask[Part]->getType()); + auto *RepMask = createReplicatedMask(Builder, InterleaveFactor, VF); + Value *ShuffledMask = Builder.CreateShuffleVector( + Mask[Part], Undefs, RepMask, "interleaved.mask"); + NewStoreInstr = Builder.CreateMaskedStore( + IVec, NewPtrs[Part], Group->getAlignment(), ShuffledMask); + } + else + NewStoreInstr = Builder.CreateAlignedStore(IVec, NewPtrs[Part], + Group->getAlignment()); Group->addMetadata(NewStoreInstr); } @@ -2400,13 +2196,13 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, Type *ScalarDataTy = getMemInstValueType(Instr); Type *DataTy = VectorType::get(ScalarDataTy, VF); Value *Ptr = getLoadStorePointerOperand(Instr); - unsigned Alignment = getMemInstAlignment(Instr); + unsigned Alignment = getLoadStoreAlignment(Instr); // An alignment of 0 means target abi alignment. We need to use the scalar's // target abi alignment in such a case. const DataLayout &DL = Instr->getModule()->getDataLayout(); if (!Alignment) Alignment = DL.getABITypeAlignment(ScalarDataTy); - unsigned AddressSpace = getMemInstAddressSpace(Instr); + unsigned AddressSpace = getLoadStoreAddressSpace(Instr); // Determine if the pointer operand of the access is either consecutive or // reverse consecutive. @@ -2594,6 +2390,7 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { if (TripCount) return TripCount; + assert(L && "Create Trip Count for null loop."); IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); // Find the loop boundaries. ScalarEvolution *SE = PSE.getSE(); @@ -2602,6 +2399,7 @@ Value *InnerLoopVectorizer::getOrCreateTripCount(Loop *L) { "Invalid loop count"); Type *IdxTy = Legal->getWidestInductionType(); + assert(IdxTy && "No type for induction"); // The exit count might have the type of i64 while the phi is i32. This can // happen if we have an induction variable that is sign extended before the @@ -2642,12 +2440,26 @@ Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { Value *TC = getOrCreateTripCount(L); IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + Type *Ty = TC->getType(); + Constant *Step = ConstantInt::get(Ty, VF * UF); + + // If the tail is to be folded by masking, round the number of iterations N + // up to a multiple of Step instead of rounding down. This is done by first + // adding Step-1 and then rounding down. Note that it's ok if this addition + // overflows: the vector induction variable will eventually wrap to zero given + // that it starts at zero and its Step is a power of two; the loop will then + // exit, with the last early-exit vector comparison also producing all-true. + if (Cost->foldTailByMasking()) { + assert(isPowerOf2_32(VF * UF) && + "VF*UF must be a power of 2 when folding tail by masking"); + TC = Builder.CreateAdd(TC, ConstantInt::get(Ty, VF * UF - 1), "n.rnd.up"); + } + // Now we need to generate the expression for the part of the loop that the // vectorized body will execute. This is equal to N - (N % Step) if scalar // iterations are not required for correctness, or N - Step, otherwise. Step // is equal to the vectorization factor (number of SIMD elements) times the // unroll factor (number of SIMD instructions). - Constant *Step = ConstantInt::get(TC->getType(), VF * UF); Value *R = Builder.CreateURem(TC, Step, "n.mod.vf"); // If there is a non-reversed interleaved group that may speculatively access @@ -2710,8 +2522,13 @@ void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, // of zero. In this case we will also jump to the scalar loop. auto P = Cost->requiresScalarEpilogue() ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_ULT; - Value *CheckMinIters = Builder.CreateICmp( - P, Count, ConstantInt::get(Count->getType(), VF * UF), "min.iters.check"); + + // If tail is to be folded, vector loop takes care of all iterations. + Value *CheckMinIters = Builder.getFalse(); + if (!Cost->foldTailByMasking()) + CheckMinIters = Builder.CreateICmp( + P, Count, ConstantInt::get(Count->getType(), VF * UF), + "min.iters.check"); BasicBlock *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); // Update dominator tree immediately if the generated block is a @@ -2740,6 +2557,8 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { if (C->isZero()) return; + assert(!Cost->foldTailByMasking() && + "Cannot SCEV check stride or overflow when folding tail"); // Create a new block containing the stride check. BB->setName("vector.scevcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); @@ -2756,6 +2575,10 @@ void InnerLoopVectorizer::emitSCEVChecks(Loop *L, BasicBlock *Bypass) { } void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { + // VPlan-native path does not do any analysis for runtime checks currently. + if (EnableVPlanNativePath) + return; + BasicBlock *BB = L->getLoopPreheader(); // Generate the code that checks in runtime if arrays overlap. We put the @@ -2768,6 +2591,7 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { if (!MemRuntimeCheck) return; + assert(!Cost->foldTailByMasking() && "Cannot check memory when folding tail"); // Create a new block containing the memory check. BB->setName("vector.memcheck"); auto *NewBB = BB->splitBasicBlock(BB->getTerminator(), "vector.ph"); @@ -2789,6 +2613,94 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { LVer->prepareNoAliasMetadata(); } +Value *InnerLoopVectorizer::emitTransformedIndex( + IRBuilder<> &B, Value *Index, ScalarEvolution *SE, const DataLayout &DL, + const InductionDescriptor &ID) const { + + SCEVExpander Exp(*SE, DL, "induction"); + auto Step = ID.getStep(); + auto StartValue = ID.getStartValue(); + assert(Index->getType() == Step->getType() && + "Index type does not match StepValue type"); + + // Note: the IR at this point is broken. We cannot use SE to create any new + // SCEV and then expand it, hoping that SCEV's simplification will give us + // a more optimal code. Unfortunately, attempt of doing so on invalid IR may + // lead to various SCEV crashes. So all we can do is to use builder and rely + // on InstCombine for future simplifications. Here we handle some trivial + // cases only. + auto CreateAdd = [&B](Value *X, Value *Y) { + assert(X->getType() == Y->getType() && "Types don't match!"); + if (auto *CX = dyn_cast<ConstantInt>(X)) + if (CX->isZero()) + return Y; + if (auto *CY = dyn_cast<ConstantInt>(Y)) + if (CY->isZero()) + return X; + return B.CreateAdd(X, Y); + }; + + auto CreateMul = [&B](Value *X, Value *Y) { + assert(X->getType() == Y->getType() && "Types don't match!"); + if (auto *CX = dyn_cast<ConstantInt>(X)) + if (CX->isOne()) + return Y; + if (auto *CY = dyn_cast<ConstantInt>(Y)) + if (CY->isOne()) + return X; + return B.CreateMul(X, Y); + }; + + switch (ID.getKind()) { + case InductionDescriptor::IK_IntInduction: { + assert(Index->getType() == StartValue->getType() && + "Index type does not match StartValue type"); + if (ID.getConstIntStepValue() && ID.getConstIntStepValue()->isMinusOne()) + return B.CreateSub(StartValue, Index); + auto *Offset = CreateMul( + Index, Exp.expandCodeFor(Step, Index->getType(), &*B.GetInsertPoint())); + return CreateAdd(StartValue, Offset); + } + case InductionDescriptor::IK_PtrInduction: { + assert(isa<SCEVConstant>(Step) && + "Expected constant step for pointer induction"); + return B.CreateGEP( + nullptr, StartValue, + CreateMul(Index, Exp.expandCodeFor(Step, Index->getType(), + &*B.GetInsertPoint()))); + } + case InductionDescriptor::IK_FpInduction: { + assert(Step->getType()->isFloatingPointTy() && "Expected FP Step value"); + auto InductionBinOp = ID.getInductionBinOp(); + assert(InductionBinOp && + (InductionBinOp->getOpcode() == Instruction::FAdd || + InductionBinOp->getOpcode() == Instruction::FSub) && + "Original bin op should be defined for FP induction"); + + Value *StepValue = cast<SCEVUnknown>(Step)->getValue(); + + // Floating point operations had to be 'fast' to enable the induction. + FastMathFlags Flags; + Flags.setFast(); + + Value *MulExp = B.CreateFMul(StepValue, Index); + if (isa<Instruction>(MulExp)) + // We have to check, the MulExp may be a constant. + cast<Instruction>(MulExp)->setFastMathFlags(Flags); + + Value *BOp = B.CreateBinOp(InductionBinOp->getOpcode(), StartValue, MulExp, + "induction"); + if (isa<Instruction>(BOp)) + cast<Instruction>(BOp)->setFastMathFlags(Flags); + + return BOp; + } + case InductionDescriptor::IK_NoInduction: + return nullptr; + } + llvm_unreachable("invalid enum"); +} + BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { /* In this function we generate a new loop. The new loop will contain @@ -2825,6 +2737,7 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { BasicBlock *OldBasicBlock = OrigLoop->getHeader(); BasicBlock *VectorPH = OrigLoop->getLoopPreheader(); BasicBlock *ExitBlock = OrigLoop->getExitBlock(); + MDNode *OrigLoopID = OrigLoop->getLoopID(); assert(VectorPH && "Invalid loop structure"); assert(ExitBlock && "Must have an exit block"); @@ -2927,7 +2840,7 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { CastInst::getCastOpcode(CountRoundDown, true, StepType, true); Value *CRD = B.CreateCast(CastOp, CountRoundDown, StepType, "cast.crd"); const DataLayout &DL = OrigLoop->getHeader()->getModule()->getDataLayout(); - EndValue = II.transform(B, CRD, PSE.getSE(), DL); + EndValue = emitTransformedIndex(B, CRD, PSE.getSE(), DL, II); EndValue->setName("ind.end"); } @@ -2948,9 +2861,12 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { // Add a check in the middle block to see if we have completed // all of the iterations in the first vector loop. // If (N - N%VF) == N, then we *don't* need to run the remainder. - Value *CmpN = - CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, Count, - CountRoundDown, "cmp.n", MiddleBlock->getTerminator()); + // If tail is to be folded, we know we don't need to run the remainder. + Value *CmpN = Builder.getTrue(); + if (!Cost->foldTailByMasking()) + CmpN = + CmpInst::Create(Instruction::ICmp, CmpInst::ICMP_EQ, Count, + CountRoundDown, "cmp.n", MiddleBlock->getTerminator()); ReplaceInstWithInst(MiddleBlock->getTerminator(), BranchInst::Create(ExitBlock, ScalarPH, CmpN)); @@ -2965,6 +2881,17 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { LoopVectorBody = VecBody; LoopScalarBody = OldBasicBlock; + Optional<MDNode *> VectorizedLoopID = + makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, + LLVMLoopVectorizeFollowupVectorized}); + if (VectorizedLoopID.hasValue()) { + Lp->setLoopID(VectorizedLoopID.getValue()); + + // Do not setAlreadyVectorized if loop attributes have been defined + // explicitly. + return LoopVectorPreHeader; + } + // Keep all loop hints from the original loop on the vector loop (we'll // replace the vectorizer-specific hints below). if (MDNode *LID = OrigLoop->getLoopID()) @@ -3023,7 +2950,7 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, II.getStep()->getType()) : B.CreateSExtOrTrunc(CountMinusOne, II.getStep()->getType()); CMO->setName("cast.cmo"); - Value *Escape = II.transform(B, CMO, PSE.getSE(), DL); + Value *Escape = emitTransformedIndex(B, CMO, PSE.getSE(), DL, II); Escape->setName("ind.escape"); MissingVals[UI] = Escape; } @@ -3109,6 +3036,10 @@ static unsigned getScalarizationOverhead(Instruction *I, unsigned VF, !TTI.supportsEfficientVectorElementLoadStore())) Cost += TTI.getScalarizationOverhead(RetTy, true, false); + // Some targets keep addresses scalar. + if (isa<LoadInst>(I) && !TTI.prefersVectorizedAddressing()) + return Cost; + if (CallInst *CI = dyn_cast<CallInst>(I)) { SmallVector<const Value *, 4> Operands(CI->arg_operands()); Cost += TTI.getOperandsScalarizationOverhead(Operands, VF); @@ -3212,7 +3143,8 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { continue; for (unsigned Part = 0; Part < UF; ++Part) { Value *I = getOrCreateVectorValue(KV.first, Part); - if (Erased.count(I) || I->use_empty() || !isa<Instruction>(I)) + if (Erased.find(I) != Erased.end() || I->use_empty() || + !isa<Instruction>(I)) continue; Type *OriginalTy = I->getType(); Type *ScalarTruncatedTy = @@ -3330,6 +3262,13 @@ void InnerLoopVectorizer::fixVectorizedLoop() { if (VF > 1) truncateToMinimalBitwidths(); + // Fix widened non-induction PHIs by setting up the PHI operands. + if (OrigPHIsToFix.size()) { + assert(EnableVPlanNativePath && + "Unexpected non-induction PHIs for fixup in non VPlan-native path"); + fixNonInductionPHIs(); + } + // At this point every instruction in the original loop is widened to a // vector form. Now we need to fix the recurrences in the loop. These PHI // nodes are currently empty because we did not want to introduce cycles. @@ -3666,8 +3605,8 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxPart, ReducedPartRdx, "bin.rdx")); else - ReducedPartRdx = RecurrenceDescriptor::createMinMaxOp( - Builder, MinMaxKind, ReducedPartRdx, RdxPart); + ReducedPartRdx = createMinMaxOp(Builder, MinMaxKind, ReducedPartRdx, + RdxPart); } if (VF > 1) { @@ -3720,9 +3659,20 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { void InnerLoopVectorizer::fixLCSSAPHIs() { for (PHINode &LCSSAPhi : LoopExitBlock->phis()) { if (LCSSAPhi.getNumIncomingValues() == 1) { - assert(OrigLoop->isLoopInvariant(LCSSAPhi.getIncomingValue(0)) && - "Incoming value isn't loop invariant"); - LCSSAPhi.addIncoming(LCSSAPhi.getIncomingValue(0), LoopMiddleBlock); + auto *IncomingValue = LCSSAPhi.getIncomingValue(0); + // Non-instruction incoming values will have only one value. + unsigned LastLane = 0; + if (isa<Instruction>(IncomingValue)) + LastLane = Cost->isUniformAfterVectorization( + cast<Instruction>(IncomingValue), VF) + ? 0 + : VF - 1; + // Can be a loop invariant incoming value or the last scalar value to be + // extracted from the vectorized loop. + Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); + Value *lastIncomingValue = + getOrCreateScalarValue(IncomingValue, { UF - 1, LastLane }); + LCSSAPhi.addIncoming(lastIncomingValue, LoopMiddleBlock); } } } @@ -3791,12 +3741,62 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { } while (Changed); } +void InnerLoopVectorizer::fixNonInductionPHIs() { + for (PHINode *OrigPhi : OrigPHIsToFix) { + PHINode *NewPhi = + cast<PHINode>(VectorLoopValueMap.getVectorValue(OrigPhi, 0)); + unsigned NumIncomingValues = OrigPhi->getNumIncomingValues(); + + SmallVector<BasicBlock *, 2> ScalarBBPredecessors( + predecessors(OrigPhi->getParent())); + SmallVector<BasicBlock *, 2> VectorBBPredecessors( + predecessors(NewPhi->getParent())); + assert(ScalarBBPredecessors.size() == VectorBBPredecessors.size() && + "Scalar and Vector BB should have the same number of predecessors"); + + // The insertion point in Builder may be invalidated by the time we get + // here. Force the Builder insertion point to something valid so that we do + // not run into issues during insertion point restore in + // getOrCreateVectorValue calls below. + Builder.SetInsertPoint(NewPhi); + + // The predecessor order is preserved and we can rely on mapping between + // scalar and vector block predecessors. + for (unsigned i = 0; i < NumIncomingValues; ++i) { + BasicBlock *NewPredBB = VectorBBPredecessors[i]; + + // When looking up the new scalar/vector values to fix up, use incoming + // values from original phi. + Value *ScIncV = + OrigPhi->getIncomingValueForBlock(ScalarBBPredecessors[i]); + + // Scalar incoming value may need a broadcast + Value *NewIncV = getOrCreateVectorValue(ScIncV, 0); + NewPhi->addIncoming(NewIncV, NewPredBB); + } + } +} + void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF) { + PHINode *P = cast<PHINode>(PN); + if (EnableVPlanNativePath) { + // Currently we enter here in the VPlan-native path for non-induction + // PHIs where all control flow is uniform. We simply widen these PHIs. + // Create a vector phi with no operands - the vector phi operands will be + // set at the end of vector code generation. + Type *VecTy = + (VF == 1) ? PN->getType() : VectorType::get(PN->getType(), VF); + Value *VecPhi = Builder.CreatePHI(VecTy, PN->getNumOperands(), "vec.phi"); + VectorLoopValueMap.setVectorValue(P, 0, VecPhi); + OrigPHIsToFix.push_back(P); + + return; + } + assert(PN->getParent() == OrigLoop->getHeader() && "Non-header phis should have been handled elsewhere"); - PHINode *P = cast<PHINode>(PN); // In order to support recurrences we need to be able to vectorize Phi nodes. // Phi nodes have cycles, so we need to vectorize them in two stages. This is // stage #1: We create a new vector PHI node with no incoming edges. We'll use @@ -3846,7 +3846,8 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, for (unsigned Lane = 0; Lane < Lanes; ++Lane) { Constant *Idx = ConstantInt::get(PtrInd->getType(), Lane + Part * VF); Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); - Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); + Value *SclrGep = + emitTransformedIndex(Builder, GlobalIdx, PSE.getSE(), DL, II); SclrGep->setName("next.gep"); VectorLoopValueMap.setScalarValue(P, {Part, Lane}, SclrGep); } @@ -4151,6 +4152,10 @@ void InnerLoopVectorizer::updateAnalysis() { // Forget the original basic block. PSE.getSE()->forgetLoop(OrigLoop); + // DT is not kept up-to-date for outer loop vectorization + if (EnableVPlanNativePath) + return; + // Update the dominator tree information. assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) && "Entry does not dominate exit."); @@ -4167,7 +4172,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { // We should not collect Scalars more than once per VF. Right now, this // function is called from collectUniformsAndScalars(), which already does // this check. Collecting Scalars for VF=1 does not make any sense. - assert(VF >= 2 && !Scalars.count(VF) && + assert(VF >= 2 && Scalars.find(VF) == Scalars.end() && "This function should not be visited twice for the same VF"); SmallSetVector<Instruction *, 8> Worklist; @@ -4253,7 +4258,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { } } for (auto *I : ScalarPtrs) - if (!PossibleNonScalarPtrs.count(I)) { + if (PossibleNonScalarPtrs.find(I) == PossibleNonScalarPtrs.end()) { LLVM_DEBUG(dbgs() << "LV: Found scalar instruction: " << *I << "\n"); Worklist.insert(I); } @@ -4279,8 +4284,9 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { // Insert the forced scalars. // FIXME: Currently widenPHIInstruction() often creates a dead vector // induction variable when the PHI user is scalarized. - if (ForcedScalars.count(VF)) - for (auto *I : ForcedScalars.find(VF)->second) + auto ForcedScalar = ForcedScalars.find(VF); + if (ForcedScalar != ForcedScalars.end()) + for (auto *I : ForcedScalar->second) Worklist.insert(I); // Expand the worklist by looking through any bitcasts and getelementptr @@ -4348,8 +4354,8 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { Scalars[VF].insert(Worklist.begin(), Worklist.end()); } -bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I) { - if (!Legal->blockNeedsPredication(I->getParent())) +bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I, unsigned VF) { + if (!blockNeedsPredication(I->getParent())) return false; switch(I->getOpcode()) { default: @@ -4360,6 +4366,14 @@ bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I) { return false; auto *Ptr = getLoadStorePointerOperand(I); auto *Ty = getMemInstValueType(I); + // We have already decided how to vectorize this instruction, get that + // result. + if (VF > 1) { + InstWidening WideningDecision = getWideningDecision(I, VF); + assert(WideningDecision != CM_Unknown && + "Widening decision should be ready at this moment"); + return WideningDecision == CM_Scalarize; + } return isa<LoadInst>(I) ? !(isLegalMaskedLoad(Ty, Ptr) || isLegalMaskedGather(Ty)) : !(isLegalMaskedStore(Ty, Ptr) || isLegalMaskedScatter(Ty)); @@ -4373,6 +4387,35 @@ bool LoopVectorizationCostModel::isScalarWithPredication(Instruction *I) { return false; } +bool LoopVectorizationCostModel::interleavedAccessCanBeWidened(Instruction *I, + unsigned VF) { + assert(isAccessInterleaved(I) && "Expecting interleaved access."); + assert(getWideningDecision(I, VF) == CM_Unknown && + "Decision should not be set yet."); + auto *Group = getInterleavedAccessGroup(I); + assert(Group && "Must have a group."); + + // Check if masking is required. + // A Group may need masking for one of two reasons: it resides in a block that + // needs predication, or it was decided to use masking to deal with gaps. + bool PredicatedAccessRequiresMasking = + Legal->blockNeedsPredication(I->getParent()) && Legal->isMaskRequired(I); + bool AccessWithGapsRequiresMasking = + Group->requiresScalarEpilogue() && !IsScalarEpilogueAllowed; + if (!PredicatedAccessRequiresMasking && !AccessWithGapsRequiresMasking) + return true; + + // If masked interleaving is required, we expect that the user/target had + // enabled it, because otherwise it either wouldn't have been created or + // it should have been invalidated by the CostModel. + assert(useMaskedInterleavedAccesses(TTI) && + "Masked interleave-groups for predicated accesses are not enabled."); + + auto *Ty = getMemInstValueType(I); + return isa<LoadInst>(I) ? TTI.isLegalMaskedLoad(Ty) + : TTI.isLegalMaskedStore(Ty); +} + bool LoopVectorizationCostModel::memoryInstructionCanBeWidened(Instruction *I, unsigned VF) { // Get and ensure we have a valid memory instruction. @@ -4407,7 +4450,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // already does this check. Collecting Uniforms for VF=1 does not make any // sense. - assert(VF >= 2 && !Uniforms.count(VF) && + assert(VF >= 2 && Uniforms.find(VF) == Uniforms.end() && "This function should not be visited twice for the same VF"); // Visit the list of Uniforms. If we'll not find any uniform value, we'll @@ -4494,26 +4537,33 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // Add to the Worklist all consecutive and consecutive-like pointers that // aren't also identified as possibly non-uniform. for (auto *V : ConsecutiveLikePtrs) - if (!PossibleNonUniformPtrs.count(V)) { + if (PossibleNonUniformPtrs.find(V) == PossibleNonUniformPtrs.end()) { LLVM_DEBUG(dbgs() << "LV: Found uniform instruction: " << *V << "\n"); Worklist.insert(V); } // Expand Worklist in topological order: whenever a new instruction - // is added , its users should be either already inside Worklist, or - // out of scope. It ensures a uniform instruction will only be used - // by uniform instructions or out of scope instructions. + // is added , its users should be already inside Worklist. It ensures + // a uniform instruction will only be used by uniform instructions. unsigned idx = 0; while (idx != Worklist.size()) { Instruction *I = Worklist[idx++]; for (auto OV : I->operand_values()) { + // isOutOfScope operands cannot be uniform instructions. if (isOutOfScope(OV)) continue; + // First order recurrence Phi's should typically be considered + // non-uniform. + auto *OP = dyn_cast<PHINode>(OV); + if (OP && Legal->isFirstOrderRecurrence(OP)) + continue; + // If all the users of the operand are uniform, then add the + // operand into the uniform worklist. auto *OI = cast<Instruction>(OV); if (llvm::all_of(OI->users(), [&](User *U) -> bool { auto *J = cast<Instruction>(U); - return !TheLoop->contains(J) || Worklist.count(J) || + return Worklist.count(J) || (OI == getLoadStorePointerOperand(J) && isUniformDecision(J, VF)); })) { @@ -4571,318 +4621,6 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { Uniforms[VF].insert(Worklist.begin(), Worklist.end()); } -void InterleavedAccessInfo::collectConstStrideAccesses( - MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, - const ValueToValueMap &Strides) { - auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); - - // Since it's desired that the load/store instructions be maintained in - // "program order" for the interleaved access analysis, we have to visit the - // blocks in the loop in reverse postorder (i.e., in a topological order). - // Such an ordering will ensure that any load/store that may be executed - // before a second load/store will precede the second load/store in - // AccessStrideInfo. - LoopBlocksDFS DFS(TheLoop); - DFS.perform(LI); - for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) - for (auto &I : *BB) { - auto *LI = dyn_cast<LoadInst>(&I); - auto *SI = dyn_cast<StoreInst>(&I); - if (!LI && !SI) - continue; - - Value *Ptr = getLoadStorePointerOperand(&I); - // We don't check wrapping here because we don't know yet if Ptr will be - // part of a full group or a group with gaps. Checking wrapping for all - // pointers (even those that end up in groups with no gaps) will be overly - // conservative. For full groups, wrapping should be ok since if we would - // wrap around the address space we would do a memory access at nullptr - // even without the transformation. The wrapping checks are therefore - // deferred until after we've formed the interleaved groups. - int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, - /*Assume=*/true, /*ShouldCheckWrap=*/false); - - const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); - PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); - uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); - - // An alignment of 0 means target ABI alignment. - unsigned Align = getMemInstAlignment(&I); - if (!Align) - Align = DL.getABITypeAlignment(PtrTy->getElementType()); - - AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, Align); - } -} - -// Analyze interleaved accesses and collect them into interleaved load and -// store groups. -// -// When generating code for an interleaved load group, we effectively hoist all -// loads in the group to the location of the first load in program order. When -// generating code for an interleaved store group, we sink all stores to the -// location of the last store. This code motion can change the order of load -// and store instructions and may break dependences. -// -// The code generation strategy mentioned above ensures that we won't violate -// any write-after-read (WAR) dependences. -// -// E.g., for the WAR dependence: a = A[i]; // (1) -// A[i] = b; // (2) -// -// The store group of (2) is always inserted at or below (2), and the load -// group of (1) is always inserted at or above (1). Thus, the instructions will -// never be reordered. All other dependences are checked to ensure the -// correctness of the instruction reordering. -// -// The algorithm visits all memory accesses in the loop in bottom-up program -// order. Program order is established by traversing the blocks in the loop in -// reverse postorder when collecting the accesses. -// -// We visit the memory accesses in bottom-up order because it can simplify the -// construction of store groups in the presence of write-after-write (WAW) -// dependences. -// -// E.g., for the WAW dependence: A[i] = a; // (1) -// A[i] = b; // (2) -// A[i + 1] = c; // (3) -// -// We will first create a store group with (3) and (2). (1) can't be added to -// this group because it and (2) are dependent. However, (1) can be grouped -// with other accesses that may precede it in program order. Note that a -// bottom-up order does not imply that WAW dependences should not be checked. -void InterleavedAccessInfo::analyzeInterleaving() { - LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); - const ValueToValueMap &Strides = LAI->getSymbolicStrides(); - - // Holds all accesses with a constant stride. - MapVector<Instruction *, StrideDescriptor> AccessStrideInfo; - collectConstStrideAccesses(AccessStrideInfo, Strides); - - if (AccessStrideInfo.empty()) - return; - - // Collect the dependences in the loop. - collectDependences(); - - // Holds all interleaved store groups temporarily. - SmallSetVector<InterleaveGroup *, 4> StoreGroups; - // Holds all interleaved load groups temporarily. - SmallSetVector<InterleaveGroup *, 4> LoadGroups; - - // Search in bottom-up program order for pairs of accesses (A and B) that can - // form interleaved load or store groups. In the algorithm below, access A - // precedes access B in program order. We initialize a group for B in the - // outer loop of the algorithm, and then in the inner loop, we attempt to - // insert each A into B's group if: - // - // 1. A and B have the same stride, - // 2. A and B have the same memory object size, and - // 3. A belongs in B's group according to its distance from B. - // - // Special care is taken to ensure group formation will not break any - // dependences. - for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend(); - BI != E; ++BI) { - Instruction *B = BI->first; - StrideDescriptor DesB = BI->second; - - // Initialize a group for B if it has an allowable stride. Even if we don't - // create a group for B, we continue with the bottom-up algorithm to ensure - // we don't break any of B's dependences. - InterleaveGroup *Group = nullptr; - if (isStrided(DesB.Stride)) { - Group = getInterleaveGroup(B); - if (!Group) { - LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B - << '\n'); - Group = createInterleaveGroup(B, DesB.Stride, DesB.Align); - } - if (B->mayWriteToMemory()) - StoreGroups.insert(Group); - else - LoadGroups.insert(Group); - } - - for (auto AI = std::next(BI); AI != E; ++AI) { - Instruction *A = AI->first; - StrideDescriptor DesA = AI->second; - - // Our code motion strategy implies that we can't have dependences - // between accesses in an interleaved group and other accesses located - // between the first and last member of the group. Note that this also - // means that a group can't have more than one member at a given offset. - // The accesses in a group can have dependences with other accesses, but - // we must ensure we don't extend the boundaries of the group such that - // we encompass those dependent accesses. - // - // For example, assume we have the sequence of accesses shown below in a - // stride-2 loop: - // - // (1, 2) is a group | A[i] = a; // (1) - // | A[i-1] = b; // (2) | - // A[i-3] = c; // (3) - // A[i] = d; // (4) | (2, 4) is not a group - // - // Because accesses (2) and (3) are dependent, we can group (2) with (1) - // but not with (4). If we did, the dependent access (3) would be within - // the boundaries of the (2, 4) group. - if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { - // If a dependence exists and A is already in a group, we know that A - // must be a store since A precedes B and WAR dependences are allowed. - // Thus, A would be sunk below B. We release A's group to prevent this - // illegal code motion. A will then be free to form another group with - // instructions that precede it. - if (isInterleaved(A)) { - InterleaveGroup *StoreGroup = getInterleaveGroup(A); - StoreGroups.remove(StoreGroup); - releaseGroup(StoreGroup); - } - - // If a dependence exists and A is not already in a group (or it was - // and we just released it), B might be hoisted above A (if B is a - // load) or another store might be sunk below A (if B is a store). In - // either case, we can't add additional instructions to B's group. B - // will only form a group with instructions that it precedes. - break; - } - - // At this point, we've checked for illegal code motion. If either A or B - // isn't strided, there's nothing left to do. - if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride)) - continue; - - // Ignore A if it's already in a group or isn't the same kind of memory - // operation as B. - // Note that mayReadFromMemory() isn't mutually exclusive to mayWriteToMemory - // in the case of atomic loads. We shouldn't see those here, canVectorizeMemory() - // should have returned false - except for the case we asked for optimization - // remarks. - if (isInterleaved(A) || (A->mayReadFromMemory() != B->mayReadFromMemory()) - || (A->mayWriteToMemory() != B->mayWriteToMemory())) - continue; - - // Check rules 1 and 2. Ignore A if its stride or size is different from - // that of B. - if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) - continue; - - // Ignore A if the memory object of A and B don't belong to the same - // address space - if (getMemInstAddressSpace(A) != getMemInstAddressSpace(B)) - continue; - - // Calculate the distance from A to B. - const SCEVConstant *DistToB = dyn_cast<SCEVConstant>( - PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); - if (!DistToB) - continue; - int64_t DistanceToB = DistToB->getAPInt().getSExtValue(); - - // Check rule 3. Ignore A if its distance to B is not a multiple of the - // size. - if (DistanceToB % static_cast<int64_t>(DesB.Size)) - continue; - - // Ignore A if either A or B is in a predicated block. Although we - // currently prevent group formation for predicated accesses, we may be - // able to relax this limitation in the future once we handle more - // complicated blocks. - if (isPredicated(A->getParent()) || isPredicated(B->getParent())) - continue; - - // The index of A is the index of B plus A's distance to B in multiples - // of the size. - int IndexA = - Group->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size); - - // Try to insert A into B's group. - if (Group->insertMember(A, IndexA, DesA.Align)) { - LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' - << " into the interleave group with" << *B - << '\n'); - InterleaveGroupMap[A] = Group; - - // Set the first load in program order as the insert position. - if (A->mayReadFromMemory()) - Group->setInsertPos(A); - } - } // Iteration over A accesses. - } // Iteration over B accesses. - - // Remove interleaved store groups with gaps. - for (InterleaveGroup *Group : StoreGroups) - if (Group->getNumMembers() != Group->getFactor()) { - LLVM_DEBUG( - dbgs() << "LV: Invalidate candidate interleaved store group due " - "to gaps.\n"); - releaseGroup(Group); - } - // Remove interleaved groups with gaps (currently only loads) whose memory - // accesses may wrap around. We have to revisit the getPtrStride analysis, - // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does - // not check wrapping (see documentation there). - // FORNOW we use Assume=false; - // TODO: Change to Assume=true but making sure we don't exceed the threshold - // of runtime SCEV assumptions checks (thereby potentially failing to - // vectorize altogether). - // Additional optional optimizations: - // TODO: If we are peeling the loop and we know that the first pointer doesn't - // wrap then we can deduce that all pointers in the group don't wrap. - // This means that we can forcefully peel the loop in order to only have to - // check the first pointer for no-wrap. When we'll change to use Assume=true - // we'll only need at most one runtime check per interleaved group. - for (InterleaveGroup *Group : LoadGroups) { - // Case 1: A full group. Can Skip the checks; For full groups, if the wide - // load would wrap around the address space we would do a memory access at - // nullptr even without the transformation. - if (Group->getNumMembers() == Group->getFactor()) - continue; - - // Case 2: If first and last members of the group don't wrap this implies - // that all the pointers in the group don't wrap. - // So we check only group member 0 (which is always guaranteed to exist), - // and group member Factor - 1; If the latter doesn't exist we rely on - // peeling (if it is a non-reveresed accsess -- see Case 3). - Value *FirstMemberPtr = getLoadStorePointerOperand(Group->getMember(0)); - if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, - /*ShouldCheckWrap=*/true)) { - LLVM_DEBUG( - dbgs() << "LV: Invalidate candidate interleaved group due to " - "first group member potentially pointer-wrapping.\n"); - releaseGroup(Group); - continue; - } - Instruction *LastMember = Group->getMember(Group->getFactor() - 1); - if (LastMember) { - Value *LastMemberPtr = getLoadStorePointerOperand(LastMember); - if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, - /*ShouldCheckWrap=*/true)) { - LLVM_DEBUG( - dbgs() << "LV: Invalidate candidate interleaved group due to " - "last group member potentially pointer-wrapping.\n"); - releaseGroup(Group); - } - } else { - // Case 3: A non-reversed interleaved load group with gaps: We need - // to execute at least one scalar epilogue iteration. This will ensure - // we don't speculatively access memory out-of-bounds. We only need - // to look for a member at index factor - 1, since every group must have - // a member at index zero. - if (Group->isReverse()) { - LLVM_DEBUG( - dbgs() << "LV: Invalidate candidate interleaved group due to " - "a reverse access with gaps.\n"); - releaseGroup(Group); - continue; - } - LLVM_DEBUG( - dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); - RequiresScalarEpilogue = true; - } - } -} - Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { if (Legal->getRuntimePointerChecking()->Need && TTI.hasBranchDivergence()) { // TODO: It may by useful to do since it's still likely to be dynamically @@ -4912,39 +4650,78 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { return None; } + if (!PSE.getUnionPredicate().getPredicates().empty()) { + ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") + << "runtime SCEV checks needed. Enable vectorization of this " + "loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz"); + LLVM_DEBUG( + dbgs() + << "LV: Aborting. Runtime SCEV check is required with -Os/-Oz.\n"); + return None; + } + + // FIXME: Avoid specializing for stride==1 instead of bailing out. + if (!Legal->getLAI()->getSymbolicStrides().empty()) { + ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") + << "runtime stride == 1 checks needed. Enable vectorization of " + "this loop with '#pragma clang loop vectorize(enable)' when " + "compiling with -Os/-Oz"); + LLVM_DEBUG( + dbgs() + << "LV: Aborting. Runtime stride check is required with -Os/-Oz.\n"); + return None; + } + // If we optimize the program for size, avoid creating the tail loop. LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); - // If we don't know the precise trip count, don't try to vectorize. - if (TC < 2) { - ORE->emit( - createMissedAnalysis("UnknownLoopCountComplexCFG") - << "unable to calculate the loop count due to complex control flow"); - LLVM_DEBUG( - dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + if (TC == 1) { + ORE->emit(createMissedAnalysis("SingleIterationLoop") + << "loop trip count is one, irrelevant for vectorization"); + LLVM_DEBUG(dbgs() << "LV: Aborting, single iteration (non) loop.\n"); return None; } + // Record that scalar epilogue is not allowed. + LLVM_DEBUG(dbgs() << "LV: Not allowing scalar epilogue due to -Os/-Oz.\n"); + + IsScalarEpilogueAllowed = !OptForSize; + + // We don't create an epilogue when optimizing for size. + // Invalidate interleave groups that require an epilogue if we can't mask + // the interleave-group. + if (!useMaskedInterleavedAccesses(TTI)) + InterleaveInfo.invalidateGroupsRequiringScalarEpilogue(); + unsigned MaxVF = computeFeasibleMaxVF(OptForSize, TC); - if (TC % MaxVF != 0) { - // If the trip count that we found modulo the vectorization factor is not - // zero then we require a tail. - // FIXME: look for a smaller MaxVF that does divide TC rather than give up. - // FIXME: return None if loop requiresScalarEpilog(<MaxVF>), or look for a - // smaller MaxVF that does not require a scalar epilog. - - ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") - << "cannot optimize for size and vectorize at the " - "same time. Enable vectorization of this loop " - "with '#pragma clang loop vectorize(enable)' " - "when compiling with -Os/-Oz"); - LLVM_DEBUG( - dbgs() << "LV: Aborting. A tail loop is required with -Os/-Oz.\n"); + if (TC > 0 && TC % MaxVF == 0) { + LLVM_DEBUG(dbgs() << "LV: No tail will remain for any chosen VF.\n"); + return MaxVF; + } + + // If we don't know the precise trip count, or if the trip count that we + // found modulo the vectorization factor is not zero, try to fold the tail + // by masking. + // FIXME: look for a smaller MaxVF that does divide TC rather than masking. + if (Legal->canFoldTailByMasking()) { + FoldTailByMasking = true; + return MaxVF; + } + + if (TC == 0) { + ORE->emit( + createMissedAnalysis("UnknownLoopCountComplexCFG") + << "unable to calculate the loop count due to complex control flow"); return None; } - return MaxVF; + ORE->emit(createMissedAnalysis("NoTailLoopWithOptForSize") + << "cannot optimize for size and vectorize at the same time. " + "Enable vectorization of this loop with '#pragma clang loop " + "vectorize(enable)' when compiling with -Os/-Oz"); + return None; } unsigned @@ -5080,11 +4857,11 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { // For each block. for (BasicBlock *BB : TheLoop->blocks()) { // For each instruction in the loop. - for (Instruction &I : *BB) { + for (Instruction &I : BB->instructionsWithoutDebug()) { Type *T = I.getType(); // Skip ignored values. - if (ValuesToIgnore.count(&I)) + if (ValuesToIgnore.find(&I) != ValuesToIgnore.end()) continue; // Only examine Loads, Stores and PHINodes. @@ -5182,6 +4959,9 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // fit without causing spills. All of this is rounded down if necessary to be // a power of two. We want power of two interleave count to simplify any // addressing operations or alignment considerations. + // We also want power of two interleave counts to ensure that the induction + // variable of the vector loop wraps to zero, when tail is folded by masking; + // this currently happens when OptForSize, in which case IC is set to 1 above. unsigned IC = PowerOf2Floor((TargetNumRegisters - R.LoopInvariantRegs) / R.MaxLocalUsers); @@ -5307,7 +5087,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { using IntervalMap = DenseMap<Instruction *, unsigned>; // Maps instruction to its index. - DenseMap<unsigned, Instruction *> IdxToInstr; + SmallVector<Instruction *, 64> IdxToInstr; // Marks the end of each interval. IntervalMap EndPoint; // Saves the list of instruction indices that are used in the loop. @@ -5316,10 +5096,9 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { // defined outside the loop, such as arguments and constants. SmallPtrSet<Value *, 8> LoopInvariants; - unsigned Index = 0; for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { - for (Instruction &I : *BB) { - IdxToInstr[Index++] = &I; + for (Instruction &I : BB->instructionsWithoutDebug()) { + IdxToInstr.push_back(&I); // Save the end location of each USE. for (Value *U : I.operands()) { @@ -5336,7 +5115,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { } // Overwrite previous end points. - EndPoint[Instr] = Index; + EndPoint[Instr] = IdxToInstr.size(); Ends.insert(Instr); } } @@ -5373,7 +5152,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { return std::max<unsigned>(1, VF * TypeSize / WidestRegister); }; - for (unsigned int i = 0; i < Index; ++i) { + for (unsigned int i = 0, s = IdxToInstr.size(); i < s; ++i) { Instruction *I = IdxToInstr[i]; // Remove all of the instructions that end at this location. @@ -5382,11 +5161,11 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { OpenIntervals.erase(ToRemove); // Ignore instructions that are never used within the loop. - if (!Ends.count(I)) + if (Ends.find(I) == Ends.end()) continue; // Skip ignored values. - if (ValuesToIgnore.count(I)) + if (ValuesToIgnore.find(I) != ValuesToIgnore.end()) continue; // For each VF find the maximum usage of registers. @@ -5400,7 +5179,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { unsigned RegUsage = 0; for (auto Inst : OpenIntervals) { // Skip ignored values for VF > 1. - if (VecValuesToIgnore.count(Inst) || + if (VecValuesToIgnore.find(Inst) != VecValuesToIgnore.end() || isScalarAfterVectorization(Inst, VFs[j])) continue; RegUsage += GetRegUsage(Inst->getType(), VFs[j]); @@ -5446,8 +5225,7 @@ bool LoopVectorizationCostModel::useEmulatedMaskMemRefHack(Instruction *I){ // from moving "masked load/store" check from legality to cost model. // Masked Load/Gather emulation was previously never allowed. // Limited number of Masked Store/Scatter emulation was allowed. - assert(isScalarWithPredication(I) && - "Expecting a scalar emulated instruction"); + assert(isPredicatedInst(I) && "Expecting a scalar emulated instruction"); return isa<LoadInst>(I) || (isa<StoreInst>(I) && NumPredStores > NumberOfStoresToPredicate); @@ -5458,7 +5236,7 @@ void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { // instructions to scalarize, there's nothing to do. Collection may already // have occurred if we have a user-selected VF and are now computing the // expected cost for interleaving. - if (VF < 2 || InstsToScalarize.count(VF)) + if (VF < 2 || InstsToScalarize.find(VF) != InstsToScalarize.end()) return; // Initialize a mapping for VF in InstsToScalalarize. If we find that it's @@ -5470,7 +5248,7 @@ void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { // determine if it would be better to not if-convert the blocks they are in. // If so, we also record the instructions to scalarize. for (BasicBlock *BB : TheLoop->blocks()) { - if (!Legal->blockNeedsPredication(BB)) + if (!blockNeedsPredication(BB)) continue; for (Instruction &I : *BB) if (isScalarWithPredication(&I)) { @@ -5553,7 +5331,7 @@ int LoopVectorizationCostModel::computePredInstDiscount( Instruction *I = Worklist.pop_back_val(); // If we've already analyzed the instruction, there's nothing to do. - if (ScalarCosts.count(I)) + if (ScalarCosts.find(I) != ScalarCosts.end()) continue; // Compute the cost of the vector instruction. Note that this cost already @@ -5612,8 +5390,8 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { // For each instruction in the old loop. for (Instruction &I : BB->instructionsWithoutDebug()) { // Skip ignored values. - if (ValuesToIgnore.count(&I) || - (VF > 1 && VecValuesToIgnore.count(&I))) + if (ValuesToIgnore.find(&I) != ValuesToIgnore.end() || + (VF > 1 && VecValuesToIgnore.find(&I) != VecValuesToIgnore.end())) continue; VectorizationCostTy C = getInstructionCost(&I, VF); @@ -5635,7 +5413,7 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { // unconditionally executed. For the scalar case, we may not always execute // the predicated block. Thus, scale the block's cost by the probability of // executing it. - if (VF == 1 && Legal->blockNeedsPredication(BB)) + if (VF == 1 && blockNeedsPredication(BB)) BlockCost.first /= getReciprocalPredBlockProb(); Cost.first += BlockCost.first; @@ -5682,11 +5460,12 @@ static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, unsigned VF) { + assert(VF > 1 && "Scalarization cost of instruction implies vectorization."); Type *ValTy = getMemInstValueType(I); auto SE = PSE.getSE(); - unsigned Alignment = getMemInstAlignment(I); - unsigned AS = getMemInstAddressSpace(I); + unsigned Alignment = getLoadStoreAlignment(I); + unsigned AS = getLoadStoreAddressSpace(I); Value *Ptr = getLoadStorePointerOperand(I); Type *PtrTy = ToVectorTy(Ptr->getType(), VF); @@ -5697,9 +5476,11 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // Get the cost of the scalar memory instruction and address computation. unsigned Cost = VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); + // Don't pass *I here, since it is scalar but will actually be part of a + // vectorized loop where the user of it is a vectorized instruction. Cost += VF * TTI.getMemoryOpCost(I->getOpcode(), ValTy->getScalarType(), Alignment, - AS, I); + AS); // Get the overhead of the extractelement and insertelement instructions // we might create due to scalarization. @@ -5708,7 +5489,7 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // If we have a predicated store, it may not be executed for each vector // lane. Scale the cost by the probability of executing the predicated // block. - if (isScalarWithPredication(I)) { + if (isPredicatedInst(I)) { Cost /= getReciprocalPredBlockProb(); if (useEmulatedMaskMemRefHack(I)) @@ -5724,9 +5505,9 @@ unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, unsigned VF) { Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); - unsigned Alignment = getMemInstAlignment(I); + unsigned Alignment = getLoadStoreAlignment(I); Value *Ptr = getLoadStorePointerOperand(I); - unsigned AS = getMemInstAddressSpace(I); + unsigned AS = getLoadStoreAddressSpace(I); int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && @@ -5745,22 +5526,30 @@ unsigned LoopVectorizationCostModel::getConsecutiveMemOpCost(Instruction *I, unsigned LoopVectorizationCostModel::getUniformMemOpCost(Instruction *I, unsigned VF) { - LoadInst *LI = cast<LoadInst>(I); - Type *ValTy = LI->getType(); + Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); - unsigned Alignment = LI->getAlignment(); - unsigned AS = LI->getPointerAddressSpace(); + unsigned Alignment = getLoadStoreAlignment(I); + unsigned AS = getLoadStoreAddressSpace(I); + if (isa<LoadInst>(I)) { + return TTI.getAddressComputationCost(ValTy) + + TTI.getMemoryOpCost(Instruction::Load, ValTy, Alignment, AS) + + TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VectorTy); + } + StoreInst *SI = cast<StoreInst>(I); + bool isLoopInvariantStoreValue = Legal->isUniform(SI->getValueOperand()); return TTI.getAddressComputationCost(ValTy) + - TTI.getMemoryOpCost(Instruction::Load, ValTy, Alignment, AS) + - TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VectorTy); + TTI.getMemoryOpCost(Instruction::Store, ValTy, Alignment, AS) + + (isLoopInvariantStoreValue ? 0 : TTI.getVectorInstrCost( + Instruction::ExtractElement, + VectorTy, VF - 1)); } unsigned LoopVectorizationCostModel::getGatherScatterCost(Instruction *I, unsigned VF) { Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); - unsigned Alignment = getMemInstAlignment(I); + unsigned Alignment = getLoadStoreAlignment(I); Value *Ptr = getLoadStorePointerOperand(I); return TTI.getAddressComputationCost(VectorTy) + @@ -5772,7 +5561,7 @@ unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, unsigned VF) { Type *ValTy = getMemInstValueType(I); Type *VectorTy = ToVectorTy(ValTy, VF); - unsigned AS = getMemInstAddressSpace(I); + unsigned AS = getLoadStoreAddressSpace(I); auto Group = getInterleavedAccessGroup(I); assert(Group && "Fail to get an interleaved access group."); @@ -5790,13 +5579,19 @@ unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, } // Calculate the cost of the whole interleaved group. - unsigned Cost = TTI.getInterleavedMemoryOpCost(I->getOpcode(), WideVecTy, - Group->getFactor(), Indices, - Group->getAlignment(), AS); - - if (Group->isReverse()) + bool UseMaskForGaps = + Group->requiresScalarEpilogue() && !IsScalarEpilogueAllowed; + unsigned Cost = TTI.getInterleavedMemoryOpCost( + I->getOpcode(), WideVecTy, Group->getFactor(), Indices, + Group->getAlignment(), AS, Legal->isMaskRequired(I), UseMaskForGaps); + + if (Group->isReverse()) { + // TODO: Add support for reversed masked interleaved access. + assert(!Legal->isMaskRequired(I) && + "Reverse masked interleaved access not supported."); Cost += Group->getNumMembers() * TTI.getShuffleCost(TargetTransformInfo::SK_Reverse, VectorTy, 0); + } return Cost; } @@ -5806,8 +5601,8 @@ unsigned LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I, // moment. if (VF == 1) { Type *ValTy = getMemInstValueType(I); - unsigned Alignment = getMemInstAlignment(I); - unsigned AS = getMemInstAddressSpace(I); + unsigned Alignment = getLoadStoreAlignment(I); + unsigned AS = getLoadStoreAddressSpace(I); return TTI.getAddressComputationCost(ValTy) + TTI.getMemoryOpCost(I->getOpcode(), ValTy, Alignment, AS, I); @@ -5826,9 +5621,12 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { return VectorizationCostTy(InstsToScalarize[VF][I], false); // Forced scalars do not have any scalarization overhead. - if (VF > 1 && ForcedScalars.count(VF) && - ForcedScalars.find(VF)->second.count(I)) - return VectorizationCostTy((getInstructionCost(I, 1).first * VF), false); + auto ForcedScalar = ForcedScalars.find(VF); + if (VF > 1 && ForcedScalar != ForcedScalars.end()) { + auto InstSet = ForcedScalar->second; + if (InstSet.find(I) != InstSet.end()) + return VectorizationCostTy((getInstructionCost(I, 1).first * VF), false); + } Type *VectorTy; unsigned C = getInstructionCost(I, VF, VectorTy); @@ -5849,10 +5647,22 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { if (!Ptr) continue; + // TODO: We should generate better code and update the cost model for + // predicated uniform stores. Today they are treated as any other + // predicated store (see added test cases in + // invariant-store-vectorization.ll). if (isa<StoreInst>(&I) && isScalarWithPredication(&I)) NumPredStores++; - if (isa<LoadInst>(&I) && Legal->isUniform(Ptr)) { - // Scalar load + broadcast + + if (Legal->isUniform(Ptr) && + // Conditional loads and stores should be scalarized and predicated. + // isScalarWithPredication cannot be used here since masked + // gather/scatters are not considered scalar with predication. + !Legal->blockNeedsPredication(I.getParent())) { + // TODO: Avoid replicating loads and stores instead of + // relying on instcombine to remove them. + // Load: Scalar load + broadcast + // Store: Scalar store + isLoopInvariantStoreValue ? 0 : extract unsigned Cost = getUniformMemOpCost(&I, VF); setWideningDecision(&I, VF, CM_Scalarize, Cost); continue; @@ -5883,7 +5693,8 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { continue; NumAccesses = Group->getNumMembers(); - InterleaveCost = getInterleaveGroupCost(&I, VF); + if (interleavedAccessCanBeWidened(&I, VF)) + InterleaveCost = getInterleaveGroupCost(&I, VF); } unsigned GatherScatterCost = @@ -6001,8 +5812,10 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, bool ScalarPredicatedBB = false; BranchInst *BI = cast<BranchInst>(I); if (VF > 1 && BI->isConditional() && - (PredicatedBBsAfterVectorization.count(BI->getSuccessor(0)) || - PredicatedBBsAfterVectorization.count(BI->getSuccessor(1)))) + (PredicatedBBsAfterVectorization.find(BI->getSuccessor(0)) != + PredicatedBBsAfterVectorization.end() || + PredicatedBBsAfterVectorization.find(BI->getSuccessor(1)) != + PredicatedBBsAfterVectorization.end())) ScalarPredicatedBB = true; if (ScalarPredicatedBB) { @@ -6025,9 +5838,10 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, auto *Phi = cast<PHINode>(I); // First-order recurrences are replaced by vector shuffles inside the loop. + // NOTE: Don't use ToVectorTy as SK_ExtractSubvector expects a vector type. if (VF > 1 && Legal->isFirstOrderRecurrence(Phi)) return TTI.getShuffleCost(TargetTransformInfo::SK_ExtractSubvector, - VectorTy, VF - 1, VectorTy); + VectorTy, VF - 1, VectorType::get(RetTy, 1)); // Phi nodes in non-header blocks (not inductions, reductions, etc.) are // converted into select instructions. We require N - 1 selects per phi @@ -6089,38 +5903,18 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, return 0; // Certain instructions can be cheaper to vectorize if they have a constant // second vector operand. One example of this are shifts on x86. - TargetTransformInfo::OperandValueKind Op1VK = - TargetTransformInfo::OK_AnyValue; - TargetTransformInfo::OperandValueKind Op2VK = - TargetTransformInfo::OK_AnyValue; - TargetTransformInfo::OperandValueProperties Op1VP = - TargetTransformInfo::OP_None; - TargetTransformInfo::OperandValueProperties Op2VP = - TargetTransformInfo::OP_None; Value *Op2 = I->getOperand(1); - - // Check for a splat or for a non uniform vector of constants. - if (isa<ConstantInt>(Op2)) { - ConstantInt *CInt = cast<ConstantInt>(Op2); - if (CInt && CInt->getValue().isPowerOf2()) - Op2VP = TargetTransformInfo::OP_PowerOf2; - Op2VK = TargetTransformInfo::OK_UniformConstantValue; - } else if (isa<ConstantVector>(Op2) || isa<ConstantDataVector>(Op2)) { - Op2VK = TargetTransformInfo::OK_NonUniformConstantValue; - Constant *SplatValue = cast<Constant>(Op2)->getSplatValue(); - if (SplatValue) { - ConstantInt *CInt = dyn_cast<ConstantInt>(SplatValue); - if (CInt && CInt->getValue().isPowerOf2()) - Op2VP = TargetTransformInfo::OP_PowerOf2; - Op2VK = TargetTransformInfo::OK_UniformConstantValue; - } - } else if (Legal->isUniform(Op2)) { + TargetTransformInfo::OperandValueProperties Op2VP; + TargetTransformInfo::OperandValueKind Op2VK = + TTI.getOperandInfo(Op2, Op2VP); + if (Op2VK == TargetTransformInfo::OK_AnyValue && Legal->isUniform(Op2)) Op2VK = TargetTransformInfo::OK_UniformValue; - } + SmallVector<const Value *, 4> Operands(I->operand_values()); unsigned N = isScalarAfterVectorization(I, VF) ? VF : 1; - return N * TTI.getArithmeticInstrCost(I->getOpcode(), VectorTy, Op1VK, - Op2VK, Op1VP, Op2VP, Operands); + return N * TTI.getArithmeticInstrCost( + I->getOpcode(), VectorTy, TargetTransformInfo::OK_AnyValue, + Op2VK, TargetTransformInfo::OP_None, Op2VP, Operands); } case Instruction::Select: { SelectInst *SI = cast<SelectInst>(I); @@ -6237,8 +6031,9 @@ INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false) namespace llvm { -Pass *createLoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { - return new LoopVectorize(NoUnrolling, AlwaysVectorize); +Pass *createLoopVectorizePass(bool InterleaveOnlyWhenForced, + bool VectorizeOnlyWhenForced) { + return new LoopVectorize(InterleaveOnlyWhenForced, VectorizeOnlyWhenForced); } } // end namespace llvm @@ -6316,6 +6111,16 @@ LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { if (!MaybeMaxVF.hasValue()) // Cases considered too costly to vectorize. return NoVectorization; + // Invalidate interleave groups if all blocks of loop will be predicated. + if (CM.blockNeedsPredication(OrigLoop->getHeader()) && + !useMaskedInterleavedAccesses(*TTI)) { + LLVM_DEBUG( + dbgs() + << "LV: Invalidate all interleaved groups due to fold-tail by masking " + "which requires masked-interleaved support.\n"); + CM.InterleaveInfo.reset(); + } + if (UserVF) { LLVM_DEBUG(dbgs() << "LV: Using user VF " << UserVF << ".\n"); assert(isPowerOf2_32(UserVF) && "VF needs to be a power of two"); @@ -6372,6 +6177,7 @@ void LoopVectorizationPlanner::executePlan(InnerLoopVectorizer &ILV, DT, ILV.Builder, ILV.VectorLoopValueMap, &ILV, CallbackILV}; State.CFG.PrevBB = ILV.createVectorizedLoopSkeleton(); + State.TripCount = ILV.getOrCreateTripCount(nullptr); //===------------------------------------------------===// // @@ -6408,7 +6214,8 @@ void LoopVectorizationPlanner::collectTriviallyDeadInstructions( PHINode *Ind = Induction.first; auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); if (llvm::all_of(IndUpdate->users(), [&](User *U) -> bool { - return U == Ind || DeadInstructions.count(cast<Instruction>(U)); + return U == Ind || DeadInstructions.find(cast<Instruction>(U)) != + DeadInstructions.end(); })) DeadInstructions.insert(IndUpdate); @@ -6551,9 +6358,17 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { // load/store/gather/scatter. Initialize BlockMask to no-mask. VPValue *BlockMask = nullptr; - // Loop incoming mask is all-one. - if (OrigLoop->getHeader() == BB) + if (OrigLoop->getHeader() == BB) { + if (!CM.blockNeedsPredication(BB)) + return BlockMaskCache[BB] = BlockMask; // Loop incoming mask is all-one. + + // Introduce the early-exit compare IV <= BTC to form header block mask. + // This is used instead of IV < TC because TC may wrap, unlike BTC. + VPValue *IV = Plan->getVPValue(Legal->getPrimaryInduction()); + VPValue *BTC = Plan->getOrCreateBackedgeTakenCount(); + BlockMask = Builder.createNaryOp(VPInstruction::ICmpULE, {IV, BTC}); return BlockMaskCache[BB] = BlockMask; + } // This is the block mask. We OR all incoming edges. for (auto *Predecessor : predecessors(BB)) { @@ -6573,8 +6388,9 @@ VPValue *VPRecipeBuilder::createBlockInMask(BasicBlock *BB, VPlanPtr &Plan) { } VPInterleaveRecipe *VPRecipeBuilder::tryToInterleaveMemory(Instruction *I, - VFRange &Range) { - const InterleaveGroup *IG = CM.getInterleavedAccessGroup(I); + VFRange &Range, + VPlanPtr &Plan) { + const InterleaveGroup<Instruction> *IG = CM.getInterleavedAccessGroup(I); if (!IG) return nullptr; @@ -6595,7 +6411,11 @@ VPInterleaveRecipe *VPRecipeBuilder::tryToInterleaveMemory(Instruction *I, assert(I == IG->getInsertPos() && "Generating a recipe for an adjunct member of an interleave group"); - return new VPInterleaveRecipe(IG); + VPValue *Mask = nullptr; + if (Legal->isMaskRequired(I)) + Mask = createBlockInMask(I->getParent(), Plan); + + return new VPInterleaveRecipe(IG, Mask); } VPWidenMemoryInstructionRecipe * @@ -6688,7 +6508,11 @@ VPBlendRecipe *VPRecipeBuilder::tryToBlend(Instruction *I, VPlanPtr &Plan) { bool VPRecipeBuilder::tryToWiden(Instruction *I, VPBasicBlock *VPBB, VFRange &Range) { - if (CM.isScalarWithPredication(I)) + + bool IsPredicated = LoopVectorizationPlanner::getDecisionAndClampRange( + [&](unsigned VF) { return CM.isScalarWithPredication(I, VF); }, Range); + + if (IsPredicated) return false; auto IsVectorizableOpcode = [](unsigned Opcode) { @@ -6795,7 +6619,9 @@ VPBasicBlock *VPRecipeBuilder::handleReplication( [&](unsigned VF) { return CM.isUniformAfterVectorization(I, VF); }, Range); - bool IsPredicated = CM.isScalarWithPredication(I); + bool IsPredicated = LoopVectorizationPlanner::getDecisionAndClampRange( + [&](unsigned VF) { return CM.isScalarWithPredication(I, VF); }, Range); + auto *Recipe = new VPReplicateRecipe(I, IsUniform, IsPredicated); // Find if I uses a predicated instruction. If so, it will use its scalar @@ -6857,7 +6683,7 @@ bool VPRecipeBuilder::tryToCreateRecipe(Instruction *Instr, VFRange &Range, VPRecipeBase *Recipe = nullptr; // Check if Instr should belong to an interleave memory recipe, or already // does. In the latter case Instr is irrelevant. - if ((Recipe = tryToInterleaveMemory(Instr, Range))) { + if ((Recipe = tryToInterleaveMemory(Instr, Range, Plan))) { VPBB->appendRecipe(Recipe); return true; } @@ -6908,6 +6734,11 @@ void LoopVectorizationPlanner::buildVPlansWithVPRecipes(unsigned MinVF, NeedDef.insert(Branch->getCondition()); } + // If the tail is to be folded by masking, the primary induction variable + // needs to be represented in VPlan for it to model early-exit masking. + if (CM.foldTailByMasking()) + NeedDef.insert(Legal->getPrimaryInduction()); + // Collect instructions from the original loop that will become trivially dead // in the vectorized loop. We don't need to vectorize these instructions. For // example, original induction update instructions can become dead because we @@ -6969,18 +6800,21 @@ LoopVectorizationPlanner::buildVPlanWithVPRecipes( // First filter out irrelevant instructions, to ensure no recipes are // built for them. - if (isa<BranchInst>(Instr) || DeadInstructions.count(Instr)) + if (isa<BranchInst>(Instr) || + DeadInstructions.find(Instr) != DeadInstructions.end()) continue; // I is a member of an InterleaveGroup for Range.Start. If it's an adjunct // member of the IG, do not construct any Recipe for it. - const InterleaveGroup *IG = CM.getInterleavedAccessGroup(Instr); + const InterleaveGroup<Instruction> *IG = + CM.getInterleavedAccessGroup(Instr); if (IG && Instr != IG->getInsertPos() && Range.Start >= 2 && // Query is illegal for VF == 1 CM.getWideningDecision(Instr, Range.Start) == LoopVectorizationCostModel::CM_Interleave) { - if (SinkAfterInverse.count(Instr)) - Ingredients.push_back(SinkAfterInverse.find(Instr)->second); + auto SinkCandidate = SinkAfterInverse.find(Instr); + if (SinkCandidate != SinkAfterInverse.end()) + Ingredients.push_back(SinkCandidate->second); continue; } @@ -7063,6 +6897,13 @@ LoopVectorizationPlanner::buildVPlan(VFRange &Range) { VPlanHCFGBuilder HCFGBuilder(OrigLoop, LI, *Plan); HCFGBuilder.buildHierarchicalCFG(); + SmallPtrSet<Instruction *, 1> DeadInstructions; + VPlanHCFGTransforms::VPInstructionsToVPRecipes( + Plan, Legal->getInductionVars(), DeadInstructions); + + for (unsigned VF = Range.Start; VF < Range.End; VF *= 2) + Plan->addVF(VF); + return Plan; } @@ -7075,6 +6916,10 @@ void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent) const { O << " +\n" << Indent << "\"INTERLEAVE-GROUP with factor " << IG->getFactor() << " at "; IG->getInsertPos()->printAsOperand(O, false); + if (User) { + O << ", "; + User->getOperand(0)->printAsOperand(O); + } O << "\\l\""; for (unsigned i = 0; i < IG->getFactor(); ++i) if (Instruction *I = IG->getMember(i)) @@ -7137,7 +6982,15 @@ void VPBlendRecipe::execute(VPTransformState &State) { void VPInterleaveRecipe::execute(VPTransformState &State) { assert(!State.Instance && "Interleave group being replicated."); - State.ILV->vectorizeInterleaveGroup(IG->getInsertPos()); + if (!User) + return State.ILV->vectorizeInterleaveGroup(IG->getInsertPos()); + + // Last (and currently only) operand is a mask. + InnerLoopVectorizer::VectorParts MaskValues(State.UF); + VPValue *Mask = User->getOperand(User->getNumOperands() - 1); + for (unsigned Part = 0; Part < State.UF; ++Part) + MaskValues[Part] = State.get(Mask, Part); + State.ILV->vectorizeInterleaveGroup(IG->getInsertPos(), &MaskValues); } void VPReplicateRecipe::execute(VPTransformState &State) { @@ -7264,11 +7117,26 @@ static bool processLoopInVPlanNativePath( Hints.getForce() != LoopVectorizeHints::FK_Enabled && F->optForSize(); // Plan how to best vectorize, return the best VF and its cost. - LVP.planInVPlanNativePath(OptForSize, UserVF); + VectorizationFactor VF = LVP.planInVPlanNativePath(OptForSize, UserVF); - // Returning false. We are currently not generating vector code in the VPlan - // native path. - return false; + // If we are stress testing VPlan builds, do not attempt to generate vector + // code. + if (VPlanBuildStressTest) + return false; + + LVP.setBestPlan(VF.Width, 1); + + InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, UserVF, 1, LVL, + &CM); + LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \"" + << L->getHeader()->getParent()->getName() << "\"\n"); + LVP.executePlan(LB, DT); + + // Mark the loop as already vectorized to avoid vectorizing again. + Hints.setAlreadyVectorized(); + + LLVM_DEBUG(verifyFunction(*L->getHeader()->getParent())); + return true; } bool LoopVectorizePass::processLoop(Loop *L) { @@ -7283,7 +7151,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { << L->getHeader()->getParent()->getName() << "\" from " << DebugLocStr << "\n"); - LoopVectorizeHints Hints(L, DisableUnrolling, *ORE); + LoopVectorizeHints Hints(L, InterleaveOnlyWhenForced, *ORE); LLVM_DEBUG( dbgs() << "LV: Loop hints:" @@ -7307,7 +7175,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // less verbose reporting vectorized loops and unvectorized loops that may // benefit from vectorization, respectively. - if (!Hints.allowVectorization(F, L, AlwaysVectorize)) { + if (!Hints.allowVectorization(F, L, VectorizeOnlyWhenForced)) { LLVM_DEBUG(dbgs() << "LV: Loop hints prevent vectorization.\n"); return false; } @@ -7320,7 +7188,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { &Requirements, &Hints, DB, AC); if (!LVL.canVectorize(EnableVPlanNativePath)) { LLVM_DEBUG(dbgs() << "LV: Not vectorizing: Cannot prove legality.\n"); - emitMissedWarning(F, L, Hints, ORE); + Hints.emitRemarkWithHints(); return false; } @@ -7393,7 +7261,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { ORE->emit(createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(), "NoImplicitFloat", L) << "loop not vectorized due to NoImplicitFloat attribute"); - emitMissedWarning(F, L, Hints, ORE); + Hints.emitRemarkWithHints(); return false; } @@ -7408,7 +7276,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { ORE->emit( createLVMissedAnalysis(Hints.vectorizeAnalysisPassName(), "UnsafeFP", L) << "loop not vectorized due to unsafe FP support."); - emitMissedWarning(F, L, Hints, ORE); + Hints.emitRemarkWithHints(); return false; } @@ -7421,7 +7289,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // Analyze interleaved memory accesses. if (UseInterleaved) { - IAI.analyzeInterleaving(); + IAI.analyzeInterleaving(useMaskedInterleavedAccesses(*TTI)); } // Use the cost model. @@ -7450,7 +7318,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { if (Requirements.doesNotMeet(F, L, Hints)) { LLVM_DEBUG(dbgs() << "LV: Not vectorizing: loop did not meet vectorization " "requirements.\n"); - emitMissedWarning(F, L, Hints, ORE); + Hints.emitRemarkWithHints(); return false; } @@ -7527,6 +7395,8 @@ bool LoopVectorizePass::processLoop(Loop *L) { LVP.setBestPlan(VF.Width, IC); using namespace ore; + bool DisableRuntimeUnroll = false; + MDNode *OrigLoopID = L->getLoopID(); if (!VectorizeLoop) { assert(IC > 1 && "interleave count should not be 1 or 0"); @@ -7553,7 +7423,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { // no runtime checks about strides and memory. A scalar loop that is // rarely used is not worth unrolling. if (!LB.areSafetyChecksAdded()) - AddRuntimeUnrollDisableMetaData(L); + DisableRuntimeUnroll = true; // Report the vectorization decision. ORE->emit([&]() { @@ -7565,8 +7435,18 @@ bool LoopVectorizePass::processLoop(Loop *L) { }); } - // Mark the loop as already vectorized to avoid vectorizing again. - Hints.setAlreadyVectorized(); + Optional<MDNode *> RemainderLoopID = + makeFollowupLoopID(OrigLoopID, {LLVMLoopVectorizeFollowupAll, + LLVMLoopVectorizeFollowupEpilogue}); + if (RemainderLoopID.hasValue()) { + L->setLoopID(RemainderLoopID.getValue()); + } else { + if (DisableRuntimeUnroll) + AddRuntimeUnrollDisableMetaData(L); + + // Mark the loop as already vectorized to avoid vectorizing again. + Hints.setAlreadyVectorized(); + } LLVM_DEBUG(verifyFunction(*L->getHeader()->getParent())); return true; @@ -7659,8 +7539,15 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, if (!Changed) return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserve<LoopAnalysis>(); - PA.preserve<DominatorTreeAnalysis>(); + + // We currently do not preserve loopinfo/dominator analyses with outer loop + // vectorization. Until this is addressed, mark these analyses as preserved + // only for non-VPlan-native path. + // TODO: Preserve Loop and Dominator analyses for VPlan-native path. + if (!EnableVPlanNativePath) { + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + } PA.preserve<BasicAA>(); PA.preserve<GlobalsAA>(); return PA; diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index 5c2efe885e22..2e856a7e6802 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1536,12 +1536,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, // Check for terminator values (e.g. invoke). for (unsigned j = 0; j < VL.size(); ++j) for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { - TerminatorInst *Term = dyn_cast<TerminatorInst>( - cast<PHINode>(VL[j])->getIncomingValueForBlock(PH->getIncomingBlock(i))); - if (Term) { - LLVM_DEBUG( - dbgs() - << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n"); + Instruction *Term = dyn_cast<Instruction>( + cast<PHINode>(VL[j])->getIncomingValueForBlock( + PH->getIncomingBlock(i))); + if (Term && Term->isTerminator()) { + LLVM_DEBUG(dbgs() + << "SLP: Need to swizzle PHINodes (terminator use).\n"); BS.cancelScheduling(VL, VL0); newTreeEntry(VL, false, UserTreeIdx, ReuseShuffleIndicies); return; @@ -2164,7 +2164,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { // extractelement/ext pair. DeadCost -= TTI->getExtractWithExtendCost( Ext->getOpcode(), Ext->getType(), VecTy, i); - // Add back the cost of s|zext which is subtracted seperately. + // Add back the cost of s|zext which is subtracted separately. DeadCost += TTI->getCastInstrCost( Ext->getOpcode(), Ext->getType(), E->getType(), Ext); continue; @@ -2536,13 +2536,13 @@ int BoUpSLP::getTreeCost() { // uses. However, we should not compute the cost of duplicate sequences. // For example, if we have a build vector (i.e., insertelement sequence) // that is used by more than one vector instruction, we only need to - // compute the cost of the insertelement instructions once. The redundent + // compute the cost of the insertelement instructions once. The redundant // instructions will be eliminated by CSE. // // We should consider not creating duplicate tree entries for gather // sequences, and instead add additional edges to the tree representing // their uses. Since such an approach results in fewer total entries, - // existing heuristics based on tree size may yeild different results. + // existing heuristics based on tree size may yield different results. // if (TE.NeedToGather && std::any_of(std::next(VectorizableTree.begin(), I + 1), @@ -3109,14 +3109,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } if (NeedToShuffleReuses) { // TODO: Merge this shuffle with the ReorderShuffleMask. - if (!E->ReorderIndices.empty()) + if (E->ReorderIndices.empty()) Builder.SetInsertPoint(VL0); - else if (auto *I = dyn_cast<Instruction>(V)) - Builder.SetInsertPoint(I->getParent(), - std::next(I->getIterator())); - else - Builder.SetInsertPoint(&F->getEntryBlock(), - F->getEntryBlock().getFirstInsertionPt()); V = Builder.CreateShuffleVector(V, UndefValue::get(VecTy), E->ReuseShuffleIndices, "shuffle"); } @@ -3649,6 +3643,8 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { auto &Locs = ExternallyUsedValues[Scalar]; ExternallyUsedValues.insert({Ex, Locs}); ExternallyUsedValues.erase(Scalar); + // Required to update internally referenced instructions. + Scalar->replaceAllUsesWith(Ex); continue; } @@ -3658,7 +3654,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { if (PHINode *PH = dyn_cast<PHINode>(User)) { for (int i = 0, e = PH->getNumIncomingValues(); i != e; ++i) { if (PH->getIncomingValue(i) == Scalar) { - TerminatorInst *IncomingTerminator = + Instruction *IncomingTerminator = PH->getIncomingBlock(i)->getTerminator(); if (isa<CatchSwitchInst>(IncomingTerminator)) { Builder.SetInsertPoint(VecI->getParent(), @@ -3966,7 +3962,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, ScheduleEnd = I->getNextNode(); if (isOneOf(S, I) != I) CheckSheduleForI(I); - assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); + assert(ScheduleEnd && "tried to vectorize a terminator?"); LLVM_DEBUG(dbgs() << "SLP: initialize schedule region to " << *I << "\n"); return true; } @@ -4002,7 +3998,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, ScheduleEnd = I->getNextNode(); if (isOneOf(S, I) != I) CheckSheduleForI(I); - assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); + assert(ScheduleEnd && "tried to vectorize a terminator?"); LLVM_DEBUG(dbgs() << "SLP: extend schedule region end to " << *I << "\n"); return true; @@ -4273,7 +4269,7 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { Worklist.push_back(I); // Traverse the expression tree in bottom-up order looking for loads. If we - // encounter an instruciton we don't yet handle, we give up. + // encounter an instruction we don't yet handle, we give up. auto MaxWidth = 0u; auto FoundUnknownInst = false; while (!Worklist.empty() && !FoundUnknownInst) { @@ -4846,7 +4842,7 @@ void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) { continue; if (GEP->getType()->isVectorTy()) continue; - GEPs[GetUnderlyingObject(GEP->getPointerOperand(), *DL)].push_back(GEP); + GEPs[GEP->getPointerOperand()].push_back(GEP); } } } @@ -5132,9 +5128,12 @@ class HorizontalReduction { /// Checks if the reduction operation can be vectorized. bool isVectorizable() const { return LHS && RHS && - // We currently only support adds && min/max reductions. + // We currently only support add/mul/logical && min/max reductions. ((Kind == RK_Arithmetic && - (Opcode == Instruction::Add || Opcode == Instruction::FAdd)) || + (Opcode == Instruction::Add || Opcode == Instruction::FAdd || + Opcode == Instruction::Mul || Opcode == Instruction::FMul || + Opcode == Instruction::And || Opcode == Instruction::Or || + Opcode == Instruction::Xor)) || ((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && (Kind == RK_Min || Kind == RK_Max)) || (Opcode == Instruction::ICmp && @@ -5456,7 +5455,7 @@ class HorizontalReduction { } }; - Instruction *ReductionRoot = nullptr; + WeakTrackingVH ReductionRoot; /// The operation data of the reduction operation. OperationData ReductionData; @@ -5741,7 +5740,7 @@ public: unsigned ReduxWidth = PowerOf2Floor(NumReducedVals); Value *VectorizedTree = nullptr; - IRBuilder<> Builder(ReductionRoot); + IRBuilder<> Builder(cast<Instruction>(ReductionRoot)); FastMathFlags Unsafe; Unsafe.setFast(); Builder.setFastMathFlags(Unsafe); @@ -5750,8 +5749,13 @@ public: BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues; // The same extra argument may be used several time, so log each attempt // to use it. - for (auto &Pair : ExtraArgs) + for (auto &Pair : ExtraArgs) { + assert(Pair.first && "DebugLoc must be set."); ExternallyUsedValues[Pair.second].push_back(Pair.first); + } + // The reduction root is used as the insertion point for new instructions, + // so set it as externally used to prevent it from being deleted. + ExternallyUsedValues[ReductionRoot]; SmallVector<Value *, 16> IgnoreList; for (auto &V : ReductionOps) IgnoreList.append(V.begin(), V.end()); @@ -5803,6 +5807,7 @@ public: Value *VectorizedRoot = V.vectorizeTree(ExternallyUsedValues); // Emit a reduction. + Builder.SetInsertPoint(cast<Instruction>(ReductionRoot)); Value *ReducedSubTree = emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI); if (VectorizedTree) { @@ -5829,8 +5834,6 @@ public: VectorizedTree = VectReductionData.createOp(Builder, "", ReductionOps); } for (auto &Pair : ExternallyUsedValues) { - assert(!Pair.second.empty() && - "At least one DebugLoc must be inserted"); // Add each externally used value to the final reduction. for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); diff --git a/lib/Transforms/Vectorize/VPRecipeBuilder.h b/lib/Transforms/Vectorize/VPRecipeBuilder.h index f43a8bb123b1..15d38ac9c84c 100644 --- a/lib/Transforms/Vectorize/VPRecipeBuilder.h +++ b/lib/Transforms/Vectorize/VPRecipeBuilder.h @@ -69,7 +69,8 @@ public: /// \return value is <true, nullptr>, as it is handled by another recipe. /// \p Range.End may be decreased to ensure same decision from \p Range.Start /// to \p Range.End. - VPInterleaveRecipe *tryToInterleaveMemory(Instruction *I, VFRange &Range); + VPInterleaveRecipe *tryToInterleaveMemory(Instruction *I, VFRange &Range, + VPlanPtr &Plan); /// Check if \I is a memory instruction to be widened for \p Range.Start and /// potentially masked. Such instructions are handled by a recipe that takes diff --git a/lib/Transforms/Vectorize/VPlan.cpp b/lib/Transforms/Vectorize/VPlan.cpp index 0780e70809d0..05a5400beb4e 100644 --- a/lib/Transforms/Vectorize/VPlan.cpp +++ b/lib/Transforms/Vectorize/VPlan.cpp @@ -44,6 +44,7 @@ #include <vector> using namespace llvm; +extern cl::opt<bool> EnableVPlanNativePath; #define DEBUG_TYPE "vplan" @@ -124,6 +125,20 @@ VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { VPBasicBlock *PredVPBB = PredVPBlock->getExitBasicBlock(); auto &PredVPSuccessors = PredVPBB->getSuccessors(); BasicBlock *PredBB = CFG.VPBB2IRBB[PredVPBB]; + + // In outer loop vectorization scenario, the predecessor BBlock may not yet + // be visited(backedge). Mark the VPBasicBlock for fixup at the end of + // vectorization. We do not encounter this case in inner loop vectorization + // as we start out by building a loop skeleton with the vector loop header + // and latch blocks. As a result, we never enter this function for the + // header block in the non VPlan-native path. + if (!PredBB) { + assert(EnableVPlanNativePath && + "Unexpected null predecessor in non VPlan-native path"); + CFG.VPBBsToFix.push_back(PredVPBB); + continue; + } + assert(PredBB && "Predecessor basic-block not found building successor."); auto *PredBBTerminator = PredBB->getTerminator(); LLVM_DEBUG(dbgs() << "LV: draw edge from" << PredBB->getName() << '\n'); @@ -185,6 +200,31 @@ void VPBasicBlock::execute(VPTransformState *State) { for (VPRecipeBase &Recipe : Recipes) Recipe.execute(*State); + VPValue *CBV; + if (EnableVPlanNativePath && (CBV = getCondBit())) { + Value *IRCBV = CBV->getUnderlyingValue(); + assert(IRCBV && "Unexpected null underlying value for condition bit"); + + // Condition bit value in a VPBasicBlock is used as the branch selector. In + // the VPlan-native path case, since all branches are uniform we generate a + // branch instruction using the condition value from vector lane 0 and dummy + // successors. The successors are fixed later when the successor blocks are + // visited. + Value *NewCond = State->Callback.getOrCreateVectorValues(IRCBV, 0); + NewCond = State->Builder.CreateExtractElement(NewCond, + State->Builder.getInt32(0)); + + // Replace the temporary unreachable terminator with the new conditional + // branch. + auto *CurrentTerminator = NewBB->getTerminator(); + assert(isa<UnreachableInst>(CurrentTerminator) && + "Expected to replace unreachable terminator with conditional " + "branch."); + auto *CondBr = BranchInst::Create(NewBB, nullptr, NewCond); + CondBr->setSuccessor(0, nullptr); + ReplaceInstWithInst(CurrentTerminator, CondBr); + } + LLVM_DEBUG(dbgs() << "LV: filled BB:" << *NewBB); } @@ -194,6 +234,20 @@ void VPRegionBlock::execute(VPTransformState *State) { if (!isReplicator()) { // Visit the VPBlocks connected to "this", starting from it. for (VPBlockBase *Block : RPOT) { + if (EnableVPlanNativePath) { + // The inner loop vectorization path does not represent loop preheader + // and exit blocks as part of the VPlan. In the VPlan-native path, skip + // vectorizing loop preheader block. In future, we may replace this + // check with the check for loop preheader. + if (Block->getNumPredecessors() == 0) + continue; + + // Skip vectorizing loop exit block. In future, we may replace this + // check with the check for loop exit. + if (Block->getNumSuccessors() == 0) + continue; + } + LLVM_DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); Block->execute(State); } @@ -249,6 +303,13 @@ void VPInstruction::generateInstruction(VPTransformState &State, State.set(this, V, Part); break; } + case VPInstruction::ICmpULE: { + Value *IV = State.get(getOperand(0), Part); + Value *TC = State.get(getOperand(1), Part); + Value *V = Builder.CreateICmpULE(IV, TC); + State.set(this, V, Part); + break; + } default: llvm_unreachable("Unsupported opcode for instruction"); } @@ -274,6 +335,15 @@ void VPInstruction::print(raw_ostream &O) const { case VPInstruction::Not: O << "not"; break; + case VPInstruction::ICmpULE: + O << "icmp ule"; + break; + case VPInstruction::SLPLoad: + O << "combined load"; + break; + case VPInstruction::SLPStore: + O << "combined store"; + break; default: O << Instruction::getOpcodeName(getOpcode()); } @@ -288,6 +358,15 @@ void VPInstruction::print(raw_ostream &O) const { /// LoopVectorBody basic-block was created for this. Introduce additional /// basic-blocks as needed, and fill them all. void VPlan::execute(VPTransformState *State) { + // -1. Check if the backedge taken count is needed, and if so build it. + if (BackedgeTakenCount && BackedgeTakenCount->getNumUsers()) { + Value *TC = State->TripCount; + IRBuilder<> Builder(State->CFG.PrevBB->getTerminator()); + auto *TCMO = Builder.CreateSub(TC, ConstantInt::get(TC->getType(), 1), + "trip.count.minus.1"); + Value2VPValue[TCMO] = BackedgeTakenCount; + } + // 0. Set the reverse mapping from VPValues to Values for code generation. for (auto &Entry : Value2VPValue) State->VPValue2Value[Entry.second] = Entry.first; @@ -319,11 +398,32 @@ void VPlan::execute(VPTransformState *State) { for (VPBlockBase *Block : depth_first(Entry)) Block->execute(State); + // Setup branch terminator successors for VPBBs in VPBBsToFix based on + // VPBB's successors. + for (auto VPBB : State->CFG.VPBBsToFix) { + assert(EnableVPlanNativePath && + "Unexpected VPBBsToFix in non VPlan-native path"); + BasicBlock *BB = State->CFG.VPBB2IRBB[VPBB]; + assert(BB && "Unexpected null basic block for VPBB"); + + unsigned Idx = 0; + auto *BBTerminator = BB->getTerminator(); + + for (VPBlockBase *SuccVPBlock : VPBB->getHierarchicalSuccessors()) { + VPBasicBlock *SuccVPBB = SuccVPBlock->getEntryBasicBlock(); + BBTerminator->setSuccessor(Idx, State->CFG.VPBB2IRBB[SuccVPBB]); + ++Idx; + } + } + // 3. Merge the temporary latch created with the last basic-block filled. BasicBlock *LastBB = State->CFG.PrevBB; // Connect LastBB to VectorLatchBB to facilitate their merge. - assert(isa<UnreachableInst>(LastBB->getTerminator()) && - "Expected VPlan CFG to terminate with unreachable"); + assert((EnableVPlanNativePath || + isa<UnreachableInst>(LastBB->getTerminator())) && + "Expected InnerLoop VPlan CFG to terminate with unreachable"); + assert((!EnableVPlanNativePath || isa<BranchInst>(LastBB->getTerminator())) && + "Expected VPlan CFG to terminate with branch in NativePath"); LastBB->getTerminator()->eraseFromParent(); BranchInst::Create(VectorLatchBB, LastBB); @@ -333,7 +433,9 @@ void VPlan::execute(VPTransformState *State) { assert(Merged && "Could not merge last basic block with latch."); VectorLatchBB = LastBB; - updateDominatorTree(State->DT, VectorPreHeaderBB, VectorLatchBB); + // We do not attempt to preserve DT for outer loop vectorization currently. + if (!EnableVPlanNativePath) + updateDominatorTree(State->DT, VectorPreHeaderBB, VectorLatchBB); } void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopPreHeaderBB, @@ -366,7 +468,7 @@ void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopPreHeaderBB, "One successor of a basic block does not lead to the other."); assert(InterimSucc->getSinglePredecessor() && "Interim successor has more than one predecessor."); - assert(pred_size(PostDomSucc) == 2 && + assert(PostDomSucc->hasNPredecessors(2) && "PostDom successor has more than two predecessors."); DT->addNewBlock(InterimSucc, BB); DT->addNewBlock(PostDomSucc, BB); @@ -392,8 +494,11 @@ void VPlanPrinter::dump() { OS << "graph [labelloc=t, fontsize=30; label=\"Vectorization Plan"; if (!Plan.getName().empty()) OS << "\\n" << DOT::EscapeString(Plan.getName()); - if (!Plan.Value2VPValue.empty()) { + if (!Plan.Value2VPValue.empty() || Plan.BackedgeTakenCount) { OS << ", where:"; + if (Plan.BackedgeTakenCount) + OS << "\\n" + << *Plan.getOrCreateBackedgeTakenCount() << " := BackedgeTakenCount"; for (auto Entry : Plan.Value2VPValue) { OS << "\\n" << *Entry.second; OS << DOT::EscapeString(" := "); @@ -466,8 +571,10 @@ void VPlanPrinter::dumpBasicBlock(const VPBasicBlock *BasicBlock) { if (const VPInstruction *CBI = dyn_cast<VPInstruction>(CBV)) { CBI->printAsOperand(OS); OS << " (" << DOT::EscapeString(CBI->getParent()->getName()) << ")\\l\""; - } else + } else { CBV->printAsOperand(OS); + OS << "\""; + } } bumpIndent(-2); @@ -579,3 +686,55 @@ void VPWidenMemoryInstructionRecipe::print(raw_ostream &O, } template void DomTreeBuilder::Calculate<VPDominatorTree>(VPDominatorTree &DT); + +void VPValue::replaceAllUsesWith(VPValue *New) { + for (VPUser *User : users()) + for (unsigned I = 0, E = User->getNumOperands(); I < E; ++I) + if (User->getOperand(I) == this) + User->setOperand(I, New); +} + +void VPInterleavedAccessInfo::visitRegion(VPRegionBlock *Region, + Old2NewTy &Old2New, + InterleavedAccessInfo &IAI) { + ReversePostOrderTraversal<VPBlockBase *> RPOT(Region->getEntry()); + for (VPBlockBase *Base : RPOT) { + visitBlock(Base, Old2New, IAI); + } +} + +void VPInterleavedAccessInfo::visitBlock(VPBlockBase *Block, Old2NewTy &Old2New, + InterleavedAccessInfo &IAI) { + if (VPBasicBlock *VPBB = dyn_cast<VPBasicBlock>(Block)) { + for (VPRecipeBase &VPI : *VPBB) { + assert(isa<VPInstruction>(&VPI) && "Can only handle VPInstructions"); + auto *VPInst = cast<VPInstruction>(&VPI); + auto *Inst = cast<Instruction>(VPInst->getUnderlyingValue()); + auto *IG = IAI.getInterleaveGroup(Inst); + if (!IG) + continue; + + auto NewIGIter = Old2New.find(IG); + if (NewIGIter == Old2New.end()) + Old2New[IG] = new InterleaveGroup<VPInstruction>( + IG->getFactor(), IG->isReverse(), IG->getAlignment()); + + if (Inst == IG->getInsertPos()) + Old2New[IG]->setInsertPos(VPInst); + + InterleaveGroupMap[VPInst] = Old2New[IG]; + InterleaveGroupMap[VPInst]->insertMember( + VPInst, IG->getIndex(Inst), + IG->isReverse() ? (-1) * int(IG->getFactor()) : IG->getFactor()); + } + } else if (VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) + visitRegion(Region, Old2New, IAI); + else + llvm_unreachable("Unsupported kind of VPBlock."); +} + +VPInterleavedAccessInfo::VPInterleavedAccessInfo(VPlan &Plan, + InterleavedAccessInfo &IAI) { + Old2NewTy Old2New; + visitRegion(cast<VPRegionBlock>(Plan.getEntry()), Old2New, IAI); +} diff --git a/lib/Transforms/Vectorize/VPlan.h b/lib/Transforms/Vectorize/VPlan.h index 883e6f52369a..5c1b4a83c30e 100644 --- a/lib/Transforms/Vectorize/VPlan.h +++ b/lib/Transforms/Vectorize/VPlan.h @@ -38,6 +38,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist_node.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IRBuilder.h" #include <algorithm> #include <cassert> @@ -52,12 +53,14 @@ class LoopVectorizationCostModel; class BasicBlock; class DominatorTree; class InnerLoopVectorizer; -class InterleaveGroup; +template <class T> class InterleaveGroup; +class LoopInfo; class raw_ostream; class Value; class VPBasicBlock; class VPRegionBlock; class VPlan; +class VPlanSlp; /// A range of powers-of-2 vectorization factors with fixed start and /// adjustable end. The range includes start and excludes end, e.g.,: @@ -293,6 +296,10 @@ struct VPTransformState { /// of replication, maps the BasicBlock of the last replica created. SmallDenseMap<VPBasicBlock *, BasicBlock *> VPBB2IRBB; + /// Vector of VPBasicBlocks whose terminator instruction needs to be fixed + /// up at the end of vector code generation. + SmallVector<VPBasicBlock *, 8> VPBBsToFix; + CFGState() = default; } CFG; @@ -313,6 +320,9 @@ struct VPTransformState { /// Values they correspond to. VPValue2ValueTy VPValue2Value; + /// Hold the trip count of the scalar loop. + Value *TripCount = nullptr; + /// Hold a pointer to InnerLoopVectorizer to reuse its IR generation methods. InnerLoopVectorizer *ILV; @@ -600,10 +610,16 @@ public: /// the VPInstruction is also a single def-use vertex. class VPInstruction : public VPUser, public VPRecipeBase { friend class VPlanHCFGTransforms; + friend class VPlanSlp; public: /// VPlan opcodes, extending LLVM IR with idiomatics instructions. - enum { Not = Instruction::OtherOpsEnd + 1 }; + enum { + Not = Instruction::OtherOpsEnd + 1, + ICmpULE, + SLPLoad, + SLPStore, + }; private: typedef unsigned char OpcodeTy; @@ -613,6 +629,13 @@ private: /// modeled instruction. void generateInstruction(VPTransformState &State, unsigned Part); +protected: + Instruction *getUnderlyingInstr() { + return cast_or_null<Instruction>(getUnderlyingValue()); + } + + void setUnderlyingInstr(Instruction *I) { setUnderlyingValue(I); } + public: VPInstruction(unsigned Opcode, ArrayRef<VPValue *> Operands) : VPUser(VPValue::VPInstructionSC, Operands), @@ -626,6 +649,11 @@ public: return V->getVPValueID() == VPValue::VPInstructionSC; } + VPInstruction *clone() const { + SmallVector<VPValue *, 2> Operands(operands()); + return new VPInstruction(Opcode, Operands); + } + /// Method to support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const VPRecipeBase *R) { return R->getVPRecipeID() == VPRecipeBase::VPInstructionSC; @@ -643,6 +671,14 @@ public: /// Print the VPInstruction. void print(raw_ostream &O) const; + + /// Return true if this instruction may modify memory. + bool mayWriteToMemory() const { + // TODO: we can use attributes of the called function to rule out memory + // modifications. + return Opcode == Instruction::Store || Opcode == Instruction::Call || + Opcode == Instruction::Invoke || Opcode == SLPStore; + } }; /// VPWidenRecipe is a recipe for producing a copy of vector type for each @@ -764,11 +800,15 @@ public: /// or stores into one wide load/store and shuffles. class VPInterleaveRecipe : public VPRecipeBase { private: - const InterleaveGroup *IG; + const InterleaveGroup<Instruction> *IG; + std::unique_ptr<VPUser> User; public: - VPInterleaveRecipe(const InterleaveGroup *IG) - : VPRecipeBase(VPInterleaveSC), IG(IG) {} + VPInterleaveRecipe(const InterleaveGroup<Instruction> *IG, VPValue *Mask) + : VPRecipeBase(VPInterleaveSC), IG(IG) { + if (Mask) // Create a VPInstruction to register as a user of the mask. + User.reset(new VPUser({Mask})); + } ~VPInterleaveRecipe() override = default; /// Method to support type inquiry through isa, cast, and dyn_cast. @@ -782,7 +822,7 @@ public: /// Print the recipe. void print(raw_ostream &O, const Twine &Indent) const override; - const InterleaveGroup *getInterleaveGroup() { return IG; } + const InterleaveGroup<Instruction> *getInterleaveGroup() { return IG; } }; /// VPReplicateRecipe replicates a given instruction producing multiple scalar @@ -1107,6 +1147,10 @@ private: // (operators '==' and '<'). SmallPtrSet<VPValue *, 16> VPExternalDefs; + /// Represents the backedge taken count of the original loop, for folding + /// the tail. + VPValue *BackedgeTakenCount = nullptr; + /// Holds a mapping between Values and their corresponding VPValue inside /// VPlan. Value2VPValueTy Value2VPValue; @@ -1114,6 +1158,9 @@ private: /// Holds the VPLoopInfo analysis for this VPlan. VPLoopInfo VPLInfo; + /// Holds the condition bit values built during VPInstruction to VPRecipe transformation. + SmallVector<VPValue *, 4> VPCBVs; + public: VPlan(VPBlockBase *Entry = nullptr) : Entry(Entry) {} @@ -1121,9 +1168,14 @@ public: if (Entry) VPBlockBase::deleteCFG(Entry); for (auto &MapEntry : Value2VPValue) - delete MapEntry.second; + if (MapEntry.second != BackedgeTakenCount) + delete MapEntry.second; + if (BackedgeTakenCount) + delete BackedgeTakenCount; // Delete once, if in Value2VPValue or not. for (VPValue *Def : VPExternalDefs) delete Def; + for (VPValue *CBV : VPCBVs) + delete CBV; } /// Generate the IR code for this VPlan. @@ -1134,6 +1186,13 @@ public: VPBlockBase *setEntry(VPBlockBase *Block) { return Entry = Block; } + /// The backedge taken count of the original loop. + VPValue *getOrCreateBackedgeTakenCount() { + if (!BackedgeTakenCount) + BackedgeTakenCount = new VPValue(); + return BackedgeTakenCount; + } + void addVF(unsigned VF) { VFs.insert(VF); } bool hasVF(unsigned VF) { return VFs.count(VF); } @@ -1148,6 +1207,11 @@ public: VPExternalDefs.insert(VPVal); } + /// Add \p CBV to the vector of condition bit values. + void addCBV(VPValue *CBV) { + VPCBVs.push_back(CBV); + } + void addVPValue(Value *V) { assert(V && "Trying to add a null Value to VPlan"); assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); @@ -1429,6 +1493,144 @@ public: } }; +class VPInterleavedAccessInfo { +private: + DenseMap<VPInstruction *, InterleaveGroup<VPInstruction> *> + InterleaveGroupMap; + + /// Type for mapping of instruction based interleave groups to VPInstruction + /// interleave groups + using Old2NewTy = DenseMap<InterleaveGroup<Instruction> *, + InterleaveGroup<VPInstruction> *>; + + /// Recursively \p Region and populate VPlan based interleave groups based on + /// \p IAI. + void visitRegion(VPRegionBlock *Region, Old2NewTy &Old2New, + InterleavedAccessInfo &IAI); + /// Recursively traverse \p Block and populate VPlan based interleave groups + /// based on \p IAI. + void visitBlock(VPBlockBase *Block, Old2NewTy &Old2New, + InterleavedAccessInfo &IAI); + +public: + VPInterleavedAccessInfo(VPlan &Plan, InterleavedAccessInfo &IAI); + + ~VPInterleavedAccessInfo() { + SmallPtrSet<InterleaveGroup<VPInstruction> *, 4> DelSet; + // Avoid releasing a pointer twice. + for (auto &I : InterleaveGroupMap) + DelSet.insert(I.second); + for (auto *Ptr : DelSet) + delete Ptr; + } + + /// Get the interleave group that \p Instr belongs to. + /// + /// \returns nullptr if doesn't have such group. + InterleaveGroup<VPInstruction> * + getInterleaveGroup(VPInstruction *Instr) const { + if (InterleaveGroupMap.count(Instr)) + return InterleaveGroupMap.find(Instr)->second; + return nullptr; + } +}; + +/// Class that maps (parts of) an existing VPlan to trees of combined +/// VPInstructions. +class VPlanSlp { +private: + enum class OpMode { Failed, Load, Opcode }; + + /// A DenseMapInfo implementation for using SmallVector<VPValue *, 4> as + /// DenseMap keys. + struct BundleDenseMapInfo { + static SmallVector<VPValue *, 4> getEmptyKey() { + return {reinterpret_cast<VPValue *>(-1)}; + } + + static SmallVector<VPValue *, 4> getTombstoneKey() { + return {reinterpret_cast<VPValue *>(-2)}; + } + + static unsigned getHashValue(const SmallVector<VPValue *, 4> &V) { + return static_cast<unsigned>(hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const SmallVector<VPValue *, 4> &LHS, + const SmallVector<VPValue *, 4> &RHS) { + return LHS == RHS; + } + }; + + /// Mapping of values in the original VPlan to a combined VPInstruction. + DenseMap<SmallVector<VPValue *, 4>, VPInstruction *, BundleDenseMapInfo> + BundleToCombined; + + VPInterleavedAccessInfo &IAI; + + /// Basic block to operate on. For now, only instructions in a single BB are + /// considered. + const VPBasicBlock &BB; + + /// Indicates whether we managed to combine all visited instructions or not. + bool CompletelySLP = true; + + /// Width of the widest combined bundle in bits. + unsigned WidestBundleBits = 0; + + using MultiNodeOpTy = + typename std::pair<VPInstruction *, SmallVector<VPValue *, 4>>; + + // Input operand bundles for the current multi node. Each multi node operand + // bundle contains values not matching the multi node's opcode. They will + // be reordered in reorderMultiNodeOps, once we completed building a + // multi node. + SmallVector<MultiNodeOpTy, 4> MultiNodeOps; + + /// Indicates whether we are building a multi node currently. + bool MultiNodeActive = false; + + /// Check if we can vectorize Operands together. + bool areVectorizable(ArrayRef<VPValue *> Operands) const; + + /// Add combined instruction \p New for the bundle \p Operands. + void addCombined(ArrayRef<VPValue *> Operands, VPInstruction *New); + + /// Indicate we hit a bundle we failed to combine. Returns nullptr for now. + VPInstruction *markFailed(); + + /// Reorder operands in the multi node to maximize sequential memory access + /// and commutative operations. + SmallVector<MultiNodeOpTy, 4> reorderMultiNodeOps(); + + /// Choose the best candidate to use for the lane after \p Last. The set of + /// candidates to choose from are values with an opcode matching \p Last's + /// or loads consecutive to \p Last. + std::pair<OpMode, VPValue *> getBest(OpMode Mode, VPValue *Last, + SmallPtrSetImpl<VPValue *> &Candidates, + VPInterleavedAccessInfo &IAI); + + /// Print bundle \p Values to dbgs(). + void dumpBundle(ArrayRef<VPValue *> Values); + +public: + VPlanSlp(VPInterleavedAccessInfo &IAI, VPBasicBlock &BB) : IAI(IAI), BB(BB) {} + + ~VPlanSlp() { + for (auto &KV : BundleToCombined) + delete KV.second; + } + + /// Tries to build an SLP tree rooted at \p Operands and returns a + /// VPInstruction combining \p Operands, if they can be combined. + VPInstruction *buildGraph(ArrayRef<VPValue *> Operands); + + /// Return the width of the widest combined bundle in bits. + unsigned getWidestBundleBits() const { return WidestBundleBits; } + + /// Return true if all visited instruction can be combined. + bool isCompletelySLP() const { return CompletelySLP; } +}; } // end namespace llvm #endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H diff --git a/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp b/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp index b6307acb9474..0f42694e193b 100644 --- a/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp +++ b/lib/Transforms/Vectorize/VPlanHCFGBuilder.cpp @@ -268,7 +268,7 @@ VPRegionBlock *PlainCFGBuilder::buildPlainCFG() { // Set VPBB successors. We create empty VPBBs for successors if they don't // exist already. Recipes will be created when the successor is visited // during the RPO traversal. - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); assert(TI && "Terminator expected."); unsigned NumSuccs = TI->getNumSuccessors(); diff --git a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp index e3cbab077e61..3ad7fc7e7b96 100644 --- a/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp +++ b/lib/Transforms/Vectorize/VPlanHCFGTransforms.cpp @@ -24,6 +24,18 @@ void VPlanHCFGTransforms::VPInstructionsToVPRecipes( VPRegionBlock *TopRegion = dyn_cast<VPRegionBlock>(Plan->getEntry()); ReversePostOrderTraversal<VPBlockBase *> RPOT(TopRegion->getEntry()); + + // Condition bit VPValues get deleted during transformation to VPRecipes. + // Create new VPValues and save away as condition bits. These will be deleted + // after finalizing the vector IR basic blocks. + for (VPBlockBase *Base : RPOT) { + VPBasicBlock *VPBB = Base->getEntryBasicBlock(); + if (auto *CondBit = VPBB->getCondBit()) { + auto *NCondBit = new VPValue(CondBit->getUnderlyingValue()); + VPBB->setCondBit(NCondBit); + Plan->addCBV(NCondBit); + } + } for (VPBlockBase *Base : RPOT) { // Do not widen instructions in pre-header and exit blocks. if (Base->getNumPredecessors() == 0 || Base->getNumSuccessors() == 0) diff --git a/lib/Transforms/Vectorize/VPlanSLP.cpp b/lib/Transforms/Vectorize/VPlanSLP.cpp new file mode 100644 index 000000000000..ad3a85a6f760 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanSLP.cpp @@ -0,0 +1,468 @@ +//===- VPlanSLP.cpp - SLP Analysis based on VPlan -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// This file implements SLP analysis based on VPlan. The analysis is based on +/// the ideas described in +/// +/// Look-ahead SLP: auto-vectorization in the presence of commutative +/// operations, CGO 2018 by Vasileios Porpodas, Rodrigo C. O. Rocha, +/// Luís F. W. Góes +/// +//===----------------------------------------------------------------------===// + +#include "VPlan.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <cassert> +#include <iterator> +#include <string> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "vplan-slp" + +// Number of levels to look ahead when re-ordering multi node operands. +static unsigned LookaheadMaxDepth = 5; + +VPInstruction *VPlanSlp::markFailed() { + // FIXME: Currently this is used to signal we hit instructions we cannot + // trivially SLP'ize. + CompletelySLP = false; + return nullptr; +} + +void VPlanSlp::addCombined(ArrayRef<VPValue *> Operands, VPInstruction *New) { + if (all_of(Operands, [](VPValue *V) { + return cast<VPInstruction>(V)->getUnderlyingInstr(); + })) { + unsigned BundleSize = 0; + for (VPValue *V : Operands) { + Type *T = cast<VPInstruction>(V)->getUnderlyingInstr()->getType(); + assert(!T->isVectorTy() && "Only scalar types supported for now"); + BundleSize += T->getScalarSizeInBits(); + } + WidestBundleBits = std::max(WidestBundleBits, BundleSize); + } + + auto Res = BundleToCombined.try_emplace(to_vector<4>(Operands), New); + assert(Res.second && + "Already created a combined instruction for the operand bundle"); + (void)Res; +} + +bool VPlanSlp::areVectorizable(ArrayRef<VPValue *> Operands) const { + // Currently we only support VPInstructions. + if (!all_of(Operands, [](VPValue *Op) { + return Op && isa<VPInstruction>(Op) && + cast<VPInstruction>(Op)->getUnderlyingInstr(); + })) { + LLVM_DEBUG(dbgs() << "VPSLP: not all operands are VPInstructions\n"); + return false; + } + + // Check if opcodes and type width agree for all instructions in the bundle. + // FIXME: Differing widths/opcodes can be handled by inserting additional + // instructions. + // FIXME: Deal with non-primitive types. + const Instruction *OriginalInstr = + cast<VPInstruction>(Operands[0])->getUnderlyingInstr(); + unsigned Opcode = OriginalInstr->getOpcode(); + unsigned Width = OriginalInstr->getType()->getPrimitiveSizeInBits(); + if (!all_of(Operands, [Opcode, Width](VPValue *Op) { + const Instruction *I = cast<VPInstruction>(Op)->getUnderlyingInstr(); + return I->getOpcode() == Opcode && + I->getType()->getPrimitiveSizeInBits() == Width; + })) { + LLVM_DEBUG(dbgs() << "VPSLP: Opcodes do not agree \n"); + return false; + } + + // For now, all operands must be defined in the same BB. + if (any_of(Operands, [this](VPValue *Op) { + return cast<VPInstruction>(Op)->getParent() != &this->BB; + })) { + LLVM_DEBUG(dbgs() << "VPSLP: operands in different BBs\n"); + return false; + } + + if (any_of(Operands, + [](VPValue *Op) { return Op->hasMoreThanOneUniqueUser(); })) { + LLVM_DEBUG(dbgs() << "VPSLP: Some operands have multiple users.\n"); + return false; + } + + // For loads, check that there are no instructions writing to memory in + // between them. + // TODO: we only have to forbid instructions writing to memory that could + // interfere with any of the loads in the bundle + if (Opcode == Instruction::Load) { + unsigned LoadsSeen = 0; + VPBasicBlock *Parent = cast<VPInstruction>(Operands[0])->getParent(); + for (auto &I : *Parent) { + auto *VPI = cast<VPInstruction>(&I); + if (VPI->getOpcode() == Instruction::Load && + std::find(Operands.begin(), Operands.end(), VPI) != Operands.end()) + LoadsSeen++; + + if (LoadsSeen == Operands.size()) + break; + if (LoadsSeen > 0 && VPI->mayWriteToMemory()) { + LLVM_DEBUG( + dbgs() << "VPSLP: instruction modifying memory between loads\n"); + return false; + } + } + + if (!all_of(Operands, [](VPValue *Op) { + return cast<LoadInst>(cast<VPInstruction>(Op)->getUnderlyingInstr()) + ->isSimple(); + })) { + LLVM_DEBUG(dbgs() << "VPSLP: only simple loads are supported.\n"); + return false; + } + } + + if (Opcode == Instruction::Store) + if (!all_of(Operands, [](VPValue *Op) { + return cast<StoreInst>(cast<VPInstruction>(Op)->getUnderlyingInstr()) + ->isSimple(); + })) { + LLVM_DEBUG(dbgs() << "VPSLP: only simple stores are supported.\n"); + return false; + } + + return true; +} + +static SmallVector<VPValue *, 4> getOperands(ArrayRef<VPValue *> Values, + unsigned OperandIndex) { + SmallVector<VPValue *, 4> Operands; + for (VPValue *V : Values) { + auto *U = cast<VPUser>(V); + Operands.push_back(U->getOperand(OperandIndex)); + } + return Operands; +} + +static bool areCommutative(ArrayRef<VPValue *> Values) { + return Instruction::isCommutative( + cast<VPInstruction>(Values[0])->getOpcode()); +} + +static SmallVector<SmallVector<VPValue *, 4>, 4> +getOperands(ArrayRef<VPValue *> Values) { + SmallVector<SmallVector<VPValue *, 4>, 4> Result; + auto *VPI = cast<VPInstruction>(Values[0]); + + switch (VPI->getOpcode()) { + case Instruction::Load: + llvm_unreachable("Loads terminate a tree, no need to get operands"); + case Instruction::Store: + Result.push_back(getOperands(Values, 0)); + break; + default: + for (unsigned I = 0, NumOps = VPI->getNumOperands(); I < NumOps; ++I) + Result.push_back(getOperands(Values, I)); + break; + } + + return Result; +} + +/// Returns the opcode of Values or ~0 if they do not all agree. +static Optional<unsigned> getOpcode(ArrayRef<VPValue *> Values) { + unsigned Opcode = cast<VPInstruction>(Values[0])->getOpcode(); + if (any_of(Values, [Opcode](VPValue *V) { + return cast<VPInstruction>(V)->getOpcode() != Opcode; + })) + return None; + return {Opcode}; +} + +/// Returns true if A and B access sequential memory if they are loads or +/// stores or if they have identical opcodes otherwise. +static bool areConsecutiveOrMatch(VPInstruction *A, VPInstruction *B, + VPInterleavedAccessInfo &IAI) { + if (A->getOpcode() != B->getOpcode()) + return false; + + if (A->getOpcode() != Instruction::Load && + A->getOpcode() != Instruction::Store) + return true; + auto *GA = IAI.getInterleaveGroup(A); + auto *GB = IAI.getInterleaveGroup(B); + + return GA && GB && GA == GB && GA->getIndex(A) + 1 == GB->getIndex(B); +} + +/// Implements getLAScore from Listing 7 in the paper. +/// Traverses and compares operands of V1 and V2 to MaxLevel. +static unsigned getLAScore(VPValue *V1, VPValue *V2, unsigned MaxLevel, + VPInterleavedAccessInfo &IAI) { + if (!isa<VPInstruction>(V1) || !isa<VPInstruction>(V2)) + return 0; + + if (MaxLevel == 0) + return (unsigned)areConsecutiveOrMatch(cast<VPInstruction>(V1), + cast<VPInstruction>(V2), IAI); + + unsigned Score = 0; + for (unsigned I = 0, EV1 = cast<VPUser>(V1)->getNumOperands(); I < EV1; ++I) + for (unsigned J = 0, EV2 = cast<VPUser>(V2)->getNumOperands(); J < EV2; ++J) + Score += getLAScore(cast<VPUser>(V1)->getOperand(I), + cast<VPUser>(V2)->getOperand(J), MaxLevel - 1, IAI); + return Score; +} + +std::pair<VPlanSlp::OpMode, VPValue *> +VPlanSlp::getBest(OpMode Mode, VPValue *Last, + SmallPtrSetImpl<VPValue *> &Candidates, + VPInterleavedAccessInfo &IAI) { + assert((Mode == OpMode::Load || Mode == OpMode::Opcode) && + "Currently we only handle load and commutative opcodes"); + LLVM_DEBUG(dbgs() << " getBest\n"); + + SmallVector<VPValue *, 4> BestCandidates; + LLVM_DEBUG(dbgs() << " Candidates for " + << *cast<VPInstruction>(Last)->getUnderlyingInstr() << " "); + for (auto *Candidate : Candidates) { + auto *LastI = cast<VPInstruction>(Last); + auto *CandidateI = cast<VPInstruction>(Candidate); + if (areConsecutiveOrMatch(LastI, CandidateI, IAI)) { + LLVM_DEBUG(dbgs() << *cast<VPInstruction>(Candidate)->getUnderlyingInstr() + << " "); + BestCandidates.push_back(Candidate); + } + } + LLVM_DEBUG(dbgs() << "\n"); + + if (BestCandidates.empty()) + return {OpMode::Failed, nullptr}; + + if (BestCandidates.size() == 1) + return {Mode, BestCandidates[0]}; + + VPValue *Best = nullptr; + unsigned BestScore = 0; + for (unsigned Depth = 1; Depth < LookaheadMaxDepth; Depth++) { + unsigned PrevScore = ~0u; + bool AllSame = true; + + // FIXME: Avoid visiting the same operands multiple times. + for (auto *Candidate : BestCandidates) { + unsigned Score = getLAScore(Last, Candidate, Depth, IAI); + if (PrevScore == ~0u) + PrevScore = Score; + if (PrevScore != Score) + AllSame = false; + PrevScore = Score; + + if (Score > BestScore) { + BestScore = Score; + Best = Candidate; + } + } + if (!AllSame) + break; + } + LLVM_DEBUG(dbgs() << "Found best " + << *cast<VPInstruction>(Best)->getUnderlyingInstr() + << "\n"); + Candidates.erase(Best); + + return {Mode, Best}; +} + +SmallVector<VPlanSlp::MultiNodeOpTy, 4> VPlanSlp::reorderMultiNodeOps() { + SmallVector<MultiNodeOpTy, 4> FinalOrder; + SmallVector<OpMode, 4> Mode; + FinalOrder.reserve(MultiNodeOps.size()); + Mode.reserve(MultiNodeOps.size()); + + LLVM_DEBUG(dbgs() << "Reordering multinode\n"); + + for (auto &Operands : MultiNodeOps) { + FinalOrder.push_back({Operands.first, {Operands.second[0]}}); + if (cast<VPInstruction>(Operands.second[0])->getOpcode() == + Instruction::Load) + Mode.push_back(OpMode::Load); + else + Mode.push_back(OpMode::Opcode); + } + + for (unsigned Lane = 1, E = MultiNodeOps[0].second.size(); Lane < E; ++Lane) { + LLVM_DEBUG(dbgs() << " Finding best value for lane " << Lane << "\n"); + SmallPtrSet<VPValue *, 4> Candidates; + LLVM_DEBUG(dbgs() << " Candidates "); + for (auto Ops : MultiNodeOps) { + LLVM_DEBUG( + dbgs() << *cast<VPInstruction>(Ops.second[Lane])->getUnderlyingInstr() + << " "); + Candidates.insert(Ops.second[Lane]); + } + LLVM_DEBUG(dbgs() << "\n"); + + for (unsigned Op = 0, E = MultiNodeOps.size(); Op < E; ++Op) { + LLVM_DEBUG(dbgs() << " Checking " << Op << "\n"); + if (Mode[Op] == OpMode::Failed) + continue; + + VPValue *Last = FinalOrder[Op].second[Lane - 1]; + std::pair<OpMode, VPValue *> Res = + getBest(Mode[Op], Last, Candidates, IAI); + if (Res.second) + FinalOrder[Op].second.push_back(Res.second); + else + // TODO: handle this case + FinalOrder[Op].second.push_back(markFailed()); + } + } + + return FinalOrder; +} + +void VPlanSlp::dumpBundle(ArrayRef<VPValue *> Values) { + dbgs() << " Ops: "; + for (auto Op : Values) + if (auto *Instr = cast_or_null<VPInstruction>(Op)->getUnderlyingInstr()) + dbgs() << *Instr << " | "; + else + dbgs() << " nullptr | "; + dbgs() << "\n"; +} + +VPInstruction *VPlanSlp::buildGraph(ArrayRef<VPValue *> Values) { + assert(!Values.empty() && "Need some operands!"); + + // If we already visited this instruction bundle, re-use the existing node + auto I = BundleToCombined.find(to_vector<4>(Values)); + if (I != BundleToCombined.end()) { +#ifndef NDEBUG + // Check that the resulting graph is a tree. If we re-use a node, this means + // its values have multiple users. We only allow this, if all users of each + // value are the same instruction. + for (auto *V : Values) { + auto UI = V->user_begin(); + auto *FirstUser = *UI++; + while (UI != V->user_end()) { + assert(*UI == FirstUser && "Currently we only support SLP trees."); + UI++; + } + } +#endif + return I->second; + } + + // Dump inputs + LLVM_DEBUG({ + dbgs() << "buildGraph: "; + dumpBundle(Values); + }); + + if (!areVectorizable(Values)) + return markFailed(); + + assert(getOpcode(Values) && "Opcodes for all values must match"); + unsigned ValuesOpcode = getOpcode(Values).getValue(); + + SmallVector<VPValue *, 4> CombinedOperands; + if (areCommutative(Values)) { + bool MultiNodeRoot = !MultiNodeActive; + MultiNodeActive = true; + for (auto &Operands : getOperands(Values)) { + LLVM_DEBUG({ + dbgs() << " Visiting Commutative"; + dumpBundle(Operands); + }); + + auto OperandsOpcode = getOpcode(Operands); + if (OperandsOpcode && OperandsOpcode == getOpcode(Values)) { + LLVM_DEBUG(dbgs() << " Same opcode, continue building\n"); + CombinedOperands.push_back(buildGraph(Operands)); + } else { + LLVM_DEBUG(dbgs() << " Adding multinode Ops\n"); + // Create dummy VPInstruction, which will we replace later by the + // re-ordered operand. + VPInstruction *Op = new VPInstruction(0, {}); + CombinedOperands.push_back(Op); + MultiNodeOps.emplace_back(Op, Operands); + } + } + + if (MultiNodeRoot) { + LLVM_DEBUG(dbgs() << "Reorder \n"); + MultiNodeActive = false; + + auto FinalOrder = reorderMultiNodeOps(); + + MultiNodeOps.clear(); + for (auto &Ops : FinalOrder) { + VPInstruction *NewOp = buildGraph(Ops.second); + Ops.first->replaceAllUsesWith(NewOp); + for (unsigned i = 0; i < CombinedOperands.size(); i++) + if (CombinedOperands[i] == Ops.first) + CombinedOperands[i] = NewOp; + delete Ops.first; + Ops.first = NewOp; + } + LLVM_DEBUG(dbgs() << "Found final order\n"); + } + } else { + LLVM_DEBUG(dbgs() << " NonCommuntative\n"); + if (ValuesOpcode == Instruction::Load) + for (VPValue *V : Values) + CombinedOperands.push_back(cast<VPInstruction>(V)->getOperand(0)); + else + for (auto &Operands : getOperands(Values)) + CombinedOperands.push_back(buildGraph(Operands)); + } + + unsigned Opcode; + switch (ValuesOpcode) { + case Instruction::Load: + Opcode = VPInstruction::SLPLoad; + break; + case Instruction::Store: + Opcode = VPInstruction::SLPStore; + break; + default: + Opcode = ValuesOpcode; + break; + } + + if (!CompletelySLP) + return markFailed(); + + assert(CombinedOperands.size() > 0 && "Need more some operands"); + auto *VPI = new VPInstruction(Opcode, CombinedOperands); + VPI->setUnderlyingInstr(cast<VPInstruction>(Values[0])->getUnderlyingInstr()); + + LLVM_DEBUG(dbgs() << "Create VPInstruction "; VPI->print(dbgs()); + cast<VPInstruction>(Values[0])->print(dbgs()); dbgs() << "\n"); + addCombined(Values, VPI); + return VPI; +} diff --git a/lib/Transforms/Vectorize/VPlanValue.h b/lib/Transforms/Vectorize/VPlanValue.h index 08f142915b49..b473579b699f 100644 --- a/lib/Transforms/Vectorize/VPlanValue.h +++ b/lib/Transforms/Vectorize/VPlanValue.h @@ -38,6 +38,10 @@ class VPUser; // and live-outs which the VPlan will need to fix accordingly. class VPValue { friend class VPBuilder; + friend class VPlanHCFGTransforms; + friend class VPBasicBlock; + friend class VPInterleavedAccessInfo; + private: const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). @@ -102,6 +106,20 @@ public: const_user_range users() const { return const_user_range(user_begin(), user_end()); } + + /// Returns true if the value has more than one unique user. + bool hasMoreThanOneUniqueUser() { + if (getNumUsers() == 0) + return false; + + // Check if all users match the first user. + auto Current = std::next(user_begin()); + while (Current != user_end() && *user_begin() == *Current) + Current++; + return Current != user_end(); + } + + void replaceAllUsesWith(VPValue *New); }; typedef DenseMap<Value *, VPValue *> Value2VPValueTy; @@ -147,6 +165,8 @@ public: return Operands[N]; } + void setOperand(unsigned I, VPValue *New) { Operands[I] = New; } + typedef SmallVectorImpl<VPValue *>::iterator operand_iterator; typedef SmallVectorImpl<VPValue *>::const_iterator const_operand_iterator; typedef iterator_range<operand_iterator> operand_range; diff --git a/lib/Transforms/Vectorize/Vectorize.cpp b/lib/Transforms/Vectorize/Vectorize.cpp index f62a88558328..559ab1968844 100644 --- a/lib/Transforms/Vectorize/Vectorize.cpp +++ b/lib/Transforms/Vectorize/Vectorize.cpp @@ -27,7 +27,7 @@ using namespace llvm; void llvm::initializeVectorization(PassRegistry &Registry) { initializeLoopVectorizePass(Registry); initializeSLPVectorizerPass(Registry); - initializeLoadStoreVectorizerPass(Registry); + initializeLoadStoreVectorizerLegacyPassPass(Registry); } void LLVMInitializeVectorization(LLVMPassRegistryRef R) { |